Files
letta-server/letta/llm_api/bedrock_client.py
Ari Webb 9dbf428c1f feat: enable bedrock for anthropic models (#8847)
* feat: enable bedrock for anthropic models

* parallel tool calls in ade

* attempt add to ci

* update tests

* add env vars

* hardcode region

* get it working

* debugging

* add bedrock extra

* default env var [skip ci]

* run ci

* reasoner model update

* secrets

* clean up log

* clean up
2026-01-19 15:54:44 -08:00

100 lines
4.4 KiB
Python

from typing import List, Optional, Union
import anthropic
from aioboto3.session import Session
from letta.llm_api.anthropic_client import AnthropicClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.enums import AgentType, ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.services.provider_manager import ProviderManager
from letta.settings import model_settings
logger = get_logger(__name__)
class BedrockClient(AnthropicClient):
@staticmethod
def get_inference_profile_id_from_handle(handle: str) -> str:
"""
Extract the Bedrock inference profile ID from the LLMConfig handle.
The handle format is: bedrock/us.anthropic.claude-opus-4-5-20250918-v1:0
Returns: us.anthropic.claude-opus-4-5-20250918-v1:0
"""
if "/" in handle:
return handle.split("/", 1)[1]
return handle
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
if llm_config.provider_category == ProviderCategory.byok:
(
override_access_key_id,
override_secret_access_key,
override_default_region,
) = await ProviderManager().get_bedrock_credentials_async(
llm_config.provider_name,
actor=self.actor,
)
return override_access_key_id, override_secret_access_key, override_default_region
@trace_method
async def _get_anthropic_client_async(
self, llm_config: LLMConfig, async_client: bool = False
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic, anthropic.AsyncAnthropicBedrock, anthropic.AnthropicBedrock]:
override_access_key_id, override_secret_access_key, override_default_region = await self.get_byok_overrides_async(llm_config)
session = Session()
async with session.client(
"sts",
aws_access_key_id=override_access_key_id or model_settings.aws_access_key_id,
aws_secret_access_key=override_secret_access_key or model_settings.aws_secret_access_key,
region_name=override_default_region or model_settings.aws_default_region,
) as sts_client:
session_token = await sts_client.get_session_token()
credentials = session_token["Credentials"]
if async_client:
return anthropic.AsyncAnthropicBedrock(
aws_access_key=credentials["AccessKeyId"],
aws_secret_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
aws_region=override_default_region or model_settings.aws_default_region,
max_retries=model_settings.anthropic_max_retries,
)
else:
return anthropic.AnthropicBedrock(
aws_access_key=credentials["AccessKeyId"],
aws_secret_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
aws_region=override_default_region or model_settings.aws_default_region,
max_retries=model_settings.anthropic_max_retries,
)
@trace_method
def build_request_data(
self,
agent_type: AgentType,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# Swap the model name back to the Bedrock inference profile ID for the API call
# The LLMConfig.model contains the Anthropic-style name (e.g., "claude-opus-4-5-20250918")
# but Bedrock API needs the inference profile ID (e.g., "us.anthropic.claude-opus-4-5-20250918-v1:0")
if llm_config.handle:
data["model"] = self.get_inference_profile_id_from_handle(llm_config.handle)
# remove disallowed fields
if "tool_choice" in data:
del data["tool_choice"]["disable_parallel_tool_use"]
return data