feat: add bedrock to byok (#2891)

This commit is contained in:
cthomas
2025-06-18 16:03:28 -07:00
committed by GitHub
parent bf76d91d6b
commit e89164f71b
10 changed files with 566 additions and 77 deletions

View File

@@ -811,12 +811,20 @@ def anthropic_chat_completions_request(
def anthropic_bedrock_chat_completions_request(
data: ChatCompletionRequest,
inner_thoughts_xml_tag: Optional[str] = "thinking",
provider_name: Optional[str] = None,
provider_category: Optional[ProviderCategory] = None,
user_id: Optional[str] = None,
) -> ChatCompletionResponse:
"""Make a chat completion request to Anthropic via AWS Bedrock."""
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag, bedrock=True)
# Get the client
client = get_bedrock_client()
if provider_category == ProviderCategory.byok:
actor = UserManager().get_user_or_default(user_id=user_id)
access_key, secret_key, region = ProviderManager().get_bedrock_credentials_async(provider_name, actor=actor)
client = get_bedrock_client(access_key, secret_key, region)
else:
client = get_bedrock_client()
# Make the request
try:

View File

@@ -1,5 +1,5 @@
import os
from typing import Any, Dict, List
from typing import Any, Dict, List, Optional
from anthropic import AnthropicBedrock
@@ -14,7 +14,11 @@ def has_valid_aws_credentials() -> bool:
return valid_aws_credentials
def get_bedrock_client():
def get_bedrock_client(
access_key: Optional[str] = None,
secret_key: Optional[str] = None,
region: Optional[str] = None,
):
"""
Get a Bedrock client
"""
@@ -22,9 +26,9 @@ def get_bedrock_client():
sts_client = boto3.client(
"sts",
aws_access_key_id=model_settings.aws_access_key,
aws_secret_access_key=model_settings.aws_secret_access_key,
region_name=model_settings.aws_region,
aws_access_key_id=access_key or model_settings.aws_access_key,
aws_secret_access_key=secret_key or model_settings.aws_secret_access_key,
region_name=region or model_settings.aws_region,
)
credentials = sts_client.get_session_token()["Credentials"]
@@ -32,7 +36,7 @@ def get_bedrock_client():
aws_access_key=credentials["AccessKeyId"],
aws_secret_key=credentials["SecretAccessKey"],
aws_session_token=credentials["SessionToken"],
aws_region=model_settings.aws_region,
aws_region=region or model_settings.aws_region,
)
return bedrock

View File

@@ -569,6 +569,9 @@ def create(
# NOTE: max_tokens is required for Anthropic API
max_tokens=llm_config.max_tokens,
),
provider_name=llm_config.provider_name,
provider_category=llm_config.provider_category,
user_id=user_id,
)
elif llm_config.model_endpoint_type == "deepseek":