* base requirements * autofix * Configure ruff for Python linting and formatting - Set up minimal ruff configuration with basic checks (E, W, F, I) - Add temporary ignores for common issues during migration - Configure pre-commit hooks to use ruff with pass_filenames - This enables gradual migration from black to ruff * Delete sdj * autofixed only * migrate lint action * more autofixed * more fixes * change precommit * try changing the hook * try this stuff
78 lines
3.4 KiB
Python
78 lines
3.4 KiB
Python
from typing import List, Optional, Union
|
|
|
|
import anthropic
|
|
from aioboto3.session import Session
|
|
|
|
from letta.llm_api.anthropic_client import AnthropicClient
|
|
from letta.log import get_logger
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.enums import ProviderCategory
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.services.provider_manager import ProviderManager
|
|
from letta.settings import model_settings
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class BedrockClient(AnthropicClient):
|
|
async def get_byok_overrides_async(self, llm_config: LLMConfig) -> tuple[str, str, str]:
|
|
override_access_key_id, override_secret_access_key, override_default_region = None, None, None
|
|
if llm_config.provider_category == ProviderCategory.byok:
|
|
(
|
|
override_access_key_id,
|
|
override_secret_access_key,
|
|
override_default_region,
|
|
) = await ProviderManager().get_bedrock_credentials_async(
|
|
llm_config.provider_name,
|
|
actor=self.actor,
|
|
)
|
|
return override_access_key_id, override_secret_access_key, override_default_regions
|
|
|
|
@trace_method
|
|
async def _get_anthropic_client_async(
|
|
self, llm_config: LLMConfig, async_client: bool = False
|
|
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic, anthropic.AsyncAnthropicBedrock, anthropic.AnthropicBedrock]:
|
|
override_access_key_id, override_secret_access_key, override_default_region = await self.get_byok_overrides_async(llm_config)
|
|
|
|
session = Session()
|
|
async with session.client(
|
|
"sts",
|
|
aws_access_key_id=override_access_key_id or model_settings.aws_access_key_id,
|
|
aws_secret_access_key=override_secret_access_key or model_settings.aws_secret_access_key,
|
|
region_name=override_default_region or model_settings.aws_default_region,
|
|
) as sts_client:
|
|
session_token = await sts_client.get_session_token()
|
|
credentials = session_token["Credentials"]
|
|
|
|
if async_client:
|
|
return anthropic.AsyncAnthropicBedrock(
|
|
aws_access_key=credentials["AccessKeyId"],
|
|
aws_secret_key=credentials["SecretAccessKey"],
|
|
aws_session_token=credentials["SessionToken"],
|
|
aws_region=override_default_region or model_settings.aws_default_region,
|
|
max_retries=model_settings.anthropic_max_retries,
|
|
)
|
|
else:
|
|
return anthropic.AnthropicBedrock(
|
|
aws_access_key=credentials["AccessKeyId"],
|
|
aws_secret_key=credentials["SecretAccessKey"],
|
|
aws_session_token=credentials["SessionToken"],
|
|
aws_region=override_default_region or model_settings.aws_default_region,
|
|
max_retries=model_settings.anthropic_max_retries,
|
|
)
|
|
|
|
@trace_method
|
|
def build_request_data(
|
|
self,
|
|
messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
tools: Optional[List[dict]] = None,
|
|
force_tool_call: Optional[str] = None,
|
|
) -> dict:
|
|
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
|
# remove disallowed fields
|
|
if "tool_choice" in data:
|
|
del data["tool_choice"]["disable_parallel_tool_use"]
|
|
return data
|