81 lines
3.5 KiB
Python
81 lines
3.5 KiB
Python
from typing import List, Optional, Union
|
|
|
|
import anthropic
|
|
from aioboto3.session import Session
|
|
|
|
from letta.llm_api.anthropic_client import AnthropicClient
|
|
from letta.log import get_logger
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.enums import AgentType, ProviderCategory
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.services.provider_manager import ProviderManager
|
|
from letta.settings import model_settings
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BedrockClient(AnthropicClient):
|
|
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
|
|
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
|
|
if llm_config.provider_category == ProviderCategory.byok:
|
|
(
|
|
override_access_key_id,
|
|
override_secret_access_key,
|
|
override_default_region,
|
|
) = await ProviderManager().get_bedrock_credentials_async(
|
|
llm_config.provider_name,
|
|
actor=self.actor,
|
|
)
|
|
return override_access_key_id, override_secret_access_key, override_default_region
|
|
|
|
@trace_method
|
|
async def _get_anthropic_client_async(
|
|
self, llm_config: LLMConfig, async_client: bool = False
|
|
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic, anthropic.AsyncAnthropicBedrock, anthropic.AnthropicBedrock]:
|
|
override_access_key_id, override_secret_access_key, override_default_region = await self.get_byok_overrides_async(llm_config)
|
|
|
|
session = Session()
|
|
async with session.client(
|
|
"sts",
|
|
aws_access_key_id=override_access_key_id or model_settings.aws_access_key_id,
|
|
aws_secret_access_key=override_secret_access_key or model_settings.aws_secret_access_key,
|
|
region_name=override_default_region or model_settings.aws_default_region,
|
|
) as sts_client:
|
|
session_token = await sts_client.get_session_token()
|
|
credentials = session_token["Credentials"]
|
|
|
|
if async_client:
|
|
return anthropic.AsyncAnthropicBedrock(
|
|
aws_access_key=credentials["AccessKeyId"],
|
|
aws_secret_key=credentials["SecretAccessKey"],
|
|
aws_session_token=credentials["SessionToken"],
|
|
aws_region=override_default_region or model_settings.aws_default_region,
|
|
max_retries=model_settings.anthropic_max_retries,
|
|
)
|
|
else:
|
|
return anthropic.AnthropicBedrock(
|
|
aws_access_key=credentials["AccessKeyId"],
|
|
aws_secret_key=credentials["SecretAccessKey"],
|
|
aws_session_token=credentials["SessionToken"],
|
|
aws_region=override_default_region or model_settings.aws_default_region,
|
|
max_retries=model_settings.anthropic_max_retries,
|
|
)
|
|
|
|
@trace_method
|
|
def build_request_data(
|
|
self,
|
|
agent_type: AgentType,
|
|
messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
tools: Optional[List[dict]] = None,
|
|
force_tool_call: Optional[str] = None,
|
|
requires_subsequent_tool_call: bool = False,
|
|
tool_return_truncation_chars: Optional[int] = None,
|
|
) -> dict:
|
|
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
|
# remove disallowed fields
|
|
if "tool_choice" in data:
|
|
del data["tool_choice"]["disable_parallel_tool_use"]
|
|
return data
|