feat: add bedrock to byok (#2891)
This commit is contained in:
@@ -811,12 +811,20 @@ def anthropic_chat_completions_request(
|
||||
def anthropic_bedrock_chat_completions_request(
|
||||
data: ChatCompletionRequest,
|
||||
inner_thoughts_xml_tag: Optional[str] = "thinking",
|
||||
provider_name: Optional[str] = None,
|
||||
provider_category: Optional[ProviderCategory] = None,
|
||||
user_id: Optional[str] = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Make a chat completion request to Anthropic via AWS Bedrock."""
|
||||
data = _prepare_anthropic_request(data, inner_thoughts_xml_tag, bedrock=True)
|
||||
|
||||
# Get the client
|
||||
client = get_bedrock_client()
|
||||
if provider_category == ProviderCategory.byok:
|
||||
actor = UserManager().get_user_or_default(user_id=user_id)
|
||||
access_key, secret_key, region = ProviderManager().get_bedrock_credentials_async(provider_name, actor=actor)
|
||||
client = get_bedrock_client(access_key, secret_key, region)
|
||||
else:
|
||||
client = get_bedrock_client()
|
||||
|
||||
# Make the request
|
||||
try:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from anthropic import AnthropicBedrock
|
||||
|
||||
@@ -14,7 +14,11 @@ def has_valid_aws_credentials() -> bool:
|
||||
return valid_aws_credentials
|
||||
|
||||
|
||||
def get_bedrock_client():
|
||||
def get_bedrock_client(
|
||||
access_key: Optional[str] = None,
|
||||
secret_key: Optional[str] = None,
|
||||
region: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Get a Bedrock client
|
||||
"""
|
||||
@@ -22,9 +26,9 @@ def get_bedrock_client():
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=model_settings.aws_access_key,
|
||||
aws_secret_access_key=model_settings.aws_secret_access_key,
|
||||
region_name=model_settings.aws_region,
|
||||
aws_access_key_id=access_key or model_settings.aws_access_key,
|
||||
aws_secret_access_key=secret_key or model_settings.aws_secret_access_key,
|
||||
region_name=region or model_settings.aws_region,
|
||||
)
|
||||
credentials = sts_client.get_session_token()["Credentials"]
|
||||
|
||||
@@ -32,7 +36,7 @@ def get_bedrock_client():
|
||||
aws_access_key=credentials["AccessKeyId"],
|
||||
aws_secret_key=credentials["SecretAccessKey"],
|
||||
aws_session_token=credentials["SessionToken"],
|
||||
aws_region=model_settings.aws_region,
|
||||
aws_region=region or model_settings.aws_region,
|
||||
)
|
||||
return bedrock
|
||||
|
||||
|
||||
@@ -569,6 +569,9 @@ def create(
|
||||
# NOTE: max_tokens is required for Anthropic API
|
||||
max_tokens=llm_config.max_tokens,
|
||||
),
|
||||
provider_name=llm_config.provider_name,
|
||||
provider_category=llm_config.provider_category,
|
||||
user_id=user_id,
|
||||
)
|
||||
|
||||
elif llm_config.model_endpoint_type == "deepseek":
|
||||
|
||||
Reference in New Issue
Block a user