Files
letta-server/letta/server/rest_api/routers/v1/chat_completions.py
Matthew Zhou 2dae4d33c3 feat: Implement streaming chat completions endpoint [LET-5485] (#5446)
* wip

* Add chat completions router and fix streaming service

* Finish chat completions

* Finish chat completions

* Remove extra print statement

* Run just api

* Don't explicitly throw http exceptions but surface Letta errors

* Remap errors

* Trigger CI

* Add missing Optional import
2025-10-24 15:12:11 -07:00

147 lines
5.7 KiB
Python

from typing import Optional, Union
from fastapi import APIRouter, Body, Depends
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from pydantic import BaseModel, Field
from letta.errors import LettaInvalidArgumentError
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_request import LettaStreamingRequest
from letta.schemas.message import MessageCreate
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.server import SyncServer
from letta.services.streaming_service import StreamingService
logger = get_logger(__name__)
router = APIRouter(tags=["chat"])
class ChatCompletionRequest(BaseModel):
"""OpenAI-compatible chat completion request - exactly matching OpenAI's schema."""
model: str = Field(..., description="ID of the model to use")
messages: list[ChatCompletionMessageParam] = Field(..., description="Messages comprising the conversation so far")
# optional parameters
temperature: Optional[float] = Field(None, ge=0, le=2, description="Sampling temperature")
top_p: Optional[float] = Field(None, ge=0, le=1, description="Nucleus sampling parameter")
n: Optional[int] = Field(1, ge=1, description="Number of chat completion choices to generate")
stream: Optional[bool] = Field(False, description="Whether to stream back partial progress")
stop: Optional[Union[str, list[str]]] = Field(None, description="Sequences where the API will stop generating")
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
presence_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Presence penalty")
frequency_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Frequency penalty")
user: Optional[str] = Field(None, description="A unique identifier representing your end-user")
async def _handle_chat_completion(
request: ChatCompletionRequest,
server: SyncServer,
headers: HeaderParams,
) -> Union[ChatCompletion, StreamingResponse]:
"""
Internal handler for chat completion logic.
Args:
request: OpenAI-compatible chat completion request
server: Letta server instance
headers: Request headers with user info
Returns:
Streaming or non-streaming chat completion response
"""
if request.user:
actor = await server.user_manager.get_actor_or_default_async(actor_id=request.user)
else:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
resolved_agent_id = request.model
if not resolved_agent_id.startswith("agent-"):
raise LettaInvalidArgumentError(
f"For this endpoint, the 'model' field should contain an agent ID (format: 'agent-...'). Received: '{resolved_agent_id}'",
argument_name="model",
)
await server.agent_manager.validate_agent_exists_async(resolved_agent_id, actor)
# convert OpenAI messages to Letta MessageCreate format
# NOTE: we only process the last user message
if len(request.messages) > 1:
logger.warning(
f"Chat completions endpoint received {len(request.messages)} messages. "
"Letta maintains conversation state internally, so only the last user message will be processed. "
"Previous messages are already stored in the agent's memory."
)
last_user_message = None
for msg in reversed(request.messages):
role = msg.get("role", "user")
if role == "user":
last_user_message = msg
break
if not last_user_message:
raise LettaInvalidArgumentError(
"No user message found in the request. Please include at least one message with role='user'.",
argument_name="messages",
)
letta_messages = [
MessageCreate(
role=MessageRole.user,
content=last_user_message.get("content", ""),
)
]
letta_request = LettaStreamingRequest(
messages=letta_messages,
stream_tokens=True,
)
if request.stream:
streaming_service = StreamingService(server)
return await streaming_service.create_agent_stream_openai_chat_completions(
agent_id=resolved_agent_id,
actor=actor,
request=letta_request,
)
else:
raise LettaInvalidArgumentError(
"Non-streaming chat completions not yet implemented. Please set stream=true.",
argument_name="stream",
)
@router.post(
"/chat/completions",
response_model=ChatCompletion,
responses={
200: {
"description": "Successful response",
"content": {
"application/json": {"schema": {"$ref": "#/components/schemas/ChatCompletion"}},
"text/event-stream": {"description": "Server-Sent Events stream (when stream=true)"},
},
}
},
operation_id="create_chat_completion",
)
async def create_chat_completion(
request: ChatCompletionRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
) -> Union[ChatCompletion, StreamingResponse]:
"""
Create a chat completion using a Letta agent (OpenAI-compatible).
This endpoint provides full OpenAI API compatibility. The agent is selected based on:
- The 'model' parameter in the request (should contain an agent ID in format 'agent-...')
When streaming is enabled (stream=true), the response will be Server-Sent Events
with ChatCompletionChunk objects.
"""
return await _handle_chat_completion(request, server, headers)