diff --git a/letta/llm_api/aws_bedrock.py b/letta/llm_api/aws_bedrock.py deleted file mode 100644 index 67497b4e..00000000 --- a/letta/llm_api/aws_bedrock.py +++ /dev/null @@ -1,90 +0,0 @@ -""" -Note that this formally only supports Anthropic Bedrock. -TODO (cliandy): determine what other providers are supported and what is needed to add support. -""" - -from typing import Any, Optional - -from anthropic import AnthropicBedrock - -from letta.log import get_logger -from letta.settings import model_settings - -logger = get_logger(__name__) - - -def get_bedrock_client( - access_key_id: Optional[str] = None, - secret_key: Optional[str] = None, - default_region: Optional[str] = None, -): - """ - Get a Bedrock client - """ - import boto3 - - sts_client = boto3.client( - "sts", - aws_access_key_id=access_key_id or model_settings.aws_access_key_id, - aws_secret_access_key=secret_key or model_settings.aws_secret_access_key, - region_name=default_region or model_settings.aws_default_region, - ) - credentials = sts_client.get_session_token()["Credentials"] - - bedrock = AnthropicBedrock( - aws_access_key=credentials["AccessKeyId"], - aws_secret_key=credentials["SecretAccessKey"], - aws_session_token=credentials["SessionToken"], - aws_region=default_region or model_settings.aws_default_region, - ) - return bedrock - - -async def bedrock_get_model_list_async( - access_key_id: Optional[str] = None, - secret_access_key: Optional[str] = None, - default_region: Optional[str] = None, -) -> list[dict]: - from aioboto3.session import Session - - try: - session = Session() - async with session.client( - "bedrock", - aws_access_key_id=access_key_id, - aws_secret_access_key=secret_access_key, - region_name=default_region, - ) as bedrock: - response = await bedrock.list_inference_profiles() - return response["inferenceProfileSummaries"] - except Exception as e: - logger.error(f"Error getting model list for bedrock: %s", e) - raise e - - -def bedrock_get_model_details(region_name: str, model_id: str) -> dict[str, Any]: - """ - Get details for a specific model from Bedrock. - """ - import boto3 - from botocore.exceptions import ClientError - - try: - bedrock = boto3.client("bedrock", region_name=region_name) - response = bedrock.get_foundation_model(modelIdentifier=model_id) - return response["modelDetails"] - except ClientError as e: - logger.exception(f"Error getting model details: {str(e)}") - raise e - - -def bedrock_get_model_context_window(model_id: str) -> int: - """ - Get context window size for a specific model. - """ - # Bedrock doesn't provide this via API, so we maintain a mapping - # 200k for anthropic: https://aws.amazon.com/bedrock/anthropic/ - if model_id.startswith("anthropic"): - return 200_000 - else: - return 100_000 # default to 100k if unknown diff --git a/letta/schemas/providers/bedrock.py b/letta/schemas/providers/bedrock.py index d7d8437f..7c9c2fa9 100644 --- a/letta/schemas/providers/bedrock.py +++ b/letta/schemas/providers/bedrock.py @@ -20,20 +20,32 @@ class BedrockProvider(Provider): provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") region: str = Field(..., description="AWS region for Bedrock") + async def bedrock_get_model_list_async(self) -> list[dict]: + from aioboto3.session import Session + + try: + session = Session() + async with session.client( + "bedrock", + aws_access_key_id=self.access_key, + aws_secret_access_key=self.api_key, + region_name=self.region, + ) as bedrock: + response = await bedrock.list_inference_profiles() + return response["inferenceProfileSummaries"] + except Exception as e: + logger.error(f"Error getting model list for bedrock: %s", e) + raise e + async def check_api_key(self): """Check if the Bedrock credentials are valid""" from letta.errors import LLMAuthenticationError - from letta.llm_api.aws_bedrock import bedrock_get_model_list_async try: # For BYOK providers, use the custom credentials if self.provider_category == ProviderCategory.byok: # If we can list models, the credentials are valid - await bedrock_get_model_list_async( - access_key_id=self.access_key, - secret_access_key=self.api_key, # api_key stores the secret access key - region_name=self.region, - ) + await self.bedrock_get_model_list_async() else: # For base providers, use default credentials bedrock_get_model_list(region_name=self.region) @@ -41,13 +53,7 @@ class BedrockProvider(Provider): raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") async def list_llm_models_async(self) -> list[LLMConfig]: - from letta.llm_api.aws_bedrock import bedrock_get_model_list_async - - models = await bedrock_get_model_list_async( - self.access_key, - self.api_key, - self.region, - ) + models = await self.bedrock_get_model_list_async() configs = [] for model_summary in models: @@ -67,10 +73,16 @@ class BedrockProvider(Provider): return configs def get_model_context_window(self, model_name: str) -> int | None: - # Context windows for Claude models - from letta.llm_api.aws_bedrock import bedrock_get_model_context_window + """ + Get context window size for a specific model. - return bedrock_get_model_context_window(model_name) + Bedrock doesn't provide this via API, so we maintain a mapping + 200k for anthropic: https://aws.amazon.com/bedrock/anthropic/ + """ + if model_name.startswith("anthropic"): + return 200_000 + else: + return 100_000 # default to 100k if unknown def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str: logger.debug("Getting handle for model_name: %s", model_name)