feat: Fast chat completions endpoint (#1091)
This commit is contained in:
@@ -750,7 +750,6 @@ class Agent(BaseAgent):
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> AgentStepResponse:
|
||||
"""Runs a single step in the agent loop (generates at most one LLM call)"""
|
||||
|
||||
try:
|
||||
|
||||
# Extract job_id from metadata if present
|
||||
|
||||
@@ -161,10 +161,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
Called externally with a ChatCompletionChunkResponse. Transforms
|
||||
it if necessary, then enqueues partial messages for streaming back.
|
||||
"""
|
||||
# print(chunk)
|
||||
processed_chunk = self._process_chunk_to_openai_style(chunk)
|
||||
# print(processed_chunk)
|
||||
# print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++")
|
||||
if processed_chunk is not None:
|
||||
self._push_to_buffer(processed_chunk)
|
||||
|
||||
|
||||
@@ -1,21 +1,30 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, List, Optional, Union
|
||||
|
||||
import httpx
|
||||
import openai
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
from starlette.concurrency import run_in_threadpool
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import Message
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
|
||||
|
||||
# TODO this belongs in a controller!
|
||||
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
|
||||
from letta.server.rest_api.utils import (
|
||||
convert_letta_messages_to_openai,
|
||||
create_assistant_message_from_openai_response,
|
||||
create_user_message,
|
||||
get_letta_server,
|
||||
get_messages_from_completion_request,
|
||||
sse_async_generator,
|
||||
)
|
||||
from letta.settings import model_settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
@@ -25,6 +34,88 @@ router = APIRouter(prefix="/v1", tags=["chat_completions"])
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/fast/chat/completions",
|
||||
response_model=None,
|
||||
operation_id="create_fast_chat_completions",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def create_fast_chat_completions(
|
||||
completion_request: CompletionCreateParams = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
# TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior
|
||||
agent_id = str(completion_request.get("user", None))
|
||||
if agent_id is None:
|
||||
error_msg = "Must pass agent_id in the 'user' field"
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
model = completion_request.get("model")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
client = openai.AsyncClient(
|
||||
api_key=model_settings.openai_api_key,
|
||||
max_retries=0,
|
||||
http_client=httpx.AsyncClient(
|
||||
timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0),
|
||||
follow_redirects=True,
|
||||
limits=httpx.Limits(
|
||||
max_connections=50,
|
||||
max_keepalive_connections=50,
|
||||
keepalive_expiry=120,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
# Magic message manipulating
|
||||
input_message = get_messages_from_completion_request(completion_request)[-1]
|
||||
completion_request.pop("messages")
|
||||
|
||||
# Get in context messages
|
||||
in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor)
|
||||
openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages)
|
||||
openai_dict_in_context_messages.append(input_message)
|
||||
|
||||
async def event_stream():
|
||||
# TODO: Factor this out into separate interface
|
||||
response_accumulator = []
|
||||
|
||||
stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages)
|
||||
|
||||
async with stream:
|
||||
async for chunk in stream:
|
||||
if chunk.choices and chunk.choices[0].delta.content:
|
||||
# TODO: This does not support tool calling right now
|
||||
response_accumulator.append(chunk.choices[0].delta.content)
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
# Construct messages
|
||||
user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor)
|
||||
assistant_message = create_assistant_message_from_openai_response(
|
||||
response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor
|
||||
)
|
||||
|
||||
# Persist both in one synchronous DB call, done in a threadpool
|
||||
await run_in_threadpool(
|
||||
server.agent_manager.append_to_in_context_messages,
|
||||
[user_message, assistant_message],
|
||||
agent_id=agent_id,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
yield "data: [DONE]\n\n"
|
||||
|
||||
return StreamingResponse(event_stream(), media_type="text/event-stream")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completions",
|
||||
response_model=None,
|
||||
@@ -44,26 +135,8 @@ async def create_chat_completions(
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
# Validate and process fields
|
||||
try:
|
||||
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
|
||||
except KeyError:
|
||||
# Handle the case where "messages" is not present in the request
|
||||
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
|
||||
except TypeError:
|
||||
# Handle the case where "messages" is not iterable
|
||||
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors and include the exception message
|
||||
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
|
||||
|
||||
if messages[-1]["role"] != "user":
|
||||
logger.error(f"The last message does not have a `user` role: {messages}")
|
||||
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
|
||||
|
||||
messages = get_messages_from_completion_request(completion_request)
|
||||
input_message = messages[-1]
|
||||
if not isinstance(input_message["content"], str):
|
||||
logger.error(f"The input message does not have valid content: {input_message}")
|
||||
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
|
||||
|
||||
# Process remaining fields
|
||||
if not completion_request["stream"]:
|
||||
|
||||
@@ -1,23 +1,33 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
import warnings
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union
|
||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast
|
||||
|
||||
from fastapi import Header
|
||||
import pytz
|
||||
from fastapi import Header, HTTPException
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.errors import ContextWindowExceededError, RateLimitExceededError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# from letta.orm.user import User
|
||||
# from letta.orm.utilities import get_db_session
|
||||
|
||||
SSE_PREFIX = "data: "
|
||||
SSE_SUFFIX = "\n\n"
|
||||
@@ -128,3 +138,172 @@ def log_error_to_sentry(e):
|
||||
import sentry_sdk
|
||||
|
||||
sentry_sdk.capture_exception(e)
|
||||
|
||||
|
||||
def create_user_message(input_message: dict, agent_id: str, actor: User) -> Message:
|
||||
"""
|
||||
Converts a user input message into the internal structured format.
|
||||
"""
|
||||
# Generate timestamp in the correct format
|
||||
now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
# Format message as structured JSON
|
||||
structured_message = {"type": "user_message", "message": input_message["content"], "time": now}
|
||||
|
||||
# Construct the Message object
|
||||
user_message = Message(
|
||||
id=f"message-{uuid.uuid4()}",
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=json.dumps(structured_message, indent=2))], # Store structured JSON
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=None,
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
return user_message
|
||||
|
||||
|
||||
def create_assistant_message_from_openai_response(
|
||||
response_text: str,
|
||||
agent_id: str,
|
||||
model: str,
|
||||
actor: User,
|
||||
) -> Message:
|
||||
"""
|
||||
Converts an OpenAI response into a Message that follows the internal
|
||||
paradigm where LLM responses are structured as tool calls instead of content.
|
||||
"""
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
|
||||
# Construct the tool call with the assistant's message
|
||||
tool_call = OpenAIToolCall(
|
||||
id=tool_call_id,
|
||||
function=OpenAIFunction(
|
||||
name=DEFAULT_MESSAGE_TOOL,
|
||||
arguments='{\n "' + DEFAULT_MESSAGE_TOOL_KWARG + '": ' + f'"{response_text}",\n "request_heartbeat": true\n' + "}",
|
||||
),
|
||||
type="function",
|
||||
)
|
||||
|
||||
# Construct the Message object
|
||||
assistant_message = Message(
|
||||
id=f"message-{uuid.uuid4()}",
|
||||
role=MessageRole.assistant,
|
||||
content=[],
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_calls=[tool_call],
|
||||
tool_call_id=None,
|
||||
created_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
return assistant_message
|
||||
|
||||
|
||||
def convert_letta_messages_to_openai(messages: List[Message]) -> List[dict]:
|
||||
"""
|
||||
Flattens Letta's messages (with system, user, assistant, tool roles, etc.)
|
||||
into standard OpenAI chat messages (system, user, assistant).
|
||||
|
||||
Transformation rules:
|
||||
1. Assistant + send_message tool_call => content = tool_call's "message"
|
||||
2. Tool (role=tool) referencing send_message => skip
|
||||
3. User messages might store actual text inside JSON => parse that into content
|
||||
4. System => pass through as normal
|
||||
"""
|
||||
|
||||
openai_messages = []
|
||||
|
||||
for msg in messages:
|
||||
# 1. Assistant + 'send_message' tool_calls => flatten
|
||||
if msg.role == MessageRole.assistant and msg.tool_calls:
|
||||
# Find any 'send_message' tool_calls
|
||||
send_message_calls = [tc for tc in msg.tool_calls if tc.function.name == "send_message"]
|
||||
if send_message_calls:
|
||||
# If we have multiple calls, just pick the first or merge them
|
||||
# Typically there's only one.
|
||||
tc = send_message_calls[0]
|
||||
arguments = json.loads(tc.function.arguments)
|
||||
# Extract the "message" string
|
||||
extracted_text = arguments.get("message", "")
|
||||
|
||||
# Create a new content with the extracted text
|
||||
msg = Message(
|
||||
id=msg.id,
|
||||
role=msg.role,
|
||||
content=[TextContent(text=extracted_text)],
|
||||
organization_id=msg.organization_id,
|
||||
agent_id=msg.agent_id,
|
||||
model=msg.model,
|
||||
name=msg.name,
|
||||
tool_calls=None, # no longer needed
|
||||
tool_call_id=None,
|
||||
created_at=msg.created_at,
|
||||
)
|
||||
|
||||
# 2. If role=tool and it's referencing send_message => skip
|
||||
if msg.role == MessageRole.tool and msg.name == "send_message":
|
||||
# Usually 'tool' messages with `send_message` are just status/OK messages
|
||||
# that OpenAI doesn't need to see. So skip them.
|
||||
continue
|
||||
|
||||
# 3. User messages might store text in JSON => parse it
|
||||
if msg.role == MessageRole.user:
|
||||
# Example: content=[TextContent(text='{"type": "user_message","message":"Hello"}')]
|
||||
# Attempt to parse JSON and extract "message"
|
||||
if msg.content and msg.content[0].text.strip().startswith("{"):
|
||||
try:
|
||||
parsed = json.loads(msg.content[0].text)
|
||||
# If there's a "message" field, use that as the content
|
||||
if "message" in parsed:
|
||||
actual_user_text = parsed["message"]
|
||||
msg = Message(
|
||||
id=msg.id,
|
||||
role=msg.role,
|
||||
content=[TextContent(text=actual_user_text)],
|
||||
organization_id=msg.organization_id,
|
||||
agent_id=msg.agent_id,
|
||||
model=msg.model,
|
||||
name=msg.name,
|
||||
tool_calls=msg.tool_calls,
|
||||
tool_call_id=msg.tool_call_id,
|
||||
created_at=msg.created_at,
|
||||
)
|
||||
except json.JSONDecodeError:
|
||||
pass # It's not JSON, leave as-is
|
||||
|
||||
# 4. System is left as-is (or any other role that doesn't need special handling)
|
||||
#
|
||||
# Finally, convert to dict using your existing method
|
||||
openai_messages.append(msg.to_openai_dict())
|
||||
|
||||
return openai_messages
|
||||
|
||||
|
||||
def get_messages_from_completion_request(completion_request: CompletionCreateParams) -> List[Dict]:
|
||||
try:
|
||||
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
|
||||
except KeyError:
|
||||
# Handle the case where "messages" is not present in the request
|
||||
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
|
||||
except TypeError:
|
||||
# Handle the case where "messages" is not iterable
|
||||
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors and include the exception message
|
||||
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
|
||||
|
||||
if messages[-1]["role"] != "user":
|
||||
logger.error(f"The last message does not have a `user` role: {messages}")
|
||||
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
|
||||
|
||||
input_message = messages[-1]
|
||||
if not isinstance(input_message["content"], str):
|
||||
logger.error(f"The input message does not have valid content: {input_message}")
|
||||
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
|
||||
|
||||
return messages
|
||||
|
||||
@@ -112,13 +112,12 @@ def _assert_valid_chunk(chunk, idx, chunks):
|
||||
|
||||
|
||||
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."])
|
||||
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message):
|
||||
@pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"])
|
||||
def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint):
|
||||
"""Tests chat completion streaming via SSE."""
|
||||
request = _get_chat_request(agent.id, message)
|
||||
|
||||
response = _sse_post(
|
||||
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
|
||||
)
|
||||
response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers)
|
||||
|
||||
try:
|
||||
chunks = list(response)
|
||||
|
||||
Reference in New Issue
Block a user