feat: Fast chat completions endpoint (#1091)

This commit is contained in:
Matthew Zhou
2025-02-20 18:55:25 -08:00
committed by GitHub
parent 4af770fd09
commit df1f9839a6
5 changed files with 283 additions and 36 deletions

View File

@@ -750,7 +750,6 @@ class Agent(BaseAgent):
put_inner_thoughts_first: bool = True,
) -> AgentStepResponse:
"""Runs a single step in the agent loop (generates at most one LLM call)"""
try:
# Extract job_id from metadata if present

View File

@@ -161,10 +161,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
Called externally with a ChatCompletionChunkResponse. Transforms
it if necessary, then enqueues partial messages for streaming back.
"""
# print(chunk)
processed_chunk = self._process_chunk_to_openai_style(chunk)
# print(processed_chunk)
# print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
if processed_chunk is not None:
self._push_to_buffer(processed_chunk)

View File

@@ -1,21 +1,30 @@
import asyncio
from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast
from typing import TYPE_CHECKING, List, Optional, Union
import httpx
import openai
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.completion_create_params import CompletionCreateParams
from starlette.concurrency import run_in_threadpool
from letta.agent import Agent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.schemas.message import MessageCreate
from letta.schemas.openai.chat_completion_response import Message
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
# TODO this belongs in a controller!
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
from letta.server.rest_api.utils import (
convert_letta_messages_to_openai,
create_assistant_message_from_openai_response,
create_user_message,
get_letta_server,
get_messages_from_completion_request,
sse_async_generator,
)
from letta.settings import model_settings
if TYPE_CHECKING:
from letta.server.server import SyncServer
@@ -25,6 +34,88 @@ router = APIRouter(prefix="/v1", tags=["chat_completions"])
logger = get_logger(__name__)
@router.post(
"/fast/chat/completions",
response_model=None,
operation_id="create_fast_chat_completions",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
)
async def create_fast_chat_completions(
completion_request: CompletionCreateParams = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
# TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior
agent_id = str(completion_request.get("user", None))
if agent_id is None:
error_msg = "Must pass agent_id in the 'user' field"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
model = completion_request.get("model")
actor = server.user_manager.get_user_or_default(user_id=user_id)
client = openai.AsyncClient(
api_key=model_settings.openai_api_key,
max_retries=0,
http_client=httpx.AsyncClient(
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
follow_redirects=True,
limits=httpx.Limits(
max_connections=50,
max_keepalive_connections=50,
keepalive_expiry=120,
),
),
)
# Magic message manipulating
input_message = get_messages_from_completion_request(completion_request)[-1]
completion_request.pop("messages")
# Get in context messages
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)
openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages)
openai_dict_in_context_messages.append(input_message)
async def event_stream():
# TODO: Factor this out into separate interface
response_accumulator = []
stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages)
async with stream:
async for chunk in stream:
if chunk.choices and chunk.choices[0].delta.content:
# TODO: This does not support tool calling right now
response_accumulator.append(chunk.choices[0].delta.content)
yield f"data: {chunk.model_dump_json()}\n\n"
# Construct messages
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
assistant_message = create_assistant_message_from_openai_response(
response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor
)
# Persist both in one synchronous DB call, done in a threadpool
await run_in_threadpool(
server.agent_manager.append_to_in_context_messages,
[user_message, assistant_message],
agent_id=agent_id,
actor=actor,
)
yield "data: [DONE]\n\n"
return StreamingResponse(event_stream(), media_type="text/event-stream")
@router.post(
"/chat/completions",
response_model=None,
@@ -44,26 +135,8 @@ async def create_chat_completions(
user_id: Optional[str] = Header(None, alias="user_id"),
):
# Validate and process fields
try:
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
except KeyError:
# Handle the case where "messages" is not present in the request
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
except TypeError:
# Handle the case where "messages" is not iterable
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
except Exception as e:
# Catch any other unexpected errors and include the exception message
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
if messages[-1]["role"] != "user":
logger.error(f"The last message does not have a `user` role: {messages}")
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
messages = get_messages_from_completion_request(completion_request)
input_message = messages[-1]
if not isinstance(input_message["content"], str):
logger.error(f"The input message does not have valid content: {input_message}")
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
# Process remaining fields
if not completion_request["stream"]:

View File

@@ -1,23 +1,33 @@
import asyncio
import json
import os
import uuid
import warnings
from datetime import datetime, timezone
from enum import Enum
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
from fastapi import Header
import pytz
from fastapi import Header, HTTPException
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
from openai.types.chat.completion_create_params import CompletionCreateParams
from pydantic import BaseModel
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.errors import ContextWindowExceededError, RateLimitExceededError
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import TextContent
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.interface import StreamingServerInterface
if TYPE_CHECKING:
from letta.server.server import SyncServer
# from letta.orm.user import User
# from letta.orm.utilities import get_db_session
SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n"
@@ -128,3 +138,172 @@ def log_error_to_sentry(e):
import sentry_sdk
sentry_sdk.capture_exception(e)
def create_user_message(input_message: dict, agent_id: str, actor: User) -> Message:
"""
Converts a user input message into the internal structured format.
"""
# Generate timestamp in the correct format
now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
# Format message as structured JSON
structured_message = {"type": "user_message", "message": input_message["content"], "time": now}
# Construct the Message object
user_message = Message(
id=f"message-{uuid.uuid4()}",
role=MessageRole.user,
content=[TextContent(text=json.dumps(structured_message, indent=2))], # Store structured JSON
organization_id=actor.organization_id,
agent_id=agent_id,
model=None,
tool_calls=None,
tool_call_id=None,
created_at=datetime.now(timezone.utc),
)
return user_message
def create_assistant_message_from_openai_response(
response_text: str,
agent_id: str,
model: str,
actor: User,
) -> Message:
"""
Converts an OpenAI response into a Message that follows the internal
paradigm where LLM responses are structured as tool calls instead of content.
"""
tool_call_id = str(uuid.uuid4())
# Construct the tool call with the assistant's message
tool_call = OpenAIToolCall(
id=tool_call_id,
function=OpenAIFunction(
name=DEFAULT_MESSAGE_TOOL,
arguments='{\n "' + DEFAULT_MESSAGE_TOOL_KWARG + '": ' + f'"{response_text}",\n "request_heartbeat": true\n' + "}",
),
type="function",
)
# Construct the Message object
assistant_message = Message(
id=f"message-{uuid.uuid4()}",
role=MessageRole.assistant,
content=[],
organization_id=actor.organization_id,
agent_id=agent_id,
model=model,
tool_calls=[tool_call],
tool_call_id=None,
created_at=datetime.now(timezone.utc),
)
return assistant_message
def convert_letta_messages_to_openai(messages: List[Message]) -> List[dict]:
"""
Flattens Letta's messages (with system, user, assistant, tool roles, etc.)
into standard OpenAI chat messages (system, user, assistant).
Transformation rules:
1. Assistant + send_message tool_call => content = tool_call's "message"
2. Tool (role=tool) referencing send_message => skip
3. User messages might store actual text inside JSON => parse that into content
4. System => pass through as normal
"""
openai_messages = []
for msg in messages:
# 1. Assistant + 'send_message' tool_calls => flatten
if msg.role == MessageRole.assistant and msg.tool_calls:
# Find any 'send_message' tool_calls
send_message_calls = [tc for tc in msg.tool_calls if tc.function.name == "send_message"]
if send_message_calls:
# If we have multiple calls, just pick the first or merge them
# Typically there's only one.
tc = send_message_calls[0]
arguments = json.loads(tc.function.arguments)
# Extract the "message" string
extracted_text = arguments.get("message", "")
# Create a new content with the extracted text
msg = Message(
id=msg.id,
role=msg.role,
content=[TextContent(text=extracted_text)],
organization_id=msg.organization_id,
agent_id=msg.agent_id,
model=msg.model,
name=msg.name,
tool_calls=None, # no longer needed
tool_call_id=None,
created_at=msg.created_at,
)
# 2. If role=tool and it's referencing send_message => skip
if msg.role == MessageRole.tool and msg.name == "send_message":
# Usually 'tool' messages with `send_message` are just status/OK messages
# that OpenAI doesn't need to see. So skip them.
continue
# 3. User messages might store text in JSON => parse it
if msg.role == MessageRole.user:
# Example: content=[TextContent(text='{"type": "user_message","message":"Hello"}')]
# Attempt to parse JSON and extract "message"
if msg.content and msg.content[0].text.strip().startswith("{"):
try:
parsed = json.loads(msg.content[0].text)
# If there's a "message" field, use that as the content
if "message" in parsed:
actual_user_text = parsed["message"]
msg = Message(
id=msg.id,
role=msg.role,
content=[TextContent(text=actual_user_text)],
organization_id=msg.organization_id,
agent_id=msg.agent_id,
model=msg.model,
name=msg.name,
tool_calls=msg.tool_calls,
tool_call_id=msg.tool_call_id,
created_at=msg.created_at,
)
except json.JSONDecodeError:
pass # It's not JSON, leave as-is
# 4. System is left as-is (or any other role that doesn't need special handling)
#
# Finally, convert to dict using your existing method
openai_messages.append(msg.to_openai_dict())
return openai_messages
def get_messages_from_completion_request(completion_request: CompletionCreateParams) -> List[Dict]:
try:
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
except KeyError:
# Handle the case where "messages" is not present in the request
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
except TypeError:
# Handle the case where "messages" is not iterable
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
except Exception as e:
# Catch any other unexpected errors and include the exception message
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
if messages[-1]["role"] != "user":
logger.error(f"The last message does not have a `user` role: {messages}")
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
input_message = messages[-1]
if not isinstance(input_message["content"], str):
logger.error(f"The input message does not have valid content: {input_message}")
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
return messages

View File

@@ -112,13 +112,12 @@ def _assert_valid_chunk(chunk, idx, chunks):
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message):
@pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"])
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint):
"""Tests chat completion streaming via SSE."""
request = _get_chat_request(agent.id, message)
response = _sse_post(
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
)
response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers)
try:
chunks = list(response)