From 9c5033e0bd1ab747c92a0ea7e7bbfae9786b19c7 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 28 Jan 2025 12:02:33 -1000 Subject: [PATCH] feat: Use Async OpenAI client to prevent blocking server thread (#811) --- letta/llm_api/llm_api_tools.py | 32 ++++++++++++++++++++------------ letta/llm_api/openai.py | 28 +++++++++++++++------------- letta/utils.py | 25 ++++++++++++++++++++++++- 3 files changed, 59 insertions(+), 26 deletions(-) diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index c6e8d63a..fe198453 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -29,6 +29,7 @@ from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.settings import ModelSettings from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface +from letta.utils import run_async_task LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"] @@ -156,21 +157,25 @@ def create( assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( stream_interface, AgentRefreshStreamingInterface ), type(stream_interface) - response = openai_chat_completions_process_stream( - url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions - api_key=model_settings.openai_api_key, - chat_completion_request=data, - stream_interface=stream_interface, + response = run_async_task( + openai_chat_completions_process_stream( + url=llm_config.model_endpoint, + api_key=model_settings.openai_api_key, + chat_completion_request=data, + stream_interface=stream_interface, + ) ) else: # Client did not request token streaming (expect a blocking backend response) data.stream = False if isinstance(stream_interface, AgentChunkStreamingInterface): stream_interface.stream_start() try: - response = openai_chat_completions_request( - url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions - api_key=model_settings.openai_api_key, - chat_completion_request=data, + response = run_async_task( + openai_chat_completions_request( + url=llm_config.model_endpoint, + api_key=model_settings.openai_api_key, + chat_completion_request=data, + ) ) finally: if isinstance(stream_interface, AgentChunkStreamingInterface): @@ -344,9 +349,12 @@ def create( stream_interface.stream_start() try: # groq uses the openai chat completions API, so this component should be reusable - response = openai_chat_completions_request( - api_key=model_settings.groq_api_key, - chat_completion_request=data, + response = run_async_task( + openai_chat_completions_request( + url=llm_config.model_endpoint, + api_key=model_settings.groq_api_key, + chat_completion_request=data, + ) ) finally: if isinstance(stream_interface, AgentChunkStreamingInterface): diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index ca0c25f2..d931e8fb 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -1,8 +1,8 @@ import warnings -from typing import Generator, List, Optional, Union +from typing import AsyncGenerator, List, Optional, Union import requests -from openai import OpenAI +from openai import AsyncOpenAI from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST @@ -158,7 +158,7 @@ def build_openai_chat_completions_request( return data -def openai_chat_completions_process_stream( +async def openai_chat_completions_process_stream( url: str, api_key: str, chat_completion_request: ChatCompletionRequest, @@ -229,9 +229,10 @@ def openai_chat_completions_process_stream( stream_interface.stream_start() n_chunks = 0 # approx == n_tokens + chunk_idx = 0 try: - for chunk_idx, chat_completion_chunk in enumerate( - openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request) + async for chat_completion_chunk in openai_chat_completions_request_stream( + url=url, api_key=api_key, chat_completion_request=chat_completion_request ): assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk) @@ -348,6 +349,7 @@ def openai_chat_completions_process_stream( # increment chunk counter n_chunks += 1 + chunk_idx += 1 except Exception as e: if stream_interface: @@ -380,24 +382,24 @@ def openai_chat_completions_process_stream( return chat_completion_response -def openai_chat_completions_request_stream( +async def openai_chat_completions_request_stream( url: str, api_key: str, chat_completion_request: ChatCompletionRequest, -) -> Generator[ChatCompletionChunkResponse, None, None]: +) -> AsyncGenerator[ChatCompletionChunkResponse, None]: data = prepare_openai_payload(chat_completion_request) data["stream"] = True - client = OpenAI( + client = AsyncOpenAI( api_key=api_key, base_url=url, ) - stream = client.chat.completions.create(**data) - for chunk in stream: + stream = await client.chat.completions.create(**data) + async for chunk in stream: # TODO: Use the native OpenAI objects here? yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True)) -def openai_chat_completions_request( +async def openai_chat_completions_request( url: str, api_key: str, chat_completion_request: ChatCompletionRequest, @@ -410,8 +412,8 @@ def openai_chat_completions_request( https://platform.openai.com/docs/guides/text-generation?lang=curl """ data = prepare_openai_payload(chat_completion_request) - client = OpenAI(api_key=api_key, base_url=url) - chat_completion = client.chat.completions.create(**data) + client = AsyncOpenAI(api_key=api_key, base_url=url) + chat_completion = await client.chat.completions.create(**data) return ChatCompletionResponse(**chat_completion.model_dump()) diff --git a/letta/utils.py b/letta/utils.py index 18a5093a..171391e3 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1,3 +1,4 @@ +import asyncio import copy import difflib import hashlib @@ -15,7 +16,7 @@ import uuid from contextlib import contextmanager from datetime import datetime, timedelta, timezone from functools import wraps -from typing import List, Union, _GenericAlias, get_args, get_origin, get_type_hints +from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints from urllib.parse import urljoin, urlparse import demjson3 as demjson @@ -1127,3 +1128,25 @@ def get_friendly_error_msg(function_name: str, exception_name: str, exception_me if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT: error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT] return error_msg + + +def run_async_task(coro: Coroutine[Any, Any, Any]) -> Any: + """ + Safely runs an asynchronous coroutine in a synchronous context. + + If an event loop is already running, it uses `asyncio.ensure_future`. + Otherwise, it creates a new event loop and runs the coroutine. + + Args: + coro: The coroutine to execute. + + Returns: + The result of the coroutine. + """ + try: + # If there's already a running event loop, schedule the coroutine + loop = asyncio.get_running_loop() + return asyncio.run_until_complete(coro) if loop.is_closed() else asyncio.ensure_future(coro) + except RuntimeError: + # If no event loop is running, create a new one + return asyncio.run(coro)