Files
letta-server/letta/llm_api/bedrock_client.py
Charles Packer a4041879a4 feat: add new agent loop (squash rebase of OSS PR) (#4815)
* feat: squash rebase of OSS PR

* fix: revert changes that weren't on manual rebase

* fix: caught another one

* fix: disable force

* chore: drop print

* fix: just stage-api && just publish-api

* fix: make agent_type consistently an arg in the client

* fix: patch multi-modal support

* chore: put in todo stub

* fix: disable hardcoding for tests

* fix: patch validate agent sync (#4882)

patch validate agent sync

* fix: strip bad merge diff

* fix: revert unrelated diff

* fix: react_v2 naming -> letta_v1 naming

* fix: strip bad merge

---------

Co-authored-by: Kevin Lin <klin5061@gmail.com>
2025-10-07 17:50:45 -07:00

79 lines
3.4 KiB
Python

from typing import List, Optional, Union
import anthropic
from aioboto3.session import Session
from letta.llm_api.anthropic_client import AnthropicClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.enums import AgentType, ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.services.provider_manager import ProviderManager
from letta.settings import model_settings
logger = get_logger(__name__)
class BedrockClient(AnthropicClient):
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
if llm_config.provider_category == ProviderCategory.byok:
(
override_access_key_id,
override_secret_access_key,
override_default_region,
) = await ProviderManager().get_bedrock_credentials_async(
llm_config.provider_name,
actor=self.actor,
)
return override_access_key_id, override_secret_access_key, override_default_region
@trace_method
async def _get_anthropic_client_async(
self, llm_config: LLMConfig, async_client: bool = False
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic, anthropic.AsyncAnthropicBedrock, anthropic.AnthropicBedrock]:
override_access_key_id, override_secret_access_key, override_default_region = await self.get_byok_overrides_async(llm_config)
session = Session()
async with session.client(
"sts",
aws_access_key_id=override_access_key_id or model_settings.aws_access_key_id,
aws_secret_access_key=override_secret_access_key or model_settings.aws_secret_access_key,
region_name=override_default_region or model_settings.aws_default_region,
) as sts_client:
session_token = await sts_client.get_session_token()
credentials = session_token["Credentials"]
if async_client:
return anthropic.AsyncAnthropicBedrock(
aws_access_key=credentials["AccessKeyId"],
aws_secret_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
aws_region=override_default_region or model_settings.aws_default_region,
max_retries=model_settings.anthropic_max_retries,
)
else:
return anthropic.AnthropicBedrock(
aws_access_key=credentials["AccessKeyId"],
aws_secret_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
aws_region=override_default_region or model_settings.aws_default_region,
max_retries=model_settings.anthropic_max_retries,
)
@trace_method
def build_request_data(
self,
agent_type: AgentType,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
# remove disallowed fields
if "tool_choice" in data:
del data["tool_choice"]["disable_parallel_tool_use"]
return data