feat: Use Async OpenAI client to prevent blocking server thread (#811)

This commit is contained in:
Matthew Zhou
2025-01-28 12:02:33 -10:00
committed by GitHub
parent 7bc59d6612
commit 9c5033e0bd
3 changed files with 59 additions and 26 deletions

View File

@@ -29,6 +29,7 @@ from letta.schemas.openai.chat_completion_request import ChatCompletionRequest,
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import ModelSettings
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
from letta.utils import run_async_task
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"]
@@ -156,21 +157,25 @@ def create(
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
stream_interface, AgentRefreshStreamingInterface
), type(stream_interface)
response = openai_chat_completions_process_stream(
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=model_settings.openai_api_key,
chat_completion_request=data,
stream_interface=stream_interface,
response = run_async_task(
openai_chat_completions_process_stream(
url=llm_config.model_endpoint,
api_key=model_settings.openai_api_key,
chat_completion_request=data,
stream_interface=stream_interface,
)
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
response = openai_chat_completions_request(
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
api_key=model_settings.openai_api_key,
chat_completion_request=data,
response = run_async_task(
openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.openai_api_key,
chat_completion_request=data,
)
)
finally:
if isinstance(stream_interface, AgentChunkStreamingInterface):
@@ -344,9 +349,12 @@ def create(
stream_interface.stream_start()
try:
# groq uses the openai chat completions API, so this component should be reusable
response = openai_chat_completions_request(
api_key=model_settings.groq_api_key,
chat_completion_request=data,
response = run_async_task(
openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=model_settings.groq_api_key,
chat_completion_request=data,
)
)
finally:
if isinstance(stream_interface, AgentChunkStreamingInterface):

View File

@@ -1,8 +1,8 @@
import warnings
from typing import Generator, List, Optional, Union
from typing import AsyncGenerator, List, Optional, Union
import requests
from openai import OpenAI
from openai import AsyncOpenAI
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
@@ -158,7 +158,7 @@ def build_openai_chat_completions_request(
return data
def openai_chat_completions_process_stream(
async def openai_chat_completions_process_stream(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
@@ -229,9 +229,10 @@ def openai_chat_completions_process_stream(
stream_interface.stream_start()
n_chunks = 0 # approx == n_tokens
chunk_idx = 0
try:
for chunk_idx, chat_completion_chunk in enumerate(
openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request)
async for chat_completion_chunk in openai_chat_completions_request_stream(
url=url, api_key=api_key, chat_completion_request=chat_completion_request
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
@@ -348,6 +349,7 @@ def openai_chat_completions_process_stream(
# increment chunk counter
n_chunks += 1
chunk_idx += 1
except Exception as e:
if stream_interface:
@@ -380,24 +382,24 @@ def openai_chat_completions_process_stream(
return chat_completion_response
def openai_chat_completions_request_stream(
async def openai_chat_completions_request_stream(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
) -> Generator[ChatCompletionChunkResponse, None, None]:
) -> AsyncGenerator[ChatCompletionChunkResponse, None]:
data = prepare_openai_payload(chat_completion_request)
data["stream"] = True
client = OpenAI(
client = AsyncOpenAI(
api_key=api_key,
base_url=url,
)
stream = client.chat.completions.create(**data)
for chunk in stream:
stream = await client.chat.completions.create(**data)
async for chunk in stream:
# TODO: Use the native OpenAI objects here?
yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True))
def openai_chat_completions_request(
async def openai_chat_completions_request(
url: str,
api_key: str,
chat_completion_request: ChatCompletionRequest,
@@ -410,8 +412,8 @@ def openai_chat_completions_request(
https://platform.openai.com/docs/guides/text-generation?lang=curl
"""
data = prepare_openai_payload(chat_completion_request)
client = OpenAI(api_key=api_key, base_url=url)
chat_completion = client.chat.completions.create(**data)
client = AsyncOpenAI(api_key=api_key, base_url=url)
chat_completion = await client.chat.completions.create(**data)
return ChatCompletionResponse(**chat_completion.model_dump())

View File

@@ -1,3 +1,4 @@
import asyncio
import copy
import difflib
import hashlib
@@ -15,7 +16,7 @@ import uuid
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from functools import wraps
from typing import List, Union, _GenericAlias, get_args, get_origin, get_type_hints
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
from urllib.parse import urljoin, urlparse
import demjson3 as demjson
@@ -1127,3 +1128,25 @@ def get_friendly_error_msg(function_name: str, exception_name: str, exception_me
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
return error_msg
def run_async_task(coro: Coroutine[Any, Any, Any]) -> Any:
"""
Safely runs an asynchronous coroutine in a synchronous context.
If an event loop is already running, it uses `asyncio.ensure_future`.
Otherwise, it creates a new event loop and runs the coroutine.
Args:
coro: The coroutine to execute.
Returns:
The result of the coroutine.
"""
try:
# If there's already a running event loop, schedule the coroutine
loop = asyncio.get_running_loop()
return asyncio.run_until_complete(coro) if loop.is_closed() else asyncio.ensure_future(coro)
except RuntimeError:
# If no event loop is running, create a new one
return asyncio.run(coro)