feat: enable bedrock for anthropic models (#8847)

* feat: enable bedrock for anthropic models

* parallel tool calls in ade

* attempt add to ci

* update tests

* add env vars

* hardcode region

* get it working

* debugging

* add bedrock extra

* default env var [skip ci]

* run ci

* reasoner model update

* secrets

* clean up log

* clean up
This commit is contained in:
Ari Webb
2026-01-19 11:59:32 -08:00
committed by Sarah Wooders
parent 4be366470b
commit 9dbf428c1f
8 changed files with 117 additions and 35 deletions

View File

@@ -16,6 +16,18 @@ logger = get_logger(__name__)
class BedrockClient(AnthropicClient): class BedrockClient(AnthropicClient):
@staticmethod
def get_inference_profile_id_from_handle(handle: str) -> str:
"""
Extract the Bedrock inference profile ID from the LLMConfig handle.
The handle format is: bedrock/us.anthropic.claude-opus-4-5-20250918-v1:0
Returns: us.anthropic.claude-opus-4-5-20250918-v1:0
"""
if "/" in handle:
return handle.split("/", 1)[1]
return handle
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]: async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
override_access_key_id, override_secret_access_key, override_default_region = None, None, None override_access_key_id, override_secret_access_key, override_default_region = None, None, None
if llm_config.provider_category == ProviderCategory.byok: if llm_config.provider_category == ProviderCategory.byok:
@@ -74,6 +86,13 @@ class BedrockClient(AnthropicClient):
tool_return_truncation_chars: Optional[int] = None, tool_return_truncation_chars: Optional[int] = None,
) -> dict: ) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call) data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# Swap the model name back to the Bedrock inference profile ID for the API call
# The LLMConfig.model contains the Anthropic-style name (e.g., "claude-opus-4-5-20250918")
# but Bedrock API needs the inference profile ID (e.g., "us.anthropic.claude-opus-4-5-20250918-v1:0")
if llm_config.handle:
data["model"] = self.get_inference_profile_id_from_handle(llm_config.handle)
# remove disallowed fields # remove disallowed fields
if "tool_choice" in data: if "tool_choice" in data:
del data["tool_choice"]["disable_parallel_tool_use"] del data["tool_choice"]["disable_parallel_tool_use"]

View File

