feat: enable bedrock for anthropic models (#8847)
* feat: enable bedrock for anthropic models * parallel tool calls in ade * attempt add to ci * update tests * add env vars * hardcode region * get it working * debugging * add bedrock extra * default env var [skip ci] * run ci * reasoner model update * secrets * clean up log * clean up
This commit is contained in:
@@ -16,6 +16,18 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BedrockClient(AnthropicClient):
|
||||
@staticmethod
|
||||
def get_inference_profile_id_from_handle(handle: str) -> str:
|
||||
"""
|
||||
Extract the Bedrock inference profile ID from the LLMConfig handle.
|
||||
|
||||
The handle format is: bedrock/us.anthropic.claude-opus-4-5-20250918-v1:0
|
||||
Returns: us.anthropic.claude-opus-4-5-20250918-v1:0
|
||||
"""
|
||||
if "/" in handle:
|
||||
return handle.split("/", 1)[1]
|
||||
return handle
|
||||
|
||||
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
|
||||
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
@@ -74,6 +86,13 @@ class BedrockClient(AnthropicClient):
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
|
||||
# Swap the model name back to the Bedrock inference profile ID for the API call
|
||||
# The LLMConfig.model contains the Anthropic-style name (e.g., "claude-opus-4-5-20250918")
|
||||
# but Bedrock API needs the inference profile ID (e.g., "us.anthropic.claude-opus-4-5-20250918-v1:0")
|
||||
if llm_config.handle:
|
||||
data["model"] = self.get_inference_profile_id_from_handle(llm_config.handle)
|
||||
|
||||
# remove disallowed fields
|
||||
if "tool_choice" in data:
|
||||
del data["tool_choice"]["disable_parallel_tool_use"]
|
||||
|
||||
@@ -182,7 +182,7 @@ class LLMConfig(BaseModel):
|
||||
if is_openai_reasoning_model(model):
|
||||
values["put_inner_thoughts_in_kwargs"] = False
|
||||
|
||||
if values.get("model_endpoint_type") == "anthropic" and (
|
||||
if values.get("model_endpoint_type") in ("anthropic", "bedrock") and (
|
||||
model.startswith("claude-3-7-sonnet")
|
||||
or model.startswith("claude-sonnet-4")
|
||||
or model.startswith("claude-opus-4")
|
||||
@@ -413,7 +413,7 @@ class LLMConfig(BaseModel):
|
||||
|
||||
@classmethod
|
||||
def is_anthropic_reasoning_model(cls, config: "LLMConfig") -> bool:
|
||||
return config.model_endpoint_type == "anthropic" and (
|
||||
return config.model_endpoint_type in ("anthropic", "bedrock") and (
|
||||
config.model.startswith("claude-opus-4")
|
||||
or config.model.startswith("claude-sonnet-4")
|
||||
or config.model.startswith("claude-3-7-sonnet")
|
||||
|
||||
@@ -18,22 +18,46 @@ logger = get_logger(__name__)
|
||||
class BedrockProvider(Provider):
|
||||
provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
access_key: str = Field(..., description="AWS secret access key for Bedrock.")
|
||||
access_key: str | None = Field(None, description="AWS access key ID for Bedrock")
|
||||
api_key: str | None = Field(None, description="AWS secret access key for Bedrock")
|
||||
region: str = Field(..., description="AWS region for Bedrock")
|
||||
|
||||
@staticmethod
|
||||
def extract_anthropic_model_name(inference_profile_id: str) -> str:
|
||||
"""
|
||||
Extract the Anthropic-style model name from a Bedrock inference profile ID.
|
||||
|
||||
Input format: us.anthropic.claude-opus-4-5-20250918-v1:0
|
||||
Output: claude-opus-4-5-20250918
|
||||
|
||||
This allows Bedrock models to use the same model name format as regular Anthropic models,
|
||||
so all the existing model name checks (startswith("claude-"), etc.) work correctly.
|
||||
"""
|
||||
# Remove region prefix (e.g., "us.anthropic." -> "claude-...")
|
||||
if ".anthropic." in inference_profile_id:
|
||||
model_part = inference_profile_id.split(".anthropic.")[1]
|
||||
else:
|
||||
model_part = inference_profile_id
|
||||
|
||||
# Remove version suffix (e.g., "-v1:0" at the end)
|
||||
# Pattern: -v followed by digits, optionally followed by :digits
|
||||
import re
|
||||
|
||||
model_name = re.sub(r"-v\d+(?::\d+)?$", "", model_part)
|
||||
return model_name
|
||||
|
||||
async def bedrock_get_model_list_async(self) -> list[dict]:
|
||||
"""
|
||||
List Bedrock inference profiles using boto3.
|
||||
"""
|
||||
from aioboto3.session import Session
|
||||
|
||||
try:
|
||||
# Decrypt credentials before using
|
||||
access_key = await self.access_key_enc.get_plaintext_async() if self.access_key_enc else None
|
||||
secret_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
|
||||
session = Session()
|
||||
async with session.client(
|
||||
"bedrock",
|
||||
aws_access_key_id=access_key,
|
||||
aws_secret_access_key=secret_key,
|
||||
aws_access_key_id=self.access_key,
|
||||
aws_secret_access_key=self.api_key,
|
||||
region_name=self.region,
|
||||
) as bedrock:
|
||||
response = await bedrock.list_inference_profiles()
|
||||
@@ -43,34 +67,43 @@ class BedrockProvider(Provider):
|
||||
raise e
|
||||
|
||||
async def check_api_key(self):
|
||||
"""Check if the Bedrock credentials are valid"""
|
||||
"""Check if the Bedrock credentials are valid by listing models"""
|
||||
from letta.errors import LLMAuthenticationError
|
||||
|
||||
try:
|
||||
# For BYOK providers, use the custom credentials
|
||||
if self.provider_category == ProviderCategory.byok:
|
||||
# If we can list models, the credentials are valid
|
||||
await self.bedrock_get_model_list_async()
|
||||
else:
|
||||
# For base providers, use default credentials
|
||||
bedrock_get_model_list(region_name=self.region)
|
||||
# If we can list models, the credentials are valid
|
||||
await self.bedrock_get_model_list_async()
|
||||
except Exception as e:
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}")
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
models = await self.bedrock_get_model_list_async()
|
||||
|
||||
configs = []
|
||||
# Deduplicate models by normalized name - prefer regional (us., eu.) over global
|
||||
seen_models: dict[str, tuple[str, dict]] = {} # model_name -> (inference_profile_id, model_summary)
|
||||
for model_summary in models:
|
||||
model_arn = model_summary["inferenceProfileArn"]
|
||||
inference_profile_id = model_summary["inferenceProfileId"]
|
||||
model_name = self.extract_anthropic_model_name(inference_profile_id)
|
||||
|
||||
if model_name not in seen_models:
|
||||
seen_models[model_name] = (inference_profile_id, model_summary)
|
||||
else:
|
||||
# Prefer regional profiles over global ones
|
||||
existing_id = seen_models[model_name][0]
|
||||
if existing_id.startswith("global.") and not inference_profile_id.startswith("global."):
|
||||
seen_models[model_name] = (inference_profile_id, model_summary)
|
||||
|
||||
configs = []
|
||||
for model_name, (inference_profile_id, model_summary) in seen_models.items():
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_arn,
|
||||
model=model_name,
|
||||
model_endpoint_type=self.provider_type.value,
|
||||
model_endpoint=None,
|
||||
context_window=self.get_model_context_window(model_arn),
|
||||
handle=self.get_handle(model_arn),
|
||||
max_tokens=self.get_default_max_output_tokens(model_arn),
|
||||
context_window=self.get_model_context_window(inference_profile_id),
|
||||
# Store the full inference profile ID in the handle for API calls
|
||||
handle=self.get_handle(inference_profile_id),
|
||||
max_tokens=self.get_default_max_output_tokens(inference_profile_id),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
@@ -82,15 +115,19 @@ class BedrockProvider(Provider):
|
||||
"""
|
||||
Get context window size for a specific model.
|
||||
|
||||
Bedrock doesn't provide this via API, so we maintain a mapping
|
||||
200k for anthropic: https://aws.amazon.com/bedrock/anthropic/
|
||||
Bedrock doesn't provide this via API, so we maintain a mapping.
|
||||
"""
|
||||
if model_name.startswith("anthropic"):
|
||||
model_lower = model_name.lower()
|
||||
if "anthropic" in model_lower or "claude" in model_lower:
|
||||
return 200_000
|
||||
else:
|
||||
return 100_000 # default to 100k if unknown
|
||||
return 100_000 # default if unknown
|
||||
|
||||
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
|
||||
logger.debug("Getting handle for model_name: %s", model_name)
|
||||
model = model_name.split(".")[-1]
|
||||
return f"{self.name}/{model}"
|
||||
"""
|
||||
Create handle from inference profile ID.
|
||||
|
||||
Input format: us.anthropic.claude-sonnet-4-20250514-v1:0
|
||||
Output: bedrock/us.anthropic.claude-sonnet-4-20250514-v1:0
|
||||
"""
|
||||
return f"{self.name}/{model_name}"
|
||||
|
||||
@@ -296,6 +296,8 @@ class SyncServer(object):
|
||||
self._enabled_providers.append(
|
||||
BedrockProvider(
|
||||
name="bedrock",
|
||||
access_key=model_settings.aws_access_key_id,
|
||||
api_key=model_settings.aws_secret_access_key,
|
||||
region=model_settings.aws_default_region,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -144,7 +144,7 @@ class ModelSettings(BaseSettings):
|
||||
# Bedrock
|
||||
aws_access_key_id: Optional[str] = None
|
||||
aws_secret_access_key: Optional[str] = None
|
||||
aws_default_region: Optional[str] = None
|
||||
aws_default_region: str = "us-east-1"
|
||||
bedrock_anthropic_version: Optional[str] = "bedrock-2023-05-31"
|
||||
|
||||
# anthropic
|
||||
|
||||
@@ -229,7 +229,8 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
|
||||
return model.startswith("o1") or model.startswith("o3") or model.startswith("o4") or model.startswith("gpt-5")
|
||||
|
||||
# Anthropic reasoning models (from anthropic_client.py:608-616)
|
||||
elif provider_type == "anthropic":
|
||||
# Also applies to Bedrock with Anthropic models
|
||||
elif provider_type in ("anthropic", "bedrock"):
|
||||
return (
|
||||
model.startswith("claude-3-7-sonnet")
|
||||
or model.startswith("claude-sonnet-4")
|
||||
|
||||
@@ -484,8 +484,22 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
|
||||
)
|
||||
# Z.ai models output reasoning by default
|
||||
is_zai_reasoning = model_settings.get("provider_type") == "zai"
|
||||
# Bedrock Anthropic reasoning models
|
||||
is_bedrock_reasoning = model_settings.get("provider_type") == "bedrock" and (
|
||||
"claude-3-7-sonnet" in model_handle
|
||||
or "claude-sonnet-4" in model_handle
|
||||
or "claude-opus-4" in model_handle
|
||||
or "claude-haiku-4-5" in model_handle
|
||||
)
|
||||
|
||||
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning
|
||||
return (
|
||||
is_openai_reasoning
|
||||
or is_anthropic_reasoning
|
||||
or is_google_vertex_reasoning
|
||||
or is_google_ai_reasoning
|
||||
or is_zai_reasoning
|
||||
or is_bedrock_reasoning
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------
|
||||
@@ -653,8 +667,8 @@ async def test_parallel_tool_calls(
|
||||
model_handle, model_settings = model_config
|
||||
provider_type = model_settings.get("provider_type", "")
|
||||
|
||||
if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
|
||||
if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex", "bedrock"]:
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, Gemini, and Bedrock models.")
|
||||
|
||||
if "gpt-5" in model_handle or "o3" in model_handle:
|
||||
pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.")
|
||||
|
||||
9
tests/model_settings/bedrock-claude-4-5-opus.json
Normal file
9
tests/model_settings/bedrock-claude-4-5-opus.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"handle": "bedrock/us.anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"model_settings": {
|
||||
"provider_type": "bedrock",
|
||||
"temperature": 1.0,
|
||||
"max_output_tokens": 16000,
|
||||
"parallel_tool_calls": false
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user