diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index f131e776..b81967d4 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -53,13 +53,13 @@ class AnthropicClient(LLMClientBase): @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: - client = self._get_anthropic_client(llm_config, async_client=True) + client = await self._get_anthropic_client_async(llm_config, async_client=True) response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) return response.model_dump() @trace_method async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]: - client = self._get_anthropic_client(llm_config, async_client=True) + client = await self._get_anthropic_client_async(llm_config, async_client=True) request_data["stream"] = True return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"]) @@ -99,7 +99,7 @@ class AnthropicClient(LLMClientBase): for agent_id in agent_messages_mapping } - client = self._get_anthropic_client(list(agent_llm_config_mapping.values())[0], async_client=True) + client = await self._get_anthropic_client_async(list(agent_llm_config_mapping.values())[0], async_client=True) anthropic_requests = [ Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items() @@ -134,6 +134,26 @@ class AnthropicClient(LLMClientBase): else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries) ) + @trace_method + async def _get_anthropic_client_async( + self, llm_config: LLMConfig, async_client: bool = False + ) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]: + override_key = None + if llm_config.provider_category == ProviderCategory.byok: + override_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor) + + if async_client: + return ( + anthropic.AsyncAnthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries) + if override_key + else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries) + ) + return ( + anthropic.Anthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries) + if override_key + else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries) + ) + @trace_method def build_request_data( self, diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 3872d851..d0633238 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -125,6 +125,23 @@ class OpenAIClient(LLMClientBase): return kwargs + async def _prepare_client_kwargs_async(self, llm_config: LLMConfig) -> dict: + api_key = None + if llm_config.provider_category == ProviderCategory.byok: + from letta.services.provider_manager import ProviderManager + + api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor) + if llm_config.model_endpoint_type == ProviderType.together: + api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") + + if not api_key: + api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") + # supposedly the openai python client requires a dummy API key + api_key = api_key or "DUMMY_API_KEY" + kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint} + + return kwargs + @trace_method def build_request_data( self, @@ -230,7 +247,8 @@ class OpenAIClient(LLMClientBase): """ Performs underlying asynchronous request to OpenAI API and returns raw response dict. """ - client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config)) + kwargs = await self._prepare_client_kwargs_async(llm_config) + client = AsyncOpenAI(**kwargs) response: ChatCompletion = await client.chat.completions.create(**request_data) return response.model_dump() @@ -265,7 +283,8 @@ class OpenAIClient(LLMClientBase): """ Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator. """ - client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config)) + kwargs = await self._prepare_client_kwargs_async(llm_config) + client = AsyncOpenAI(**kwargs) response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( **request_data, stream=True, stream_options={"include_usage": True} ) diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 6b2bab01..9d957e5f 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -129,6 +129,12 @@ class ProviderManager: providers = self.list_providers(name=provider_name, actor=actor) return providers[0].api_key if providers else None + @enforce_types + @trace_method + async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: + providers = await self.list_providers_async(name=provider_name, actor=actor) + return providers[0].api_key if providers else None + @enforce_types @trace_method def check_provider_api_key(self, provider_check: ProviderCheck) -> None: diff --git a/tests/test_llm_clients.py b/tests/test_llm_clients.py index 7eabb864..099a98a9 100644 --- a/tests/test_llm_clients.py +++ b/tests/test_llm_clients.py @@ -73,8 +73,8 @@ async def test_send_llm_batch_request_async_success( anthropic_client, mock_agent_messages, mock_agent_tools, mock_agent_llm_config, dummy_beta_message_batch ): """Test a successful batch request using mocked Anthropic client responses.""" - # Patch the _get_anthropic_client method so that it returns a mock client. - with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client: + # Patch the _get_anthropic_client_async method so that it returns a mock client. + with patch.object(anthropic_client, "_get_anthropic_client_async") as mock_get_client: mock_client = AsyncMock() # Set the create method to return the dummy response asynchronously. mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch