From 9dbf428c1ff468e126844caf7dd51537442c8e41 Mon Sep 17 00:00:00 2001 From: Ari Webb Date: Mon, 19 Jan 2026 11:59:32 -0800 Subject: [PATCH] feat: enable bedrock for anthropic models (#8847) * feat: enable bedrock for anthropic models * parallel tool calls in ade * attempt add to ci * update tests * add env vars * hardcode region * get it working * debugging * add bedrock extra * default env var [skip ci] * run ci * reasoner model update * secrets * clean up log * clean up --- letta/llm_api/bedrock_client.py | 19 ++++ letta/schemas/llm_config.py | 4 +- letta/schemas/providers/bedrock.py | 93 +++++++++++++------ letta/server/server.py | 2 + letta/settings.py | 2 +- tests/integration_test_send_message.py | 3 +- tests/integration_test_send_message_v2.py | 20 +++- .../bedrock-claude-4-5-opus.json | 9 ++ 8 files changed, 117 insertions(+), 35 deletions(-) create mode 100644 tests/model_settings/bedrock-claude-4-5-opus.json diff --git a/letta/llm_api/bedrock_client.py b/letta/llm_api/bedrock_client.py index d471424f..a49dcfb2 100644 --- a/letta/llm_api/bedrock_client.py +++ b/letta/llm_api/bedrock_client.py @@ -16,6 +16,18 @@ logger = get_logger(__name__) class BedrockClient(AnthropicClient): + @staticmethod + def get_inference_profile_id_from_handle(handle: str) -> str: + """ + Extract the Bedrock inference profile ID from the LLMConfig handle. + + The handle format is: bedrock/us.anthropic.claude-opus-4-5-20250918-v1:0 + Returns: us.anthropic.claude-opus-4-5-20250918-v1:0 + """ + if "/" in handle: + return handle.split("/", 1)[1] + return handle + async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]: override_access_key_id, override_secret_access_key, override_default_region = None, None, None if llm_config.provider_category == ProviderCategory.byok: @@ -74,6 +86,13 @@ class BedrockClient(AnthropicClient): tool_return_truncation_chars: Optional[int] = None, ) -> dict: data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call) + + # Swap the model name back to the Bedrock inference profile ID for the API call + # The LLMConfig.model contains the Anthropic-style name (e.g., "claude-opus-4-5-20250918") + # but Bedrock API needs the inference profile ID (e.g., "us.anthropic.claude-opus-4-5-20250918-v1:0") + if llm_config.handle: + data["model"] = self.get_inference_profile_id_from_handle(llm_config.handle) + # remove disallowed fields if "tool_choice" in data: del data["tool_choice"]["disable_parallel_tool_use"] diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 58b1f22c..5ce041f8 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -182,7 +182,7 @@ class LLMConfig(BaseModel): if is_openai_reasoning_model(model): values["put_inner_thoughts_in_kwargs"] = False - if values.get("model_endpoint_type") == "anthropic" and ( + if values.get("model_endpoint_type") in ("anthropic", "bedrock") and ( model.startswith("claude-3-7-sonnet") or model.startswith("claude-sonnet-4") or model.startswith("claude-opus-4") @@ -413,7 +413,7 @@ class LLMConfig(BaseModel): @classmethod def is_anthropic_reasoning_model(cls, config: "LLMConfig") -> bool: - return config.model_endpoint_type == "anthropic" and ( + return config.model_endpoint_type in ("anthropic", "bedrock") and ( config.model.startswith("claude-opus-4") or config.model.startswith("claude-sonnet-4") or config.model.startswith("claude-3-7-sonnet") diff --git a/letta/schemas/providers/bedrock.py b/letta/schemas/providers/bedrock.py index 4cfaff81..7f833ad8 100644 --- a/letta/schemas/providers/bedrock.py +++ b/letta/schemas/providers/bedrock.py @@ -18,22 +18,46 @@ logger = get_logger(__name__) class BedrockProvider(Provider): provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - access_key: str = Field(..., description="AWS secret access key for Bedrock.") + access_key: str | None = Field(None, description="AWS access key ID for Bedrock") + api_key: str | None = Field(None, description="AWS secret access key for Bedrock") region: str = Field(..., description="AWS region for Bedrock") + @staticmethod + def extract_anthropic_model_name(inference_profile_id: str) -> str: + """ + Extract the Anthropic-style model name from a Bedrock inference profile ID. + + Input format: us.anthropic.claude-opus-4-5-20250918-v1:0 + Output: claude-opus-4-5-20250918 + + This allows Bedrock models to use the same model name format as regular Anthropic models, + so all the existing model name checks (startswith("claude-"), etc.) work correctly. + """ + # Remove region prefix (e.g., "us.anthropic." -> "claude-...") + if ".anthropic." in inference_profile_id: + model_part = inference_profile_id.split(".anthropic.")[1] + else: + model_part = inference_profile_id + + # Remove version suffix (e.g., "-v1:0" at the end) + # Pattern: -v followed by digits, optionally followed by :digits + import re + + model_name = re.sub(r"-v\d+(?::\d+)?$", "", model_part) + return model_name + async def bedrock_get_model_list_async(self) -> list[dict]: + """ + List Bedrock inference profiles using boto3. + """ from aioboto3.session import Session try: - # Decrypt credentials before using - access_key = await self.access_key_enc.get_plaintext_async() if self.access_key_enc else None - secret_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None - session = Session() async with session.client( "bedrock", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, + aws_access_key_id=self.access_key, + aws_secret_access_key=self.api_key, region_name=self.region, ) as bedrock: response = await bedrock.list_inference_profiles() @@ -43,34 +67,43 @@ class BedrockProvider(Provider): raise e async def check_api_key(self): - """Check if the Bedrock credentials are valid""" + """Check if the Bedrock credentials are valid by listing models""" from letta.errors import LLMAuthenticationError try: - # For BYOK providers, use the custom credentials - if self.provider_category == ProviderCategory.byok: - # If we can list models, the credentials are valid - await self.bedrock_get_model_list_async() - else: - # For base providers, use default credentials - bedrock_get_model_list(region_name=self.region) + # If we can list models, the credentials are valid + await self.bedrock_get_model_list_async() except Exception as e: raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") async def list_llm_models_async(self) -> list[LLMConfig]: models = await self.bedrock_get_model_list_async() - configs = [] + # Deduplicate models by normalized name - prefer regional (us., eu.) over global + seen_models: dict[str, tuple[str, dict]] = {} # model_name -> (inference_profile_id, model_summary) for model_summary in models: - model_arn = model_summary["inferenceProfileArn"] + inference_profile_id = model_summary["inferenceProfileId"] + model_name = self.extract_anthropic_model_name(inference_profile_id) + + if model_name not in seen_models: + seen_models[model_name] = (inference_profile_id, model_summary) + else: + # Prefer regional profiles over global ones + existing_id = seen_models[model_name][0] + if existing_id.startswith("global.") and not inference_profile_id.startswith("global."): + seen_models[model_name] = (inference_profile_id, model_summary) + + configs = [] + for model_name, (inference_profile_id, model_summary) in seen_models.items(): configs.append( LLMConfig( - model=model_arn, + model=model_name, model_endpoint_type=self.provider_type.value, model_endpoint=None, - context_window=self.get_model_context_window(model_arn), - handle=self.get_handle(model_arn), - max_tokens=self.get_default_max_output_tokens(model_arn), + context_window=self.get_model_context_window(inference_profile_id), + # Store the full inference profile ID in the handle for API calls + handle=self.get_handle(inference_profile_id), + max_tokens=self.get_default_max_output_tokens(inference_profile_id), provider_name=self.name, provider_category=self.provider_category, ) @@ -82,15 +115,19 @@ class BedrockProvider(Provider): """ Get context window size for a specific model. - Bedrock doesn't provide this via API, so we maintain a mapping - 200k for anthropic: https://aws.amazon.com/bedrock/anthropic/ + Bedrock doesn't provide this via API, so we maintain a mapping. """ - if model_name.startswith("anthropic"): + model_lower = model_name.lower() + if "anthropic" in model_lower or "claude" in model_lower: return 200_000 else: - return 100_000 # default to 100k if unknown + return 100_000 # default if unknown def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str: - logger.debug("Getting handle for model_name: %s", model_name) - model = model_name.split(".")[-1] - return f"{self.name}/{model}" + """ + Create handle from inference profile ID. + + Input format: us.anthropic.claude-sonnet-4-20250514-v1:0 + Output: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0 + """ + return f"{self.name}/{model_name}" diff --git a/letta/server/server.py b/letta/server/server.py index 196c35f8..a021bdfc 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -296,6 +296,8 @@ class SyncServer(object): self._enabled_providers.append( BedrockProvider( name="bedrock", + access_key=model_settings.aws_access_key_id, + api_key=model_settings.aws_secret_access_key, region=model_settings.aws_default_region, ) ) diff --git a/letta/settings.py b/letta/settings.py index 3caead22..365b9d95 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -144,7 +144,7 @@ class ModelSettings(BaseSettings): # Bedrock aws_access_key_id: Optional[str] = None aws_secret_access_key: Optional[str] = None - aws_default_region: Optional[str] = None + aws_default_region: str = "us-east-1" bedrock_anthropic_version: Optional[str] = "bedrock-2023-05-31" # anthropic diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 53df36b0..c329180a 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -229,7 +229,8 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool: return model.startswith("o1") or model.startswith("o3") or model.startswith("o4") or model.startswith("gpt-5") # Anthropic reasoning models (from anthropic_client.py:608-616) - elif provider_type == "anthropic": + # Also applies to Bedrock with Anthropic models + elif provider_type in ("anthropic", "bedrock"): return ( model.startswith("claude-3-7-sonnet") or model.startswith("claude-sonnet-4") diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index e383ecc8..124b1a5b 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -484,8 +484,22 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool: ) # Z.ai models output reasoning by default is_zai_reasoning = model_settings.get("provider_type") == "zai" + # Bedrock Anthropic reasoning models + is_bedrock_reasoning = model_settings.get("provider_type") == "bedrock" and ( + "claude-3-7-sonnet" in model_handle + or "claude-sonnet-4" in model_handle + or "claude-opus-4" in model_handle + or "claude-haiku-4-5" in model_handle + ) - return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning + return ( + is_openai_reasoning + or is_anthropic_reasoning + or is_google_vertex_reasoning + or is_google_ai_reasoning + or is_zai_reasoning + or is_bedrock_reasoning + ) # ------------------------------ @@ -653,8 +667,8 @@ async def test_parallel_tool_calls( model_handle, model_settings = model_config provider_type = model_settings.get("provider_type", "") - if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex"]: - pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.") + if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex", "bedrock"]: + pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, Gemini, and Bedrock models.") if "gpt-5" in model_handle or "o3" in model_handle: pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.") diff --git a/tests/model_settings/bedrock-claude-4-5-opus.json b/tests/model_settings/bedrock-claude-4-5-opus.json new file mode 100644 index 00000000..e9d3b10c --- /dev/null +++ b/tests/model_settings/bedrock-claude-4-5-opus.json @@ -0,0 +1,9 @@ +{ + "handle": "bedrock/us.anthropic.claude-opus-4-5-20251101-v1:0", + "model_settings": { + "provider_type": "bedrock", + "temperature": 1.0, + "max_output_tokens": 16000, + "parallel_tool_calls": false + } +}