feat: Use Async OpenAI client to prevent blocking server thread (#811)
This commit is contained in:
@@ -29,6 +29,7 @@ from letta.schemas.openai.chat_completion_request import ChatCompletionRequest,
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.settings import ModelSettings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface
|
||||
from letta.utils import run_async_task
|
||||
|
||||
LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"]
|
||||
|
||||
@@ -156,21 +157,25 @@ def create(
|
||||
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
|
||||
stream_interface, AgentRefreshStreamingInterface
|
||||
), type(stream_interface)
|
||||
response = openai_chat_completions_process_stream(
|
||||
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
||||
api_key=model_settings.openai_api_key,
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
response = run_async_task(
|
||||
openai_chat_completions_process_stream(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=model_settings.openai_api_key,
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
)
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_start()
|
||||
try:
|
||||
response = openai_chat_completions_request(
|
||||
url=llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions
|
||||
api_key=model_settings.openai_api_key,
|
||||
chat_completion_request=data,
|
||||
response = run_async_task(
|
||||
openai_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=model_settings.openai_api_key,
|
||||
chat_completion_request=data,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
@@ -344,9 +349,12 @@ def create(
|
||||
stream_interface.stream_start()
|
||||
try:
|
||||
# groq uses the openai chat completions API, so this component should be reusable
|
||||
response = openai_chat_completions_request(
|
||||
api_key=model_settings.groq_api_key,
|
||||
chat_completion_request=data,
|
||||
response = run_async_task(
|
||||
openai_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=model_settings.groq_api_key,
|
||||
chat_completion_request=data,
|
||||
)
|
||||
)
|
||||
finally:
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import warnings
|
||||
from typing import Generator, List, Optional, Union
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
import requests
|
||||
from openai import OpenAI
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST
|
||||
@@ -158,7 +158,7 @@ def build_openai_chat_completions_request(
|
||||
return data
|
||||
|
||||
|
||||
def openai_chat_completions_process_stream(
|
||||
async def openai_chat_completions_process_stream(
|
||||
url: str,
|
||||
api_key: str,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
@@ -229,9 +229,10 @@ def openai_chat_completions_process_stream(
|
||||
stream_interface.stream_start()
|
||||
|
||||
n_chunks = 0 # approx == n_tokens
|
||||
chunk_idx = 0
|
||||
try:
|
||||
for chunk_idx, chat_completion_chunk in enumerate(
|
||||
openai_chat_completions_request_stream(url=url, api_key=api_key, chat_completion_request=chat_completion_request)
|
||||
async for chat_completion_chunk in openai_chat_completions_request_stream(
|
||||
url=url, api_key=api_key, chat_completion_request=chat_completion_request
|
||||
):
|
||||
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
|
||||
|
||||
@@ -348,6 +349,7 @@ def openai_chat_completions_process_stream(
|
||||
|
||||
# increment chunk counter
|
||||
n_chunks += 1
|
||||
chunk_idx += 1
|
||||
|
||||
except Exception as e:
|
||||
if stream_interface:
|
||||
@@ -380,24 +382,24 @@ def openai_chat_completions_process_stream(
|
||||
return chat_completion_response
|
||||
|
||||
|
||||
def openai_chat_completions_request_stream(
|
||||
async def openai_chat_completions_request_stream(
|
||||
url: str,
|
||||
api_key: str,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
) -> Generator[ChatCompletionChunkResponse, None, None]:
|
||||
) -> AsyncGenerator[ChatCompletionChunkResponse, None]:
|
||||
data = prepare_openai_payload(chat_completion_request)
|
||||
data["stream"] = True
|
||||
client = OpenAI(
|
||||
client = AsyncOpenAI(
|
||||
api_key=api_key,
|
||||
base_url=url,
|
||||
)
|
||||
stream = client.chat.completions.create(**data)
|
||||
for chunk in stream:
|
||||
stream = await client.chat.completions.create(**data)
|
||||
async for chunk in stream:
|
||||
# TODO: Use the native OpenAI objects here?
|
||||
yield ChatCompletionChunkResponse(**chunk.model_dump(exclude_none=True))
|
||||
|
||||
|
||||
def openai_chat_completions_request(
|
||||
async def openai_chat_completions_request(
|
||||
url: str,
|
||||
api_key: str,
|
||||
chat_completion_request: ChatCompletionRequest,
|
||||
@@ -410,8 +412,8 @@ def openai_chat_completions_request(
|
||||
https://platform.openai.com/docs/guides/text-generation?lang=curl
|
||||
"""
|
||||
data = prepare_openai_payload(chat_completion_request)
|
||||
client = OpenAI(api_key=api_key, base_url=url)
|
||||
chat_completion = client.chat.completions.create(**data)
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=url)
|
||||
chat_completion = await client.chat.completions.create(**data)
|
||||
return ChatCompletionResponse(**chat_completion.model_dump())
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import copy
|
||||
import difflib
|
||||
import hashlib
|
||||
@@ -15,7 +16,7 @@ import uuid
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from functools import wraps
|
||||
from typing import List, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import demjson3 as demjson
|
||||
@@ -1127,3 +1128,25 @@ def get_friendly_error_msg(function_name: str, exception_name: str, exception_me
|
||||
if len(error_msg) > MAX_ERROR_MESSAGE_CHAR_LIMIT:
|
||||
error_msg = error_msg[:MAX_ERROR_MESSAGE_CHAR_LIMIT]
|
||||
return error_msg
|
||||
|
||||
|
||||
def run_async_task(coro: Coroutine[Any, Any, Any]) -> Any:
|
||||
"""
|
||||
Safely runs an asynchronous coroutine in a synchronous context.
|
||||
|
||||
If an event loop is already running, it uses `asyncio.ensure_future`.
|
||||
Otherwise, it creates a new event loop and runs the coroutine.
|
||||
|
||||
Args:
|
||||
coro: The coroutine to execute.
|
||||
|
||||
Returns:
|
||||
The result of the coroutine.
|
||||
"""
|
||||
try:
|
||||
# If there's already a running event loop, schedule the coroutine
|
||||
loop = asyncio.get_running_loop()
|
||||
return asyncio.run_until_complete(coro) if loop.is_closed() else asyncio.ensure_future(coro)
|
||||
except RuntimeError:
|
||||
# If no event loop is running, create a new one
|
||||
return asyncio.run(coro)
|
||||
|
||||
Reference in New Issue
Block a user