from typing import List, Optional, Union import anthropic from aioboto3.session import Session from letta.llm_api.anthropic_client import AnthropicClient from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.enums import AgentType, ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.services.provider_manager import ProviderManager from letta.settings import model_settings logger = get_logger(__name__) class BedrockClient(AnthropicClient): async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]: override_access_key_id, override_secret_access_key, override_default_region = None, None, None if llm_config.provider_category == ProviderCategory.byok: ( override_access_key_id, override_secret_access_key, override_default_region, ) = await ProviderManager().get_bedrock_credentials_async( llm_config.provider_name, actor=self.actor, ) return override_access_key_id, override_secret_access_key, override_default_region @trace_method async def _get_anthropic_client_async( self, llm_config: LLMConfig, async_client: bool = False ) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic, anthropic.AsyncAnthropicBedrock, anthropic.AnthropicBedrock]: override_access_key_id, override_secret_access_key, override_default_region = await self.get_byok_overrides_async(llm_config) session = Session() async with session.client( "sts", aws_access_key_id=override_access_key_id or model_settings.aws_access_key_id, aws_secret_access_key=override_secret_access_key or model_settings.aws_secret_access_key, region_name=override_default_region or model_settings.aws_default_region, ) as sts_client: session_token = await sts_client.get_session_token() credentials = session_token["Credentials"] if async_client: return anthropic.AsyncAnthropicBedrock( aws_access_key=credentials["AccessKeyId"], aws_secret_key=credentials["SecretAccessKey"], aws_session_token=credentials["SessionToken"], aws_region=override_default_region or model_settings.aws_default_region, max_retries=model_settings.anthropic_max_retries, ) else: return anthropic.AnthropicBedrock( aws_access_key=credentials["AccessKeyId"], aws_secret_key=credentials["SecretAccessKey"], aws_session_token=credentials["SessionToken"], aws_region=override_default_region or model_settings.aws_default_region, max_retries=model_settings.anthropic_max_retries, ) @trace_method def build_request_data( self, agent_type: AgentType, messages: List[PydanticMessage], llm_config: LLMConfig, tools: Optional[List[dict]] = None, force_tool_call: Optional[str] = None, requires_subsequent_tool_call: bool = False, tool_return_truncation_chars: Optional[int] = None, ) -> dict: data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call) # remove disallowed fields if "tool_choice" in data: del data["tool_choice"]["disable_parallel_tool_use"] return data