From 396f37156c29c6a17db2c01624409bd3548c0b44 Mon Sep 17 00:00:00 2001 From: Eric Ly <111820150+lyeric2022@users.noreply.github.com> Date: Thu, 17 Jul 2025 11:39:46 -0700 Subject: [PATCH] feat: create 'test connection' bedrock api + fix endpoints for test connection (ant, openai, gemini) (#3227) Co-authored-by: Eric Ly --- letta/llm_api/aws_bedrock.py | 18 ++++++++++++++++-- letta/schemas/providers.py | 20 ++++++++++++++++++++ letta/services/provider_manager.py | 2 +- 3 files changed, 37 insertions(+), 3 deletions(-) diff --git a/letta/llm_api/aws_bedrock.py b/letta/llm_api/aws_bedrock.py index c395868d..539ce0fd 100644 --- a/letta/llm_api/aws_bedrock.py +++ b/letta/llm_api/aws_bedrock.py @@ -41,22 +41,36 @@ def get_bedrock_client( return bedrock -def bedrock_get_model_list(region_name: str) -> List[dict]: +def bedrock_get_model_list( + region_name: str, + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, +) -> List[dict]: """ Get list of available models from Bedrock. Args: region_name: AWS region name + access_key_id: Optional AWS access key ID + secret_access_key: Optional AWS secret access key + + TODO: Implement model_provider and output_modality filtering model_provider: Optional provider name to filter models. If None, returns all models. output_modality: Output modality to filter models. Defaults to "text". Returns: List of model summaries + """ import boto3 try: - bedrock = boto3.client("bedrock", region_name=region_name) + bedrock = boto3.client( + "bedrock", + region_name=region_name, + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + ) response = bedrock.list_inference_profiles() return response["inferenceProfileSummaries"] except Exception as e: diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 51d988bc..97d68281 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -1536,6 +1536,26 @@ class BedrockProvider(Provider): provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") region: str = Field(..., description="AWS region for Bedrock") + def check_api_key(self): + """Check if the Bedrock credentials are valid""" + from letta.errors import LLMAuthenticationError + from letta.llm_api.aws_bedrock import bedrock_get_model_list + + try: + # For BYOK providers, use the custom credentials + if self.provider_category == ProviderCategory.byok: + # If we can list models, the credentials are valid + bedrock_get_model_list( + region_name=self.region, + access_key_id=self.access_key, + secret_access_key=self.api_key, # api_key stores the secret access key + ) + else: + # For base providers, use default credentials + bedrock_get_model_list(region_name=self.region) + except Exception as e: + raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") + def list_llm_models(self): from letta.llm_api.aws_bedrock import bedrock_get_model_list diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index e919b93f..610ffb2e 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -213,7 +213,7 @@ class ProviderManager: provider_type=provider_check.provider_type, api_key=provider_check.api_key, provider_category=ProviderCategory.byok, - secret_key=provider_check.api_secret, + access_id_key=provider_check.access_id_key, # This contains the access key ID for Bedrock region=provider_check.region, ).cast_to_subtype()