diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index eaccf6e9..4fc949c0 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -27,7 +27,7 @@ from letta.otel.context import get_ctx_attributes from letta.otel.metric_registry import MetricRegistry from letta.otel.tracing import log_event, trace_method, tracer from letta.schemas.agent import AgentState, UpdateAgent -from letta.schemas.enums import MessageRole +from letta.schemas.enums import MessageRole, ProviderType from letta.schemas.letta_message import MessageType from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse @@ -512,12 +512,12 @@ class LettaAgent(BaseAgent): # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED - if agent_state.llm_config.model_endpoint_type == "anthropic": + if agent_state.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]: interface = AnthropicStreamingInterface( use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, ) - elif agent_state.llm_config.model_endpoint_type == "openai": + elif agent_state.llm_config.model_endpoint_type == ProviderType.openai: interface = OpenAIStreamingInterface( use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 6250b9ac..e8bb91fa 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -49,20 +49,20 @@ class AnthropicClient(LLMClientBase): @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: client = self._get_anthropic_client(llm_config, async_client=False) - response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) + response = client.beta.messages.create(**request_data) return response.model_dump() @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: client = await self._get_anthropic_client_async(llm_config, async_client=True) - response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) + response = await client.beta.messages.create(**request_data) return response.model_dump() @trace_method async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]: client = await self._get_anthropic_client_async(llm_config, async_client=True) request_data["stream"] = True - return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) + return await client.beta.messages.create(**request_data) @trace_method async def send_llm_batch_request_async( diff --git a/letta/llm_api/aws_bedrock.py b/letta/llm_api/aws_bedrock.py index 9f767397..c395868d 100644 --- a/letta/llm_api/aws_bedrock.py +++ b/letta/llm_api/aws_bedrock.py @@ -64,6 +64,28 @@ def bedrock_get_model_list(region_name: str) -> List[dict]: raise e +async def bedrock_get_model_list_async( + access_key_id: Optional[str] = None, + secret_access_key: Optional[str] = None, + default_region: Optional[str] = None, +) -> List[dict]: + from aioboto3.session import Session + + try: + session = Session() + async with session.client( + "bedrock", + aws_access_key_id=access_key_id, + aws_secret_access_key=secret_access_key, + region_name=default_region, + ) as bedrock: + response = await bedrock.list_inference_profiles() + return response["inferenceProfileSummaries"] + except Exception as e: + print(f"Error getting model list: {str(e)}") + raise e + + def bedrock_get_model_details(region_name: str, model_id: str) -> Dict[str, Any]: """ Get details for a specific model from Bedrock. diff --git a/letta/llm_api/bedrock_client.py b/letta/llm_api/bedrock_client.py new file mode 100644 index 00000000..b7b60a5e --- /dev/null +++ b/letta/llm_api/bedrock_client.py @@ -0,0 +1,74 @@ +from typing import List, Optional, Union + +import anthropic +from aioboto3.session import Session + +from letta.llm_api.anthropic_client import AnthropicClient +from letta.log import get_logger +from letta.otel.tracing import trace_method +from letta.schemas.enums import ProviderCategory +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage +from letta.services.provider_manager import ProviderManager +from letta.settings import model_settings + +logger = get_logger(__name__) + + +class BedrockClient(AnthropicClient): + + @trace_method + async def _get_anthropic_client_async( + self, llm_config: LLMConfig, async_client: bool = False + ) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic, anthropic.AsyncAnthropicBedrock, anthropic.AnthropicBedrock]: + override_access_key_id, override_secret_access_key, override_default_region = None, None, None + if llm_config.provider_category == ProviderCategory.byok: + ( + override_access_key_id, + override_secret_access_key, + override_default_region, + ) = await ProviderManager().get_bedrock_credentials_async( + llm_config.provider_name, + actor=self.actor, + ) + + session = Session() + async with session.client( + "sts", + aws_access_key_id=override_access_key_id or model_settings.aws_access_key_id, + aws_secret_access_key=override_secret_access_key or model_settings.aws_secret_access_key, + region_name=override_default_region or model_settings.aws_default_region, + ) as sts_client: + session_token = await sts_client.get_session_token() + credentials = session_token["Credentials"] + + if async_client: + return anthropic.AsyncAnthropicBedrock( + aws_access_key=credentials["AccessKeyId"], + aws_secret_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + aws_region=override_default_region or model_settings.aws_default_region, + max_retries=model_settings.anthropic_max_retries, + ) + else: + return anthropic.AnthropicBedrock( + aws_access_key=credentials["AccessKeyId"], + aws_secret_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + aws_region=override_default_region or model_settings.aws_default_region, + max_retries=model_settings.anthropic_max_retries, + ) + + @trace_method + def build_request_data( + self, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: Optional[List[dict]] = None, + force_tool_call: Optional[str] = None, + ) -> dict: + data = super().build_request_data(messages, llm_config, tools, force_tool_call) + # remove disallowed fields + if "tool_choice" in data: + del data["tool_choice"]["disable_parallel_tool_use"] + return data diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index 7372b68a..b21e4da2 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -51,6 +51,13 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) + case ProviderType.bedrock: + from letta.llm_api.bedrock_client import BedrockClient + + return BedrockClient( + put_inner_thoughts_first=put_inner_thoughts_first, + actor=actor, + ) case ProviderType.openai | ProviderType.together: from letta.llm_api.openai_client import OpenAIClient diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 13179e16..34eb1250 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -98,7 +98,7 @@ class Provider(ProviderBase): case ProviderType.anthropic: return AnthropicProvider(**self.model_dump(exclude_none=True)) case ProviderType.bedrock: - return AnthropicBedrockProvider(**self.model_dump(exclude_none=True)) + return BedrockProvider(**self.model_dump(exclude_none=True)) case ProviderType.ollama: return OllamaProvider(**self.model_dump(exclude_none=True)) case ProviderType.google_ai: @@ -1513,7 +1513,7 @@ class CohereProvider(OpenAIProvider): pass -class AnthropicBedrockProvider(Provider): +class BedrockProvider(Provider): provider_type: Literal[ProviderType.bedrock] = Field(ProviderType.bedrock, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") region: str = Field(..., description="AWS region for Bedrock") @@ -1539,6 +1539,32 @@ class AnthropicBedrockProvider(Provider): ) return configs + async def list_llm_models_async(self) -> List[LLMConfig]: + from letta.llm_api.aws_bedrock import bedrock_get_model_list_async + + models = await bedrock_get_model_list_async( + self.access_key, + self.api_key, + self.region, + ) + + configs = [] + for model_summary in models: + model_arn = model_summary["inferenceProfileArn"] + configs.append( + LLMConfig( + model=model_arn, + model_endpoint_type=self.provider_type.value, + model_endpoint=None, + context_window=self.get_model_context_window(model_arn), + handle=self.get_handle(model_arn), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs + def list_embedding_models(self): return [] @@ -1548,7 +1574,7 @@ class AnthropicBedrockProvider(Provider): return bedrock_get_model_context_window(model_name) - def get_handle(self, model_name: str) -> str: + def get_handle(self, model_name: str, is_embedding: bool = False, base_name: Optional[str] = None) -> str: print(model_name) model = model_name.split(".")[-1] - return f"bedrock/{model}" + return f"{self.name}/{model}" diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 4ddbdb3a..f1fef585 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -13,7 +13,7 @@ from sqlalchemy.exc import IntegrityError, OperationalError from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent -from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_MODEL_ENDPOINT from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns from letta.log import get_logger @@ -686,7 +686,7 @@ async def send_message( # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] if agent_eligible and model_compatible: if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: @@ -768,9 +768,9 @@ async def send_message_streaming( # TODO: This is redundant, remove soon agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] - model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] - not_letta_endpoint = not ("inference.letta.com" in agent.llm_config.model_endpoint) + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] + model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] + not_letta_endpoint = LETTA_MODEL_ENDPOINT != agent.llm_config.model_endpoint if agent_eligible and model_compatible: if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: @@ -857,7 +857,14 @@ async def process_message_background( try: agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] + model_compatible = agent.llm_config.model_endpoint_type in [ + "anthropic", + "openai", + "together", + "google_ai", + "google_vertex", + "bedrock", + ] if agent_eligible and model_compatible: if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: agent_loop = SleeptimeMultiAgentV2( @@ -1021,7 +1028,7 @@ async def summarize_agent_conversation( actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor, include_relationships=["multi_agent_group"]) agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex"] + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"] if agent_eligible and model_compatible: agent = LettaAgent( diff --git a/letta/server/server.py b/letta/server/server.py index 378fb687..4f5e766f 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -54,9 +54,9 @@ from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySumm from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.passage import Passage, PassageUpdate from letta.schemas.providers import ( - AnthropicBedrockProvider, AnthropicProvider, AzureProvider, + BedrockProvider, DeepSeekProvider, GoogleAIProvider, GoogleVertexProvider, @@ -367,7 +367,7 @@ class SyncServer(Server): ) if model_settings.aws_access_key_id and model_settings.aws_secret_access_key and model_settings.aws_default_region: self._enabled_providers.append( - AnthropicBedrockProvider( + BedrockProvider( name="bedrock", region=model_settings.aws_default_region, ) diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 60bbb547..e919b93f 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -1,4 +1,4 @@ -from typing import List, Optional, Union +from typing import List, Optional, Tuple, Union from letta.orm.provider import Provider as ProviderModel from letta.otel.tracing import trace_method @@ -196,10 +196,12 @@ class ProviderManager: @enforce_types @trace_method - async def get_bedrock_credentials_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: + async def get_bedrock_credentials_async( + self, provider_name: Union[str, None], actor: PydanticUser + ) -> Tuple[Optional[str], Optional[str], Optional[str]]: providers = await self.list_providers_async(name=provider_name, actor=actor) - access_key = providers[0].api_key if providers else None - secret_key = providers[0].api_secret if providers else None + access_key = providers[0].access_key if providers else None + secret_key = providers[0].api_key if providers else None region = providers[0].region if providers else None return access_key, secret_key, region diff --git a/poetry.lock b/poetry.lock index 42fbc3f2..b904429b 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,5 +1,66 @@ # This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand. +[[package]] +name = "aioboto3" +version = "14.3.0" +description = "Async boto3 wrapper" +optional = true +python-versions = "<4.0,>=3.8" +groups = ["main"] +markers = "extra == \"bedrock\"" +files = [ + {file = "aioboto3-14.3.0-py3-none-any.whl", hash = "sha256:aec5de94e9edc1ffbdd58eead38a37f00ddac59a519db749a910c20b7b81bca7"}, + {file = "aioboto3-14.3.0.tar.gz", hash = "sha256:1d18f88bb56835c607b62bb6cb907754d717bedde3ddfff6935727cb48a80135"}, +] + +[package.dependencies] +aiobotocore = {version = "2.22.0", extras = ["boto3"]} +aiofiles = ">=23.2.1" + +[package.extras] +chalice = ["chalice (>=1.24.0)"] +s3cse = ["cryptography (>=44.0.1)"] + +[[package]] +name = "aiobotocore" +version = "2.22.0" +description = "Async client for aws services using botocore and aiohttp" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"bedrock\"" +files = [ + {file = "aiobotocore-2.22.0-py3-none-any.whl", hash = "sha256:b4e6306f79df9d81daff1f9d63189a2dbee4b77ce3ab937304834e35eaaeeccf"}, + {file = "aiobotocore-2.22.0.tar.gz", hash = "sha256:11091477266b75c2b5d28421c1f2bc9a87d175d0b8619cb830805e7a113a170b"}, +] + +[package.dependencies] +aiohttp = ">=3.9.2,<4.0.0" +aioitertools = ">=0.5.1,<1.0.0" +boto3 = {version = ">=1.37.2,<1.37.4", optional = true, markers = "extra == \"boto3\""} +botocore = ">=1.37.2,<1.37.4" +jmespath = ">=0.7.1,<2.0.0" +multidict = ">=6.0.0,<7.0.0" +python-dateutil = ">=2.1,<3.0.0" +wrapt = ">=1.10.10,<2.0.0" + +[package.extras] +awscli = ["awscli (>=1.38.2,<1.38.4)"] +boto3 = ["boto3 (>=1.37.2,<1.37.4)"] + +[[package]] +name = "aiofiles" +version = "24.1.0" +description = "File support for asyncio." +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"bedrock\"" +files = [ + {file = "aiofiles-24.1.0-py3-none-any.whl", hash = "sha256:b4ec55f4195e3eb5d7abd1bf7e061763e864dd4954231fb8539a0ef8bb8260e5"}, + {file = "aiofiles-24.1.0.tar.gz", hash = "sha256:22a075c9e5a3810f0c2e48f3008c94d68c65d763b9b03857924c99e57355166c"}, +] + [[package]] name = "aiohappyeyeballs" version = "2.6.1" @@ -116,6 +177,23 @@ yarl = ">=1.17.0,<2.0" [package.extras] speedups = ["Brotli ; platform_python_implementation == \"CPython\"", "aiodns (>=3.2.0) ; sys_platform == \"linux\" or sys_platform == \"darwin\"", "brotlicffi ; platform_python_implementation != \"CPython\""] +[[package]] +name = "aioitertools" +version = "0.12.0" +description = "itertools and builtins for AsyncIO and mixed iterables" +optional = true +python-versions = ">=3.8" +groups = ["main"] +markers = "extra == \"bedrock\"" +files = [ + {file = "aioitertools-0.12.0-py3-none-any.whl", hash = "sha256:fc1f5fac3d737354de8831cbba3eb04f79dd649d8f3afb4c5b114925e662a796"}, + {file = "aioitertools-0.12.0.tar.gz", hash = "sha256:c2a9055b4fbb7705f561b9d86053e8af5d10cc845d22c32008c43490b2d8dd6b"}, +] + +[package.extras] +dev = ["attribution (==1.8.0)", "black (==24.8.0)", "build (>=1.2)", "coverage (==7.6.1)", "flake8 (==7.1.1)", "flit (==3.9.0)", "mypy (==1.11.2)", "ufmt (==2.7.1)", "usort (==1.0.8.post1)"] +docs = ["sphinx (==8.0.2)", "sphinx-mdinclude (==0.6.2)"] + [[package]] name = "aiomultiprocess" version = "0.9.1" @@ -620,19 +698,19 @@ files = [ [[package]] name = "boto3" -version = "1.37.31" +version = "1.37.3" description = "The AWS SDK for Python" optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"bedrock\"" files = [ - {file = "boto3-1.37.31-py3-none-any.whl", hash = "sha256:cf8997be0742a5cab9d33a138ef56e423a8ebd8881f6f73e95076b26656b36dc"}, - {file = "boto3-1.37.31.tar.gz", hash = "sha256:dfee02b2f8f632a239a2f4ba6a2d568e2edd7f7464e9afd8a487fdb3fa9a0ad3"}, + {file = "boto3-1.37.3-py3-none-any.whl", hash = "sha256:2063b40af99fd02f6228ff52397b552ff3353831edaf8d25cc04801827ab9794"}, + {file = "boto3-1.37.3.tar.gz", hash = "sha256:21f3ce0ef111297e63a6eb998a25197b8c10982970c320d4c6e8db08be2157be"}, ] [package.dependencies] -botocore = ">=1.37.31,<1.38.0" +botocore = ">=1.37.3,<1.38.0" jmespath = ">=0.7.1,<2.0.0" s3transfer = ">=0.11.0,<0.12.0" @@ -641,15 +719,15 @@ crt = ["botocore[crt] (>=1.21.0,<2.0a0)"] [[package]] name = "botocore" -version = "1.37.31" +version = "1.37.3" description = "Low-level, data-driven core of boto 3." optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"bedrock\"" files = [ - {file = "botocore-1.37.31-py3-none-any.whl", hash = "sha256:598a33a7a0e5a014bd1416c999a0b9c634fbbba3d1363e2368e6a92da4544df4"}, - {file = "botocore-1.37.31.tar.gz", hash = "sha256:eb3dfa44a87187bd82c3b493d568d8436270d4d000f237b49b669a01fcd8a21c"}, + {file = "botocore-1.37.3-py3-none-any.whl", hash = "sha256:d01bd3bf4c80e61fa88d636ad9f5c9f60a551d71549b481386c6b4efe0bb2b2e"}, + {file = "botocore-1.37.3.tar.gz", hash = "sha256:fe8403eb55a88faf9b0f9da6615e5bee7be056d75e17af66c3c8f0a3b0648da4"}, ] [package.dependencies] @@ -6326,22 +6404,22 @@ pyasn1 = ">=0.1.3" [[package]] name = "s3transfer" -version = "0.11.4" +version = "0.11.3" description = "An Amazon S3 Transfer Manager" optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"bedrock\"" files = [ - {file = "s3transfer-0.11.4-py3-none-any.whl", hash = "sha256:ac265fa68318763a03bf2dc4f39d5cbd6a9e178d81cc9483ad27da33637e320d"}, - {file = "s3transfer-0.11.4.tar.gz", hash = "sha256:559f161658e1cf0a911f45940552c696735f5c74e64362e515f333ebed87d679"}, + {file = "s3transfer-0.11.3-py3-none-any.whl", hash = "sha256:ca855bdeb885174b5ffa95b9913622459d4ad8e331fc98eb01e6d5eb6a30655d"}, + {file = "s3transfer-0.11.3.tar.gz", hash = "sha256:edae4977e3a122445660c7c114bba949f9d191bae3b34a096f18a1c8c354527a"}, ] [package.dependencies] -botocore = ">=1.37.4,<2.0a.0" +botocore = ">=1.36.0,<2.0a.0" [package.extras] -crt = ["botocore[crt] (>=1.37.4,<2.0a.0)"] +crt = ["botocore[crt] (>=1.36.0,<2.0a.0)"] [[package]] name = "scramp" @@ -7783,7 +7861,7 @@ cffi = ["cffi (>=1.11)"] [extras] all = ["autoflake", "black", "docker", "fastapi", "granian", "isort", "langchain", "langchain-community", "locust", "pexpect", "pg8000", "pgvector", "pre-commit", "psycopg2", "psycopg2-binary", "pyright", "pytest-asyncio", "pytest-order", "redis", "uvicorn", "uvloop", "wikipedia"] -bedrock = ["boto3"] +bedrock = ["aioboto3", "boto3"] cloud-tool-sandbox = ["e2b-code-interpreter"] desktop = ["docker", "fastapi", "langchain", "langchain-community", "locust", "pg8000", "pgvector", "psycopg2", "psycopg2-binary", "pyright", "uvicorn", "wikipedia"] dev = ["autoflake", "black", "isort", "locust", "pexpect", "pre-commit", "pyright", "pytest-asyncio", "pytest-order"] @@ -7798,4 +7876,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.1" python-versions = "<3.14,>=3.10" -content-hash = "87b1d77da4ccba13d41d7b6ed9fe24302982e181f84ad93f0cb409f216e33255" +content-hash = "9372e0eacfc54bd204ddafcbfc11ac7bf0688a270e1906a10f57fda0191b4d73" diff --git a/pyproject.toml b/pyproject.toml index a575de79..dcbbea0a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -98,6 +98,7 @@ granian = {version = "^2.3.2", extras = ["uvloop", "reload"], optional = true} redis = {version = "^6.2.0", optional = true} structlog = "^25.4.0" certifi = "^2025.6.15" +aioboto3 = {version = "^14.3.0", optional = true} [tool.poetry.extras] @@ -109,7 +110,7 @@ server = ["websockets", "fastapi", "uvicorn"] cloud-tool-sandbox = ["e2b-code-interpreter"] external-tools = ["docker", "langchain", "wikipedia", "langchain-community", "firecrawl-py"] tests = ["wikipedia"] -bedrock = ["boto3"] +bedrock = ["boto3", "aioboto3"] google = ["google-genai"] desktop = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pyright", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust"] all = ["pgvector", "pg8000", "psycopg2-binary", "psycopg2", "pytest", "pytest-asyncio", "pexpect", "black", "pre-commit", "pyright", "pytest-order", "autoflake", "isort", "websockets", "fastapi", "uvicorn", "docker", "langchain", "wikipedia", "langchain-community", "locust", "uvloop", "granian", "redis"] diff --git a/tests/configs/llm_model_configs/bedrock-claude-3-5-sonnet.json b/tests/configs/llm_model_configs/bedrock-claude-4-sonnet.json similarity index 57% rename from tests/configs/llm_model_configs/bedrock-claude-3-5-sonnet.json rename to tests/configs/llm_model_configs/bedrock-claude-4-sonnet.json index af62ae69..2680ee89 100644 --- a/tests/configs/llm_model_configs/bedrock-claude-3-5-sonnet.json +++ b/tests/configs/llm_model_configs/bedrock-claude-4-sonnet.json @@ -1,6 +1,6 @@ { "context_window": 200000, - "model": "arn:aws:bedrock:us-west-2:850995572407:inference-profile/us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "model": "arn:aws:bedrock:us-east-1:474668403324:inference-profile/us.anthropic.claude-sonnet-4-20250514-v1:0", "model_endpoint_type": "bedrock", "model_endpoint": null, "model_wrapper": null, diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 3a848527..16c9e745 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -110,6 +110,7 @@ all_configs = [ "claude-3-5-sonnet.json", "claude-3-7-sonnet.json", "claude-3-7-sonnet-extended.json", + "bedrock-claude-4-sonnet.json", "gemini-1.5-pro.json", "gemini-2.5-flash-vertex.json", "gemini-2.5-pro-vertex.json",