feat(asyncify): byok in async loop (#2421)

This commit is contained in:
cthomas
2025-05-25 19:47:20 -07:00
committed by GitHub
parent 0633ae116b
commit 20470844a7
4 changed files with 52 additions and 7 deletions

View File

@@ -53,13 +53,13 @@ class AnthropicClient(LLMClientBase):
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(llm_config, async_client=True)
client = await self._get_anthropic_client_async(llm_config, async_client=True)
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
client = self._get_anthropic_client(llm_config, async_client=True)
client = await self._get_anthropic_client_async(llm_config, async_client=True)
request_data["stream"] = True
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@@ -99,7 +99,7 @@ class AnthropicClient(LLMClientBase):
for agent_id in agent_messages_mapping
}
client = self._get_anthropic_client(list(agent_llm_config_mapping.values())[0], async_client=True)
client = await self._get_anthropic_client_async(list(agent_llm_config_mapping.values())[0], async_client=True)
anthropic_requests = [
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
@@ -134,6 +134,26 @@ class AnthropicClient(LLMClientBase):
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
)
@trace_method
async def _get_anthropic_client_async(
self, llm_config: LLMConfig, async_client: bool = False
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
override_key = None
if llm_config.provider_category == ProviderCategory.byok:
override_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
if async_client:
return (
anthropic.AsyncAnthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
if override_key
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
)
return (
anthropic.Anthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
if override_key
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
)
@trace_method
def build_request_data(
self,

View File

@@ -125,6 +125,23 @@ class OpenAIClient(LLMClientBase):
return kwargs
async def _prepare_client_kwargs_async(self, llm_config: LLMConfig) -> dict:
api_key = None
if llm_config.provider_category == ProviderCategory.byok:
from letta.services.provider_manager import ProviderManager
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
if llm_config.model_endpoint_type == ProviderType.together:
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY"
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
return kwargs
@trace_method
def build_request_data(
self,
@@ -230,7 +247,8 @@ class OpenAIClient(LLMClientBase):
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
kwargs = await self._prepare_client_kwargs_async(llm_config)
client = AsyncOpenAI(**kwargs)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@@ -265,7 +283,8 @@ class OpenAIClient(LLMClientBase):
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
kwargs = await self._prepare_client_kwargs_async(llm_config)
client = AsyncOpenAI(**kwargs)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)

View File

@@ -129,6 +129,12 @@ class ProviderManager:
providers = self.list_providers(name=provider_name, actor=actor)
return providers[0].api_key if providers else None
@enforce_types
@trace_method
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
providers = await self.list_providers_async(name=provider_name, actor=actor)
return providers[0].api_key if providers else None
@enforce_types
@trace_method
def check_provider_api_key(self, provider_check: ProviderCheck) -> None:

View File

@@ -73,8 +73,8 @@ async def test_send_llm_batch_request_async_success(
anthropic_client, mock_agent_messages, mock_agent_tools, mock_agent_llm_config, dummy_beta_message_batch
):
"""Test a successful batch request using mocked Anthropic client responses."""
# Patch the _get_anthropic_client method so that it returns a mock client.
with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client:
# Patch the _get_anthropic_client_async method so that it returns a mock client.
with patch.object(anthropic_client, "_get_anthropic_client_async") as mock_get_client:
mock_client = AsyncMock()
# Set the create method to return the dummy response asynchronously.
mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch