feat: create 'test connection' bedrock api + fix endpoints for test connection (ant, openai, gemini) (#3227)
Co-authored-by: Eric Ly <lyyeric@letta.com>
This commit is contained in:
@@ -41,22 +41,36 @@ def get_bedrock_client(
|
||||
return bedrock
|
||||
|
||||
|
||||
def bedrock_get_model_list(region_name: str) -> List[dict]:
|
||||
def bedrock_get_model_list(
|
||||
region_name: str,
|
||||
access_key_id: Optional[str] = None,
|
||||
secret_access_key: Optional[str] = None,
|
||||
) -> List[dict]:
|
||||
"""
|
||||
Get list of available models from Bedrock.
|
||||
|
||||
Args:
|
||||
region_name: AWS region name
|
||||
access_key_id: Optional AWS access key ID
|
||||
secret_access_key: Optional AWS secret access key
|
||||
|
||||
TODO: Implement model_provider and output_modality filtering
|
||||
model_provider: Optional provider name to filter models. If None, returns all models.
|
||||
output_modality: Output modality to filter models. Defaults to "text".
|
||||
|
||||
Returns:
|
||||
List of model summaries
|
||||
|
||||
"""
|
||||
import boto3
|
||||
|
||||
try:
|
||||
bedrock = boto3.client("bedrock", region_name=region_name)
|
||||
bedrock = boto3.client(
|
||||
"bedrock",
|
||||
region_name=region_name,
|
||||
aws_access_key_id=access_key_id,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
)
|
||||
response = bedrock.list_inference_profiles()
|
||||
return response["inferenceProfileSummaries"]
|
||||
except Exception as e:
|
||||
|
||||
@@ -1536,6 +1536,26 @@ class BedrockProvider(Provider):
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
region: str = Field(..., description="AWS region for Bedrock")
|
||||
|
||||
def check_api_key(self):
|
||||
"""Check if the Bedrock credentials are valid"""
|
||||
from letta.errors import LLMAuthenticationError
|
||||
from letta.llm_api.aws_bedrock import bedrock_get_model_list
|
||||
|
||||
try:
|
||||
# For BYOK providers, use the custom credentials
|
||||
if self.provider_category == ProviderCategory.byok:
|
||||
# If we can list models, the credentials are valid
|
||||
bedrock_get_model_list(
|
||||
region_name=self.region,
|
||||
access_key_id=self.access_key,
|
||||
secret_access_key=self.api_key, # api_key stores the secret access key
|
||||
)
|
||||
else:
|
||||
# For base providers, use default credentials
|
||||
bedrock_get_model_list(region_name=self.region)
|
||||
except Exception as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}")
|
||||
|
||||
def list_llm_models(self):
|
||||
from letta.llm_api.aws_bedrock import bedrock_get_model_list
|
||||
|
||||
|
||||
@@ -213,7 +213,7 @@ class ProviderManager:
|
||||
provider_type=provider_check.provider_type,
|
||||
api_key=provider_check.api_key,
|
||||
provider_category=ProviderCategory.byok,
|
||||
secret_key=provider_check.api_secret,
|
||||
access_id_key=provider_check.access_id_key, # This contains the access key ID for Bedrock
|
||||
region=provider_check.region,
|
||||
).cast_to_subtype()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user