feat(asyncify): byok in async loop (#2421)
This commit is contained in:
@@ -53,13 +53,13 @@ class AnthropicClient(LLMClientBase):
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
client = self._get_anthropic_client(llm_config, async_client=True)
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[BetaRawMessageStreamEvent]:
|
||||
client = self._get_anthropic_client(llm_config, async_client=True)
|
||||
client = await self._get_anthropic_client_async(llm_config, async_client=True)
|
||||
request_data["stream"] = True
|
||||
return await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
|
||||
|
||||
@@ -99,7 +99,7 @@ class AnthropicClient(LLMClientBase):
|
||||
for agent_id in agent_messages_mapping
|
||||
}
|
||||
|
||||
client = self._get_anthropic_client(list(agent_llm_config_mapping.values())[0], async_client=True)
|
||||
client = await self._get_anthropic_client_async(list(agent_llm_config_mapping.values())[0], async_client=True)
|
||||
|
||||
anthropic_requests = [
|
||||
Request(custom_id=agent_id, params=MessageCreateParamsNonStreaming(**params)) for agent_id, params in requests.items()
|
||||
@@ -134,6 +134,26 @@ class AnthropicClient(LLMClientBase):
|
||||
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _get_anthropic_client_async(
|
||||
self, llm_config: LLMConfig, async_client: bool = False
|
||||
) -> Union[anthropic.AsyncAnthropic, anthropic.Anthropic]:
|
||||
override_key = None
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
override_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
|
||||
|
||||
if async_client:
|
||||
return (
|
||||
anthropic.AsyncAnthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if override_key
|
||||
else anthropic.AsyncAnthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
return (
|
||||
anthropic.Anthropic(api_key=override_key, max_retries=model_settings.anthropic_max_retries)
|
||||
if override_key
|
||||
else anthropic.Anthropic(max_retries=model_settings.anthropic_max_retries)
|
||||
)
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
|
||||
@@ -125,6 +125,23 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
return kwargs
|
||||
|
||||
async def _prepare_client_kwargs_async(self, llm_config: LLMConfig) -> dict:
|
||||
api_key = None
|
||||
if llm_config.provider_category == ProviderCategory.byok:
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
|
||||
if llm_config.model_endpoint_type == ProviderType.together:
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
# supposedly the openai python client requires a dummy API key
|
||||
api_key = api_key or "DUMMY_API_KEY"
|
||||
kwargs = {"api_key": api_key, "base_url": llm_config.model_endpoint}
|
||||
|
||||
return kwargs
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
@@ -230,7 +247,8 @@ class OpenAIClient(LLMClientBase):
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
kwargs = await self._prepare_client_kwargs_async(llm_config)
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@@ -265,7 +283,8 @@ class OpenAIClient(LLMClientBase):
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
client = AsyncOpenAI(**self._prepare_client_kwargs(llm_config))
|
||||
kwargs = await self._prepare_client_kwargs_async(llm_config)
|
||||
client = AsyncOpenAI(**kwargs)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data, stream=True, stream_options={"include_usage": True}
|
||||
)
|
||||
|
||||
@@ -129,6 +129,12 @@ class ProviderManager:
|
||||
providers = self.list_providers(name=provider_name, actor=actor)
|
||||
return providers[0].api_key if providers else None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_override_key_async(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]:
|
||||
providers = await self.list_providers_async(name=provider_name, actor=actor)
|
||||
return providers[0].api_key if providers else None
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def check_provider_api_key(self, provider_check: ProviderCheck) -> None:
|
||||
|
||||
@@ -73,8 +73,8 @@ async def test_send_llm_batch_request_async_success(
|
||||
anthropic_client, mock_agent_messages, mock_agent_tools, mock_agent_llm_config, dummy_beta_message_batch
|
||||
):
|
||||
"""Test a successful batch request using mocked Anthropic client responses."""
|
||||
# Patch the _get_anthropic_client method so that it returns a mock client.
|
||||
with patch.object(anthropic_client, "_get_anthropic_client") as mock_get_client:
|
||||
# Patch the _get_anthropic_client_async method so that it returns a mock client.
|
||||
with patch.object(anthropic_client, "_get_anthropic_client_async") as mock_get_client:
|
||||
mock_client = AsyncMock()
|
||||
# Set the create method to return the dummy response asynchronously.
|
||||
mock_client.beta.messages.batches.create.return_value = dummy_beta_message_batch
|
||||
|
||||
Reference in New Issue
Block a user