feat: Implement streaming chat completions endpoint [LET-5485] (#5446)

* wip

* Add chat completions router and fix streaming service

* Finish chat completions

* Finish chat completions

* Remove extra print statement

* Run just api

* Don't explicitly throw http exceptions but surface Letta errors

* Remap errors

* Trigger CI

* Add missing Optional import
This commit is contained in:
Matthew Zhou
2025-10-15 11:03:48 -07:00
committed by Caren Thomas
parent 714978f5ee
commit 2dae4d33c3
9 changed files with 1721 additions and 54 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -142,6 +142,14 @@ class LettaMCPTimeoutError(LettaMCPError):
super().__init__(message=message, code=ErrorCode.TIMEOUT, details=details)
class LettaServiceUnavailableError(LettaError):
"""Error raised when a required service is unavailable."""
def __init__(self, message: str, service_name: Optional[str] = None):
details = {"service_name": service_name} if service_name else {}
super().__init__(message=message, code=ErrorCode.INTERNAL_SERVER_ERROR, details=details)
class LettaUnexpectedStreamCancellationError(LettaError):
"""Error raised when a streaming request is terminated unexpectedly."""

View File

@@ -31,6 +31,7 @@ from letta.errors import (
LettaInvalidMCPSchemaError,
LettaMCPConnectionError,
LettaMCPTimeoutError,
LettaServiceUnavailableError,
LettaToolCreateError,
LettaToolNameConflictError,
LettaUnsupportedFileUploadError,
@@ -39,6 +40,7 @@ from letta.errors import (
LLMError,
LLMRateLimitError,
LLMTimeoutError,
PendingApprovalError,
)
from letta.helpers.pinecone_utils import get_pinecone_indices, should_use_pinecone, upsert_pinecone_indices
from letta.jobs.scheduler import start_scheduler_with_leader_election
@@ -274,6 +276,7 @@ def create_application() -> "FastAPI":
app.add_exception_handler(ForeignKeyConstraintViolationError, _error_handler_409)
app.add_exception_handler(UniqueConstraintViolationError, _error_handler_409)
app.add_exception_handler(IntegrityError, _error_handler_409)
app.add_exception_handler(PendingApprovalError, _error_handler_409)
# 415 Unsupported Media Type errors
app.add_exception_handler(LettaUnsupportedFileUploadError, _error_handler_415)
@@ -287,6 +290,7 @@ def create_application() -> "FastAPI":
# 503 Service Unavailable errors
app.add_exception_handler(OperationalError, _error_handler_503)
app.add_exception_handler(LettaServiceUnavailableError, _error_handler_503)
@app.exception_handler(IncompatibleAgentType)
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):

View File

@@ -1,6 +1,7 @@
from letta.server.rest_api.routers.v1.agents import router as agents_router
from letta.server.rest_api.routers.v1.archives import router as archives_router
from letta.server.rest_api.routers.v1.blocks import router as blocks_router
from letta.server.rest_api.routers.v1.chat_completions import router as chat_completions_router, router as openai_chat_completions_router
from letta.server.rest_api.routers.v1.embeddings import router as embeddings_router
from letta.server.rest_api.routers.v1.folders import router as folders_router
from letta.server.rest_api.routers.v1.groups import router as groups_router
@@ -26,6 +27,7 @@ ROUTERS = [
sources_router,
folders_router,
agents_router,
chat_completions_router,
groups_router,
identities_router,
internal_templates_router,
@@ -42,4 +44,5 @@ ROUTERS = [
messages_router,
voice_router,
embeddings_router,
openai_chat_completions_router,
]

View File

@@ -0,0 +1,146 @@
from typing import Optional, Union
from fastapi import APIRouter, Body, Depends
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletion
from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
from pydantic import BaseModel, Field
from letta.errors import LettaInvalidArgumentError
from letta.log import get_logger
from letta.schemas.enums import MessageRole
from letta.schemas.letta_request import LettaStreamingRequest
from letta.schemas.message import MessageCreate
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.server import SyncServer
from letta.services.streaming_service import StreamingService
logger = get_logger(__name__)
router = APIRouter(tags=["chat"])
class ChatCompletionRequest(BaseModel):
"""OpenAI-compatible chat completion request - exactly matching OpenAI's schema."""
model: str = Field(..., description="ID of the model to use")
messages: list[ChatCompletionMessageParam] = Field(..., description="Messages comprising the conversation so far")
# optional parameters
temperature: Optional[float] = Field(None, ge=0, le=2, description="Sampling temperature")
top_p: Optional[float] = Field(None, ge=0, le=1, description="Nucleus sampling parameter")
n: Optional[int] = Field(1, ge=1, description="Number of chat completion choices to generate")
stream: Optional[bool] = Field(False, description="Whether to stream back partial progress")
stop: Optional[Union[str, list[str]]] = Field(None, description="Sequences where the API will stop generating")
max_tokens: Optional[int] = Field(None, description="Maximum number of tokens to generate")
presence_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Presence penalty")
frequency_penalty: Optional[float] = Field(None, ge=-2, le=2, description="Frequency penalty")
user: Optional[str] = Field(None, description="A unique identifier representing your end-user")
async def _handle_chat_completion(
request: ChatCompletionRequest,
server: SyncServer,
headers: HeaderParams,
) -> Union[ChatCompletion, StreamingResponse]:
"""
Internal handler for chat completion logic.
Args:
request: OpenAI-compatible chat completion request
server: Letta server instance
headers: Request headers with user info
Returns:
Streaming or non-streaming chat completion response
"""
if request.user:
actor = await server.user_manager.get_actor_or_default_async(actor_id=request.user)
else:
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
resolved_agent_id = request.model
if not resolved_agent_id.startswith("agent-"):
raise LettaInvalidArgumentError(
f"For this endpoint, the 'model' field should contain an agent ID (format: 'agent-...'). Received: '{resolved_agent_id}'",
argument_name="model",
)
await server.agent_manager.validate_agent_exists_async(resolved_agent_id, actor)
# convert OpenAI messages to Letta MessageCreate format
# NOTE: we only process the last user message
if len(request.messages) > 1:
logger.warning(
f"Chat completions endpoint received {len(request.messages)} messages. "
"Letta maintains conversation state internally, so only the last user message will be processed. "
"Previous messages are already stored in the agent's memory."
)
last_user_message = None
for msg in reversed(request.messages):
role = msg.get("role", "user")
if role == "user":
last_user_message = msg
break
if not last_user_message:
raise LettaInvalidArgumentError(
"No user message found in the request. Please include at least one message with role='user'.",
argument_name="messages",
)
letta_messages = [
MessageCreate(
role=MessageRole.user,
content=last_user_message.get("content", ""),
)
]
letta_request = LettaStreamingRequest(
messages=letta_messages,
stream_tokens=True,
)
if request.stream:
streaming_service = StreamingService(server)
return await streaming_service.create_agent_stream_openai_chat_completions(
agent_id=resolved_agent_id,
actor=actor,
request=letta_request,
)
else:
raise LettaInvalidArgumentError(
"Non-streaming chat completions not yet implemented. Please set stream=true.",
argument_name="stream",
)
@router.post(
"/chat/completions",
response_model=ChatCompletion,
responses={
200: {
"description": "Successful response",
"content": {
"application/json": {"schema": {"$ref": "#/components/schemas/ChatCompletion"}},
"text/event-stream": {"description": "Server-Sent Events stream (when stream=true)"},
},
}
},
operation_id="create_chat_completion",
)
async def create_chat_completion(
request: ChatCompletionRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
) -> Union[ChatCompletion, StreamingResponse]:
"""
Create a chat completion using a Letta agent (OpenAI-compatible).
This endpoint provides full OpenAI API compatibility. The agent is selected based on:
- The 'model' parameter in the request (should contain an agent ID in format 'agent-...')
When streaming is enabled (stream=true), the response will be Server-Sent Events
with ChatCompletionChunk objects.
"""
return await _handle_chat_completion(request, server, headers)

View File

@@ -1047,6 +1047,23 @@ class AgentManager:
archive_ids = [row[0] for row in result.fetchall()]
return archive_ids
@enforce_types
@trace_method
async def validate_agent_exists_async(self, agent_id: str, actor: PydanticUser) -> None:
"""
Validate that an agent exists and user has access to it.
Lightweight method that doesn't load the full agent object.
Args:
agent_id: ID of the agent to validate
actor: User performing the action
Raises:
LettaAgentNotFoundError: If agent doesn't exist or user doesn't have access
"""
async with db_registry.async_session() as session:
await validate_agent_exists_async(session, agent_id, actor)
@enforce_types
@trace_method
async def delete_agent_async(self, agent_id: str, actor: PydanticUser) -> None:

View File

@@ -1,36 +1,42 @@
"""
Streaming service for handling agent message streaming with various formats.
Provides a unified interface for streaming agent responses with support for
different output formats (Letta native, OpenAI-compatible, etc.)
"""
import asyncio
import json
import time
from typing import AsyncIterator, Optional, Union
from uuid import uuid4
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
from letta.agents.agent_loop import AgentLoop
from letta.agents.base_agent_v2 import BaseAgentV2
from letta.constants import REDIS_RUN_ID_PREFIX
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LLMAuthenticationError, LLMError, LLMRateLimitError, LLMTimeoutError, PendingApprovalError
from letta.errors import (
LettaInvalidArgumentError,
LettaServiceUnavailableError,
LLMAuthenticationError,
LLMError,
LLMRateLimitError,
LLMTimeoutError,
PendingApprovalError,
)
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.schemas.agent import AgentState
from letta.schemas.enums import AgentType, RunStatus
from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus
from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message import AssistantMessage, MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import MessageCreate
from letta.schemas.run import Run as PydanticRun, RunUpdate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
from letta.services.lettuce import LettuceClient
from letta.services.run_manager import RunManager
from letta.settings import settings
from letta.utils import safe_create_task
@@ -116,13 +122,11 @@ class StreamingService:
# handle background streaming if requested
if request.background and settings.track_agent_run:
if isinstance(redis_client, NoopAsyncRedisClient):
raise HTTPException(
status_code=503,
detail=(
"Background streaming requires Redis to be running. "
"Please ensure Redis is properly configured. "
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
),
raise LettaServiceUnavailableError(
f"Background streaming requires Redis to be running. "
f"Please ensure Redis is properly configured. "
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}",
service_name="redis",
)
safe_create_task(
@@ -176,9 +180,7 @@ class StreamingService:
if settings.track_agent_run:
run_update_metadata = {"error": str(e)}
run_status = RunStatus.failed
raise HTTPException(
status_code=409, detail={"code": "PENDING_APPROVAL", "message": str(e), "pending_request_id": e.pending_request_id}
)
raise
except Exception as e:
if settings.track_agent_run:
run_update_metadata = {"error": str(e)}
@@ -192,9 +194,67 @@ class StreamingService:
actor=actor,
)
async def create_agent_stream_openai_chat_completions(
self,
agent_id: str,
actor: User,
request: LettaStreamingRequest,
) -> StreamingResponse:
"""
Create OpenAI-compatible chat completions streaming response.
Transforms Letta's internal streaming format to match OpenAI's
ChatCompletionChunk schema, filtering out internal tool execution
and only streaming assistant text responses.
Args:
agent_id: The agent ID to stream from
actor: The user making the request
request: The LettaStreamingRequest containing all request parameters
Returns:
StreamingResponse with OpenAI-formatted SSE chunks
"""
# load agent to get model info for the completion chunks
agent = await self.server.agent_manager.get_agent_by_id_async(agent_id, actor)
# create standard Letta stream (returns SSE-formatted stream)
run, letta_stream_response = await self.create_agent_stream(
agent_id=agent_id,
actor=actor,
request=request,
run_type="openai_chat_completions",
)
# extract the stream iterator from the response
if isinstance(letta_stream_response, StreamingResponseWithStatusCode):
letta_stream = letta_stream_response.body_iterator
else:
raise LettaInvalidArgumentError(
"Agent is not compatible with streaming mode",
argument_name="model",
)
# create transformer with agent's model info
model_name = agent.llm_config.model if agent.llm_config else "unknown"
completion_id = f"chatcmpl-{run.id if run else str(uuid4())}"
transformer = OpenAIChatCompletionsStreamTransformer(
model=model_name,
completion_id=completion_id,
)
# transform Letta SSE stream to OpenAI format (parser handles SSE strings)
openai_stream = transformer.transform_stream(letta_stream)
return StreamingResponse(
openai_stream,
media_type="text/event-stream",
)
def _create_error_aware_stream(
self,
agent_loop: AgentLoop,
agent_loop: BaseAgentV2,
messages: list[MessageCreate],
max_steps: int,
stream_tokens: bool,
@@ -223,6 +283,7 @@ class StreamingService:
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=include_return_message_types,
)
async for chunk in stream:
yield chunk
@@ -336,3 +397,185 @@ class StreamingService:
update=update,
actor=actor,
)
class OpenAIChatCompletionsStreamTransformer:
"""
Transforms Letta streaming messages into OpenAI ChatCompletionChunk format.
Filters out internal tool execution and only streams assistant text responses.
"""
def __init__(self, model: str, completion_id: str):
"""
Initialize the transformer.
Args:
model: Model name to include in chunks
completion_id: Unique ID for this completion (format: chatcmpl-{uuid})
"""
self.model = model
self.completion_id = completion_id
self.first_chunk = True
self.created = int(time.time())
# TODO: This is lowkey really ugly and poor code design, but this works fine for now
def _parse_sse_chunk(self, sse_string: str):
"""
Parse SSE-formatted string back into a message object.
Args:
sse_string: SSE formatted string like "data: {...}\n\n"
Returns:
Parsed message object or None if can't parse
"""
try:
# strip SSE formatting
if sse_string.startswith("data: "):
json_str = sse_string[6:].strip()
# handle [DONE] marker
if json_str == "[DONE]":
return MessageStreamStatus.done
# parse JSON
data = json.loads(json_str)
# reconstruct message object based on message_type
message_type = data.get("message_type")
if message_type == "assistant_message":
return AssistantMessage(**data)
elif message_type == "usage_statistics":
return LettaUsageStatistics(**data)
elif message_type == "stop_reason":
# skip stop_reason, we use [DONE] instead
return None
else:
# other message types we skip
return None
return None
except Exception as e:
logger.warning(f"Failed to parse SSE chunk: {e}")
return None
async def transform_stream(self, letta_stream: AsyncIterator) -> AsyncIterator[str]:
"""
Transform Letta stream to OpenAI ChatCompletionChunk SSE format.
Args:
letta_stream: Async iterator of Letta messages (may be SSE strings or objects)
Yields:
SSE-formatted strings: "data: {json}\n\n"
"""
try:
async for raw_chunk in letta_stream:
# parse SSE string if needed
if isinstance(raw_chunk, str):
chunk = self._parse_sse_chunk(raw_chunk)
if chunk is None:
continue # skip unparseable or filtered chunks
else:
chunk = raw_chunk
# only process assistant messages
if isinstance(chunk, AssistantMessage):
async for sse_chunk in self._process_assistant_message(chunk):
print(f"CHUNK: {sse_chunk}")
yield sse_chunk
# handle completion status
elif chunk == MessageStreamStatus.done:
# emit final chunk with finish_reason
final_chunk = ChatCompletionChunk(
id=self.completion_id,
object="chat.completion.chunk",
created=self.created,
model=self.model,
choices=[
Choice(
index=0,
delta=ChoiceDelta(),
finish_reason="stop",
)
],
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Error in OpenAI stream transformation: {e}", exc_info=True)
error_chunk = {"error": {"message": str(e), "type": "server_error"}}
yield f"data: {json.dumps(error_chunk)}\n\n"
async def _process_assistant_message(self, message: AssistantMessage) -> AsyncIterator[str]:
"""
Convert AssistantMessage to OpenAI ChatCompletionChunk(s).
Args:
message: Letta AssistantMessage with content
Yields:
SSE-formatted chunk strings
"""
# extract text from content (can be string or list of TextContent)
text_content = self._extract_text_content(message.content)
if not text_content:
return
# emit role on first chunk only
if self.first_chunk:
self.first_chunk = False
# first chunk includes role
chunk = ChatCompletionChunk(
id=self.completion_id,
object="chat.completion.chunk",
created=self.created,
model=self.model,
choices=[
Choice(
index=0,
delta=ChoiceDelta(role="assistant", content=text_content),
finish_reason=None,
)
],
)
else:
# subsequent chunks just have content
chunk = ChatCompletionChunk(
id=self.completion_id,
object="chat.completion.chunk",
created=self.created,
model=self.model,
choices=[
Choice(
index=0,
delta=ChoiceDelta(content=text_content),
finish_reason=None,
)
],
)
yield f"data: {chunk.model_dump_json()}\n\n"
def _extract_text_content(self, content: Union[str, list[TextContent]]) -> str:
"""
Extract text string from content field.
Args:
content: Either a string or list of TextContent objects
Returns:
Extracted text string
"""
if isinstance(content, str):
return content
elif isinstance(content, list):
# concatenate all TextContent items
text_parts = []
for item in content:
if isinstance(item, TextContent):
text_parts.append(item.text)
return "".join(text_parts)
return ""

View File

@@ -1,6 +1,7 @@
import os
import threading
import uuid
from typing import List
import pytest
from dotenv import load_dotenv
@@ -9,8 +10,9 @@ from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.enums import AgentType, MessageStreamStatus
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage as OpenAIUserMessage
from letta.schemas.usage import LettaUsageStatistics
from tests.utils import wait_for_server
@@ -71,7 +73,7 @@ def weather_tool(client):
"""
Fetches the current weather for a given location.
Parameters:
Args:
location (str): The location to get the weather for.
Returns:
@@ -100,6 +102,7 @@ def weather_tool(client):
def agent(client, roll_dice_tool, weather_tool):
"""Creates an agent and ensures cleanup after tests."""
agent_state = client.agents.create(
agent_type=AgentType.letta_v1_agent,
name=f"test_compl_{str(uuid.uuid4())[5:]}",
tool_ids=[roll_dice_tool.id, weather_tool.id],
include_base_tools=True,
@@ -111,7 +114,6 @@ def agent(client, roll_dice_tool, weather_tool):
embedding_config=EmbeddingConfig.default_config(provider="openai"),
)
yield agent_state
client.agents.delete(agent_state.id)
# --- Helper Functions --- #
@@ -149,42 +151,46 @@ def _assert_valid_chunk(chunk, idx, chunks):
@pytest.mark.asyncio
@pytest.mark.parametrize("message", ["Tell me something interesting about bananas.", "What's the weather in SF?"])
@pytest.mark.parametrize("endpoint", ["openai/v1"])
async def test_chat_completions_streaming_openai_client(disable_e2b_api_key, client, agent, message, endpoint):
"""Tests chat completion streaming using the Async OpenAI client."""
request = _get_chat_request(message)
@pytest.mark.parametrize("message", ["Tell me a short joke"])
async def test_chat_completions_streaming_openai_client(disable_e2b_api_key, client, agent, roll_dice_tool, message):
"""Tests Letta's OpenAI-compatible chat completions streaming endpoint."""
async_client = AsyncOpenAI(base_url="http://localhost:8283/v1", max_retries=0)
async_client = AsyncOpenAI(base_url=f"http://localhost:8283/{endpoint}/{agent.id}", max_retries=0)
stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True))
stream = await async_client.chat.completions.create(
model=agent.id, # agent ID goes in model field
messages=[{"role": "user", "content": message}],
stream=True,
)
received_chunks = 0
stop_chunk_count = 0
last_chunk = None
content_parts = []
try:
async with stream:
async for chunk in stream:
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
async for chunk in stream:
assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}"
assert chunk.choices, "Each ChatCompletionChunk should have at least one choice."
# Track last chunk for final verification
last_chunk = chunk
last_chunk = chunk
# If this chunk has a finish reason of "stop", track it
if chunk.choices[0].finish_reason == "stop":
stop_chunk_count += 1
# Fail early if more than one stop chunk is sent
assert stop_chunk_count == 1, f"Multiple stop chunks detected: {chunk.model_dump_json(indent=4)}"
continue
if chunk.choices[0].finish_reason == "stop":
stop_chunk_count += 1
assert stop_chunk_count == 1, f"Multiple stop chunks detected: {chunk.model_dump_json(indent=4)}"
continue
# Validate regular content chunks
assert chunk.choices[0].delta.content, f"Chunk at index {received_chunks} has no content: {chunk.model_dump_json(indent=4)}"
if chunk.choices[0].delta.content:
content_parts.append(chunk.choices[0].delta.content)
received_chunks += 1
except Exception as e:
pytest.fail(f"Streaming failed with exception: {e}")
assert received_chunks > 0, "No valid streaming chunks were received."
print("\n=== Stream Summary ===")
print(f"Received chunks: {received_chunks}")
print(f"Full response: {''.join(content_parts)}")
print(f"Stop chunk count: {stop_chunk_count}")
# Ensure the last chunk is the expected stop chunk
assert received_chunks > 0, "No valid streaming chunks were received."
assert stop_chunk_count == 1, "Expected exactly one stop chunk."
assert last_chunk is not None, "No last chunk received."
assert last_chunk.choices[0].finish_reason == "stop", "Last chunk should have finish_reason='stop'"