Files
letta-server/letta/services/streaming_service.py
Kian Jones f5c4ab50f4 chore: add ty + pre-commit hook and repeal even more ruff rules (#9504)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import

* remove all ignores, add FastAPI rules and Ruff rules

* add ty and precommit

* ruff stuff

* ty check fixes

* ty check fixes pt 2

* error on invalid
2026-02-24 10:55:11 -08:00

727 lines
30 KiB
Python

import json
import time
from typing import AsyncIterator, Optional, Union
from uuid import uuid4
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice, ChoiceDelta
from letta.agents.agent_loop import AgentLoop
from letta.agents.base_agent_v2 import BaseAgentV2
from letta.constants import REDIS_RUN_ID_PREFIX
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import (
LettaInvalidArgumentError,
LettaServiceUnavailableError,
LLMAuthenticationError,
LLMError,
LLMRateLimitError,
LLMTimeoutError,
PendingApprovalError,
)
from letta.helpers.datetime_helpers import get_utc_timestamp_ns
from letta.log import get_logger
from letta.otel.context import get_ctx_attributes
from letta.otel.metric_registry import MetricRegistry
from letta.schemas.agent import AgentState
from letta.schemas.enums import AgentType, MessageStreamStatus, RunStatus
from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_message import AssistantMessage, LettaErrorMessage, MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_request import ClientToolSchema, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.run import Run as PydanticRun, RunUpdate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
from letta.server.rest_api.streaming_response import (
RunCancelledException,
StreamingResponseWithStatusCode,
add_keepalive_to_stream,
cancellation_aware_stream_wrapper,
get_cancellation_event_for_run,
)
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.run_manager import RunManager
from letta.settings import settings
from letta.utils import safe_create_task
logger = get_logger(__name__)
class StreamingService:
"""
Service for managing agent streaming responses.
Handles run creation, stream generation, error handling, and format conversion.
"""
def __init__(self, server):
"""
Initialize the streaming service.
Args:
server: The SyncServer instance for accessing managers and services
"""
self.server = server
self.runs_manager = RunManager() if settings.track_agent_run else None
async def create_agent_stream(
self,
agent_id: str,
actor: User,
request: LettaStreamingRequest,
run_type: str = "streaming",
conversation_id: Optional[str] = None,
) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]:
"""
Create a streaming response for an agent.
Args:
agent_id: The agent ID to stream from
actor: The user making the request
request: The LettaStreamingRequest containing all request parameters
run_type: Type of run for tracking
conversation_id: Optional conversation ID for conversation-scoped messaging
Returns:
Tuple of (run object or None, streaming response)
"""
request_start_timestamp_ns = get_utc_timestamp_ns()
MetricRegistry().user_message_counter.add(1, get_ctx_attributes())
# get redis client
redis_client = await get_redis_client()
# load agent and check eligibility
agent = await self.server.agent_manager.get_agent_by_id_async(
agent_id,
actor,
include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools", "tags"],
)
# Handle model override if specified in the request
if request.override_model:
override_llm_config = await self.server.get_llm_config_from_handle_async(
actor=actor,
handle=request.override_model,
)
# Create a copy of agent state with the overridden llm_config
agent = agent.model_copy(update={"llm_config": override_llm_config})
model_compatible_token_streaming = self._is_token_streaming_compatible(agent)
# Attempt to acquire conversation lock if conversation_id is provided
# This prevents concurrent message processing for the same conversation
# Skip locking if Redis is not available (graceful degradation)
if conversation_id and not isinstance(redis_client, NoopAsyncRedisClient):
await redis_client.acquire_conversation_lock(
conversation_id=conversation_id,
token=str(uuid4()),
)
# create run if tracking is enabled
run = None
run_update_metadata = None
try:
if settings.track_agent_run:
run = await self._create_run(agent_id, request, run_type, actor, conversation_id=conversation_id)
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
# use agent loop for streaming
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
# create the base stream with error handling
raw_stream = self._create_error_aware_stream(
agent_loop=agent_loop,
messages=request.messages,
max_steps=request.max_steps,
stream_tokens=request.stream_tokens and model_compatible_token_streaming,
run_id=run.id if run else None,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
actor=actor,
conversation_id=conversation_id,
client_tools=request.client_tools,
include_compaction_messages=request.include_compaction_messages,
)
# handle background streaming if requested
if request.background and settings.track_agent_run:
if isinstance(redis_client, NoopAsyncRedisClient):
raise LettaServiceUnavailableError(
f"Background streaming requires Redis to be running. "
f"Please ensure Redis is properly configured. "
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}",
service_name="redis",
)
# Wrap the agent loop stream with cancellation awareness for background task
background_stream = raw_stream
if settings.enable_cancellation_aware_streaming and run:
background_stream = cancellation_aware_stream_wrapper(
stream_generator=raw_stream,
run_manager=self.runs_manager,
run_id=run.id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run.id),
)
safe_create_task(
create_background_stream_processor(
stream_generator=background_stream,
redis_client=redis_client,
run_id=run.id,
run_manager=self.server.run_manager,
actor=actor,
conversation_id=conversation_id,
),
label=f"background_stream_processor_{run.id}",
)
raw_stream = redis_sse_stream_generator(
redis_client=redis_client,
run_id=run.id,
)
# wrap client stream with cancellation awareness if enabled and tracking runs
stream = raw_stream
if settings.enable_cancellation_aware_streaming and settings.track_agent_run and run and not request.background:
stream = cancellation_aware_stream_wrapper(
stream_generator=raw_stream,
run_manager=self.runs_manager,
run_id=run.id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run.id),
)
# conditionally wrap with keepalive based on request parameter
if request.include_pings and settings.enable_keepalive:
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval, run_id=run.id)
result = StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)
# update run status to running before returning
if settings.track_agent_run and run:
# refetch run since it may have been updated by another service
run = await self.server.run_manager.get_run_by_id(run_id=run.id, actor=actor)
if run.status == RunStatus.created:
run_status = RunStatus.running
else:
# don't override run status if it has already been updated
run_status = None
return run, result
except PendingApprovalError as e:
if settings.track_agent_run:
run_update_metadata = {"error": str(e)}
run_status = RunStatus.failed
raise
except Exception as e:
if settings.track_agent_run:
run_update_metadata = {"error": str(e)}
run_status = RunStatus.failed
raise
finally:
if settings.track_agent_run and run and run_status:
await self.server.run_manager.update_run_by_id_async(
run_id=run.id,
conversation_id=conversation_id,
update=RunUpdate(status=run_status, metadata=run_update_metadata),
actor=actor,
)
async def create_agent_stream_openai_chat_completions(
self,
agent_id: str,
actor: User,
request: LettaStreamingRequest,
) -> StreamingResponse:
"""
Create OpenAI-compatible chat completions streaming response.
Transforms Letta's internal streaming format to match OpenAI's
ChatCompletionChunk schema, filtering out internal tool execution
and only streaming assistant text responses.
Args:
agent_id: The agent ID to stream from
actor: The user making the request
request: The LettaStreamingRequest containing all request parameters
Returns:
StreamingResponse with OpenAI-formatted SSE chunks
"""
# load agent to get model info for the completion chunks
agent = await self.server.agent_manager.get_agent_by_id_async(agent_id, actor)
# create standard Letta stream (returns SSE-formatted stream)
run, letta_stream_response = await self.create_agent_stream(
agent_id=agent_id,
actor=actor,
request=request,
run_type="openai_chat_completions",
)
# extract the stream iterator from the response
if isinstance(letta_stream_response, StreamingResponseWithStatusCode):
letta_stream = letta_stream_response.body_iterator
else:
raise LettaInvalidArgumentError(
"Agent is not compatible with streaming mode",
argument_name="model",
)
# create transformer with agent's model info
model_name = agent.llm_config.model if agent.llm_config else "unknown"
completion_id = f"chatcmpl-{run.id if run else str(uuid4())}"
transformer = OpenAIChatCompletionsStreamTransformer(
model=model_name,
completion_id=completion_id,
)
# transform Letta SSE stream to OpenAI format (parser handles SSE strings)
openai_stream = transformer.transform_stream(letta_stream)
return StreamingResponse(
openai_stream,
media_type="text/event-stream",
)
def _create_error_aware_stream(
self,
agent_loop: BaseAgentV2,
messages: list[MessageCreate],
max_steps: int,
stream_tokens: bool,
run_id: Optional[str],
use_assistant_message: bool,
request_start_timestamp_ns: int,
include_return_message_types: Optional[list[MessageType]],
actor: User,
conversation_id: Optional[str] = None,
client_tools: Optional[list[ClientToolSchema]] = None,
include_compaction_messages: bool = False,
) -> AsyncIterator:
"""
Create a stream with unified error handling.
Returns:
Async iterator that yields chunks with proper error handling
"""
async def error_aware_stream():
"""Stream that handles early LLM errors gracefully in streaming format."""
run_status = None
stop_reason = None
error_data = None
saw_done = False
saw_error = False
try:
stream = agent_loop.stream(
input_messages=messages,
max_steps=max_steps,
stream_tokens=stream_tokens,
run_id=run_id,
use_assistant_message=use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=include_return_message_types,
conversation_id=conversation_id,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
)
async for chunk in stream:
# Track terminal events (check at line start to avoid false positives in message content)
if isinstance(chunk, str):
if "\ndata: [DONE]" in chunk or chunk.startswith("data: [DONE]"):
saw_done = True
if "\nevent: error" in chunk or chunk.startswith("event: error"):
saw_error = True
yield chunk
# Stream completed - check if we got a terminal event
if not saw_done and not saw_error:
# Stream ended without terminal - treat as error to avoid hanging clients
logger.error(
f"Stream for run {run_id} ended without terminal event. "
f"Agent stop_reason: {agent_loop.stop_reason}. Emitting error + [DONE]."
)
stop_reason = LettaStopReason(stop_reason=StopReasonType.error)
error_message = LettaErrorMessage(
run_id=run_id,
error_type="stream_incomplete",
message="Stream ended unexpectedly without a terminal event.",
detail=None,
)
error_data = {"error": error_message.model_dump()}
yield f"data: {stop_reason.model_dump_json()}\n\n"
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
saw_error = True
saw_done = True
run_status = RunStatus.failed
else:
# set run status after successful completion
if agent_loop.stop_reason and agent_loop.stop_reason.stop_reason.value == "cancelled":
run_status = RunStatus.cancelled
else:
run_status = RunStatus.completed
stop_reason = agent_loop.stop_reason if agent_loop.stop_reason else LettaStopReason(stop_reason=StopReasonType.end_turn)
except LLMTimeoutError as e:
run_status = RunStatus.failed
stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error)
error_message = LettaErrorMessage(
run_id=run_id,
error_type="llm_timeout",
message="The LLM request timed out. Please try again.",
detail=str(e),
)
error_data = {"error": error_message.model_dump()}
logger.error(f"Run {run_id} stopped with LLM timeout error: {e}, error_data: {error_message.model_dump()}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
# Send [DONE] marker to properly close the stream
yield "data: [DONE]\n\n"
except LLMRateLimitError as e:
run_status = RunStatus.failed
stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error)
error_message = LettaErrorMessage(
run_id=run_id,
error_type="llm_rate_limit",
message="Rate limit exceeded for LLM model provider. Please wait before making another request.",
detail=str(e),
)
error_data = {"error": error_message.model_dump()}
logger.warning(f"Run {run_id} stopped with LLM rate limit error: {e}, error_data: {error_message.model_dump()}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
# Send [DONE] marker to properly close the stream
yield "data: [DONE]\n\n"
except LLMAuthenticationError as e:
run_status = RunStatus.failed
stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error)
error_message = LettaErrorMessage(
run_id=run_id,
error_type="llm_authentication",
message="Authentication failed with the LLM model provider.",
detail=str(e),
)
error_data = {"error": error_message.model_dump()}
logger.warning(f"Run {run_id} stopped with LLM authentication error: {e}, error_data: {error_message.model_dump()}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
# Send [DONE] marker to properly close the stream
yield "data: [DONE]\n\n"
except LLMError as e:
run_status = RunStatus.failed
stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error)
error_message = LettaErrorMessage(
run_id=run_id,
error_type="llm_error",
message="An error occurred with the LLM request.",
detail=str(e),
)
error_data = {"error": error_message.model_dump()}
logger.error(f"Run {run_id} stopped with LLM error: {e}, error_data: {error_message.model_dump()}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
# Send [DONE] marker to properly close the stream
yield "data: [DONE]\n\n"
except RunCancelledException:
# Run was explicitly cancelled - this is not an error
# The cancellation has already been handled by cancellation_aware_stream_wrapper
logger.info(f"Run {run_id} was cancelled, exiting stream gracefully")
# Send [DONE] to properly close the stream
yield "data: [DONE]\n\n"
# Don't update run status in finally - cancellation is already recorded
run_status = None # Signal to finally block to skip update
except Exception as e:
run_status = RunStatus.failed
stop_reason = LettaStopReason(stop_reason=StopReasonType.error)
# Use repr() if str() is empty (happens with Exception() with no args)
error_detail = str(e) or repr(e)
error_message = LettaErrorMessage(
run_id=run_id,
error_type="internal_error",
message="An unknown error occurred with the LLM streaming request.",
detail=error_detail,
)
error_data = {"error": error_message.model_dump()}
logger.error(f"Run {run_id} stopped with unknown error: {error_detail}, error_data: {error_message.model_dump()}")
yield f"data: {stop_reason.model_dump_json()}\n\n"
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
# Send [DONE] marker to properly close the stream
yield "data: [DONE]\n\n"
# Capture for Sentry but don't re-raise to allow stream to complete gracefully
capture_sentry_exception(e)
finally:
# always update run status, whether success or failure
if run_id and self.runs_manager and run_status:
# Extract stop_reason enum value from LettaStopReason object
stop_reason_value = stop_reason.stop_reason if stop_reason else StopReasonType.error.value
await self.runs_manager.update_run_by_id_async(
run_id=run_id,
conversation_id=conversation_id,
update=RunUpdate(status=run_status, stop_reason=stop_reason_value, metadata=error_data),
actor=actor,
)
return error_aware_stream()
def _is_token_streaming_compatible(self, agent: AgentState) -> bool:
"""Check if agent's model supports token-level streaming."""
base_compatible = agent.llm_config.model_endpoint_type in [
"anthropic",
"openai",
"bedrock",
"deepseek",
"zai",
"chatgpt_oauth",
"minimax",
"openrouter",
]
google_letta_v1 = agent.agent_type == AgentType.letta_v1_agent and agent.llm_config.model_endpoint_type in [
"google_ai",
"google_vertex",
]
return base_compatible or google_letta_v1
async def _create_run(
self, agent_id: str, request: LettaStreamingRequest, run_type: str, actor: User, conversation_id: Optional[str] = None
) -> PydanticRun:
"""Create a run for tracking execution."""
run = await self.runs_manager.create_run(
pydantic_run=PydanticRun(
agent_id=agent_id,
conversation_id=conversation_id,
background=request.background or False,
metadata={
"run_type": run_type,
},
request_config=LettaRequestConfig.from_letta_request(request),
),
actor=actor,
)
return run
async def _update_run_status(
self,
run_id: str,
status: RunStatus,
actor: User,
error: Optional[str] = None,
stop_reason: Optional[str] = None,
conversation_id: Optional[str] = None,
):
"""Update the status of a run."""
if not self.runs_manager:
return
update = RunUpdate(status=status)
if error:
update.metadata = {"error": error}
if stop_reason:
update.stop_reason = stop_reason
await self.runs_manager.update_run_by_id_async(
run_id=run_id,
update=update,
actor=actor,
conversation_id=conversation_id,
)
class OpenAIChatCompletionsStreamTransformer:
"""
Transforms Letta streaming messages into OpenAI ChatCompletionChunk format.
Filters out internal tool execution and only streams assistant text responses.
"""
def __init__(self, model: str, completion_id: str):
"""
Initialize the transformer.
Args:
model: Model name to include in chunks
completion_id: Unique ID for this completion (format: chatcmpl-{uuid})
"""
self.model = model
self.completion_id = completion_id
self.first_chunk = True
self.created = int(time.time())
# TODO: This is lowkey really ugly and poor code design, but this works fine for now
def _parse_sse_chunk(self, sse_string: str):
"""
Parse SSE-formatted string back into a message object.
Args:
sse_string: SSE formatted string like "data: {...}\n\n"
Returns:
Parsed message object or None if can't parse
"""
try:
# strip SSE formatting
if sse_string.startswith("data: "):
json_str = sse_string[6:].strip()
# handle [DONE] marker
if json_str == "[DONE]":
return MessageStreamStatus.done
# parse JSON
data = json.loads(json_str)
# reconstruct message object based on message_type
message_type = data.get("message_type")
if message_type == "assistant_message":
return AssistantMessage(**data)
elif message_type == "usage_statistics":
return LettaUsageStatistics(**data)
elif message_type == "stop_reason":
# skip stop_reason, we use [DONE] instead
return None
else:
# other message types we skip
return None
return None
except Exception as e:
logger.warning(f"Failed to parse SSE chunk: {e}")
return None
async def transform_stream(self, letta_stream: AsyncIterator) -> AsyncIterator[str]:
"""
Transform Letta stream to OpenAI ChatCompletionChunk SSE format.
Args:
letta_stream: Async iterator of Letta messages (may be SSE strings or objects)
Yields:
SSE-formatted strings: "data: {json}\n\n"
"""
try:
async for raw_chunk in letta_stream:
# parse SSE string if needed
if isinstance(raw_chunk, str):
chunk = self._parse_sse_chunk(raw_chunk)
if chunk is None:
continue # skip unparseable or filtered chunks
else:
chunk = raw_chunk
# only process assistant messages
if isinstance(chunk, AssistantMessage):
async for sse_chunk in self._process_assistant_message(chunk):
print(f"CHUNK: {sse_chunk}")
yield sse_chunk
# handle completion status
elif chunk == MessageStreamStatus.done:
# emit final chunk with finish_reason
final_chunk = ChatCompletionChunk(
id=self.completion_id,
object="chat.completion.chunk",
created=self.created,
model=self.model,
choices=[
Choice(
index=0,
delta=ChoiceDelta(),
finish_reason="stop",
)
],
)
yield f"data: {final_chunk.model_dump_json()}\n\n"
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Error in OpenAI stream transformation: {e}", exc_info=True)
error_chunk = {"error": {"message": str(e), "type": "server_error"}}
yield f"data: {json.dumps(error_chunk)}\n\n"
async def _process_assistant_message(self, message: AssistantMessage) -> AsyncIterator[str]:
"""
Convert AssistantMessage to OpenAI ChatCompletionChunk(s).
Args:
message: Letta AssistantMessage with content
Yields:
SSE-formatted chunk strings
"""
# extract text from content (can be string or list of TextContent)
text_content = self._extract_text_content(message.content)
if not text_content:
return
# emit role on first chunk only
if self.first_chunk:
self.first_chunk = False
# first chunk includes role
chunk = ChatCompletionChunk(
id=self.completion_id,
object="chat.completion.chunk",
created=self.created,
model=self.model,
choices=[
Choice(
index=0,
delta=ChoiceDelta(role="assistant", content=text_content),
finish_reason=None,
)
],
)
else:
# subsequent chunks just have content
chunk = ChatCompletionChunk(
id=self.completion_id,
object="chat.completion.chunk",
created=self.created,
model=self.model,
choices=[
Choice(
index=0,
delta=ChoiceDelta(content=text_content),
finish_reason=None,
)
],
)
yield f"data: {chunk.model_dump_json()}\n\n"
def _extract_text_content(self, content: Union[str, list[TextContent]]) -> str:
"""
Extract text string from content field.
Args:
content: Either a string or list of TextContent objects
Returns:
Extracted text string
"""
if isinstance(content, str):
return content
elif isinstance(content, list):
# concatenate all TextContent items
text_parts = []
for item in content:
if isinstance(item, TextContent):
text_parts.append(item.text)
return "".join(text_parts)
return ""