@@ -182,7 +182,7 @@ class LLMConfig(BaseModel):
if is_openai_reasoning_model(model): if is_openai_reasoning_model(model):
values["put_inner_thoughts_in_kwargs"] = False values["put_inner_thoughts_in_kwargs"] = False
if values.get("model_endpoint_type") == "anthropic" and ( if values.get("model_endpoint_type") in ("anthropic", "bedrock") and (
model.startswith("claude-3-7-sonnet") model.startswith("claude-3-7-sonnet")
or model.startswith("claude-sonnet-4") or model.startswith("claude-sonnet-4")
or model.startswith("claude-opus-4") or model.startswith("claude-opus-4")
@@ -413,7 +413,7 @@ class LLMConfig(BaseModel):
@classmethod @classmethod
def is_anthropic_reasoning_model(cls, config: "LLMConfig") -> bool: def is_anthropic_reasoning_model(cls, config: "LLMConfig") -> bool:
return config.model_endpoint_type == "anthropic" and ( return config.model_endpoint_type in ("anthropic", "bedrock") and (
config.model.startswith("claude-opus-4") config.model.startswith("claude-opus-4")
or config.model.startswith("claude-sonnet-4") or config.model.startswith("claude-sonnet-4")
or config.model.startswith("claude-3-7-sonnet") or config.model.startswith("claude-3-7-sonnet")

View File

@@ -18,22 +18,46 @@ logger = get_logger(__name__)
class BedrockProvider(Provider): class BedrockProvider(Provider):
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
access_key: str = Field(..., description="AWS secret access key for Bedrock.") access_key: str | None = Field(None, description="AWS access key ID for Bedrock")
api_key: str | None = Field(None, description="AWS secret access key for Bedrock")
region: str = Field(..., description="AWS region for Bedrock") region: str = Field(..., description="AWS region for Bedrock")
@staticmethod
def extract_anthropic_model_name(inference_profile_id: str) -> str:
"""
Extract the Anthropic-style model name from a Bedrock inference profile ID.
Input format: us.anthropic.claude-opus-4-5-20250918-v1:0
Output: claude-opus-4-5-20250918
This allows Bedrock models to use the same model name format as regular Anthropic models,
so all the existing model name checks (startswith("claude-"), etc.) work correctly.
"""
# Remove region prefix (e.g., "us.anthropic." -> "claude-...")
if ".anthropic." in inference_profile_id:
model_part = inference_profile_id.split(".anthropic.")[1]
else:
model_part = inference_profile_id
# Remove version suffix (e.g., "-v1:0" at the end)
# Pattern: -v followed by digits, optionally followed by :digits
import re
model_name = re.sub(r"-v\d+(?::\d+)?$", "", model_part)
return model_name
async def bedrock_get_model_list_async(self) -> list[dict]: async def bedrock_get_model_list_async(self) -> list[dict]:
"""
List Bedrock inference profiles using boto3.
"""
from aioboto3.session import Session from aioboto3.session import Session
try: try:
# Decrypt credentials before using
access_key = await self.access_key_enc.get_plaintext_async() if self.access_key_enc else None
secret_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
session = Session() session = Session()
async with session.client( async with session.client(
"bedrock", "bedrock",
aws_access_key_id=access_key, aws_access_key_id=self.access_key,
aws_secret_access_key=secret_key, aws_secret_access_key=self.api_key,
region_name=self.region, region_name=self.region,
) as bedrock: ) as bedrock:
response = await bedrock.list_inference_profiles() response = await bedrock.list_inference_profiles()
@@ -43,34 +67,43 @@ class BedrockProvider(Provider):
raise e raise e
async def check_api_key(self): async def check_api_key(self):
"""Check if the Bedrock credentials are valid""" """Check if the Bedrock credentials are valid by listing models"""
from letta.errors import LLMAuthenticationError from letta.errors import LLMAuthenticationError
try: try:
# For BYOK providers, use the custom credentials # If we can list models, the credentials are valid
if self.provider_category == ProviderCategory.byok: await self.bedrock_get_model_list_async()
# If we can list models, the credentials are valid
await self.bedrock_get_model_list_async()
else:
# For base providers, use default credentials
bedrock_get_model_list(region_name=self.region)
except Exception as e: except Exception as e:
raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}") raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}")
async def list_llm_models_async(self) -> list[LLMConfig]: async def list_llm_models_async(self) -> list[LLMConfig]:
models = await self.bedrock_get_model_list_async() models = await self.bedrock_get_model_list_async()
configs = [] # Deduplicate models by normalized name - prefer regional (us., eu.) over global
seen_models: dict[str, tuple[str, dict]] = {} # model_name -> (inference_profile_id, model_summary)
for model_summary in models: for model_summary in models:
model_arn = model_summary["inferenceProfileArn"] inference_profile_id = model_summary["inferenceProfileId"]
model_name = self.extract_anthropic_model_name(inference_profile_id)
if model_name not in seen_models:
seen_models[model_name] = (inference_profile_id, model_summary)
else:
# Prefer regional profiles over global ones
existing_id = seen_models[model_name][0]
if existing_id.startswith("global.") and not inference_profile_id.startswith("global."):
seen_models[model_name] = (inference_profile_id, model_summary)
configs = []
for model_name, (inference_profile_id, model_summary) in seen_models.items():
configs.append( configs.append(
LLMConfig( LLMConfig(
model=model_arn, model=model_name,
model_endpoint_type=self.provider_type.value, model_endpoint_type=self.provider_type.value,
model_endpoint=None, model_endpoint=None,
context_window=self.get_model_context_window(model_arn), context_window=self.get_model_context_window(inference_profile_id),
handle=self.get_handle(model_arn), # Store the full inference profile ID in the handle for API calls
max_tokens=self.get_default_max_output_tokens(model_arn), handle=self.get_handle(inference_profile_id),
max_tokens=self.get_default_max_output_tokens(inference_profile_id),
provider_name=self.name, provider_name=self.name,
provider_category=self.provider_category, provider_category=self.provider_category,
) )
@@ -82,15 +115,19 @@ class BedrockProvider(Provider):
""" """
Get context window size for a specific model. Get context window size for a specific model.
Bedrock doesn't provide this via API, so we maintain a mapping Bedrock doesn't provide this via API, so we maintain a mapping.
200k for anthropic: https://aws.amazon.com/bedrock/anthropic/
""" """
if model_name.startswith("anthropic"): model_lower = model_name.lower()
if "anthropic" in model_lower or "claude" in model_lower:
return 200_000 return 200_000
else: else:
return 100_000 # default to 100k if unknown return 100_000 # default if unknown
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str: def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
logger.debug("Getting handle for model_name: %s", model_name) """
model = model_name.split(".")[-1] Create handle from inference profile ID.
return f"{self.name}/{model}"
Input format: us.anthropic.claude-sonnet-4-20250514-v1:0
Output: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0
"""
return f"{self.name}/{model_name}"

View File

@@ -296,6 +296,8 @@ class SyncServer(object):
self._enabled_providers.append( self._enabled_providers.append(
BedrockProvider( BedrockProvider(
name="bedrock", name="bedrock",
access_key=model_settings.aws_access_key_id,
api_key=model_settings.aws_secret_access_key,
region=model_settings.aws_default_region, region=model_settings.aws_default_region,
) )
) )

View File

@@ -144,7 +144,7 @@ class ModelSettings(BaseSettings):
# Bedrock # Bedrock
aws_access_key_id: Optional[str] = None aws_access_key_id: Optional[str] = None
aws_secret_access_key: Optional[str] = None aws_secret_access_key: Optional[str] = None
aws_default_region: Optional[str] = None aws_default_region: str = "us-east-1"
bedrock_anthropic_version: Optional[str] = "bedrock-2023-05-31" bedrock_anthropic_version: Optional[str] = "bedrock-2023-05-31"
# anthropic # anthropic

View File

@@ -229,7 +229,8 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
return model.startswith("o1") or model.startswith("o3") or model.startswith("o4") or model.startswith("gpt-5") return model.startswith("o1") or model.startswith("o3") or model.startswith("o4") or model.startswith("gpt-5")
# Anthropic reasoning models (from anthropic_client.py:608-616) # Anthropic reasoning models (from anthropic_client.py:608-616)
elif provider_type == "anthropic": # Also applies to Bedrock with Anthropic models
elif provider_type in ("anthropic", "bedrock"):
return ( return (
model.startswith("claude-3-7-sonnet") model.startswith("claude-3-7-sonnet")
or model.startswith("claude-sonnet-4") or model.startswith("claude-sonnet-4")

View File

@@ -484,8 +484,22 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
) )
# Z.ai models output reasoning by default # Z.ai models output reasoning by default
is_zai_reasoning = model_settings.get("provider_type") == "zai" is_zai_reasoning = model_settings.get("provider_type") == "zai"
# Bedrock Anthropic reasoning models
is_bedrock_reasoning = model_settings.get("provider_type") == "bedrock" and (
"claude-3-7-sonnet" in model_handle
or "claude-sonnet-4" in model_handle
or "claude-opus-4" in model_handle
or "claude-haiku-4-5" in model_handle
)
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning return (
is_openai_reasoning
or is_anthropic_reasoning
or is_google_vertex_reasoning
or is_google_ai_reasoning
or is_zai_reasoning
or is_bedrock_reasoning
)
# ------------------------------ # ------------------------------
@@ -653,8 +667,8 @@ async def test_parallel_tool_calls(
model_handle, model_settings = model_config model_handle, model_settings = model_config
provider_type = model_settings.get("provider_type", "") provider_type = model_settings.get("provider_type", "")
if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex"]: if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex", "bedrock"]:
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.") pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, Gemini, and Bedrock models.")
if "gpt-5" in model_handle or "o3" in model_handle: if "gpt-5" in model_handle or "o3" in model_handle:
pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.") pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.")

View File

@@ -0,0 +1,9 @@
{
"handle": "bedrock/us.anthropic.claude-opus-4-5-20251101-v1:0",
"model_settings": {
"provider_type": "bedrock",
"temperature": 1.0,
"max_output_tokens": 16000,
"parallel_tool_calls": false
}
}