From 11a36aaa6ab7bfcd4b35000740febebdcbd68011 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 14 Oct 2025 11:01:26 -0700 Subject: [PATCH] feat: Create streaming service class [LET-5504] (#5418) Create streaming service --- letta/server/rest_api/routers/v1/agents.py | 199 +----------- letta/services/streaming_service.py | 338 +++++++++++++++++++++ 2 files changed, 350 insertions(+), 187 deletions(-) create mode 100644 letta/services/streaming_service.py diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 08af9533..a2372fb7 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -15,7 +15,7 @@ from starlette.responses import Response, StreamingResponse from letta.agents.agent_loop import AgentLoop from letta.agents.letta_agent_v2 import LettaAgentV2 from letta.constants import AGENT_ID_PATTERN, DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX -from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client +from letta.data_sources.redis_client import get_redis_client from letta.errors import ( AgentExportIdMappingError, AgentExportProcessingError, @@ -54,7 +54,6 @@ from letta.schemas.tool import Tool from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server -from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator from letta.server.server import SyncServer from letta.services.lettuce import LettuceClient from letta.services.run_manager import RunManager @@ -1291,195 +1290,21 @@ async def send_message_streaming( This endpoint accepts a message from a user and processes it through the agent. It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ - request_start_timestamp_ns = get_utc_timestamp_ns() - MetricRegistry().user_message_counter.add(1, get_ctx_attributes()) - - # TODO (cliandy): clean this up - redis_client = await get_redis_client() + from letta.services.streaming_service import StreamingService actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # TODO: This is redundant, remove soon - agent = await server.agent_manager.get_agent_by_id_async( - agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"] + + # use the streaming service for unified stream handling + streaming_service = StreamingService(server) + + run, result = await streaming_service.create_agent_stream( + agent_id=agent_id, + actor=actor, + request=request, + run_type="send_message_streaming", ) - agent_eligible = agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] - model_compatible = agent.llm_config.model_endpoint_type in [ - "anthropic", - "openai", - "together", - "google_ai", - "google_vertex", - "bedrock", - "ollama", - "azure", - "xai", - "groq", - "deepseek", - ] - model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek"] - if agent.agent_type == AgentType.letta_v1_agent and agent.llm_config.model_endpoint_type in ["google_ai", "google_vertex"]: - model_compatible_token_streaming = True - # Create a new run for execution tracking - if settings.track_agent_run: - runs_manager = RunManager() - run = await runs_manager.create_run( - pydantic_run=PydanticRun( - agent_id=agent_id, - background=request.background or False, - metadata={ - "run_type": "send_message_streaming", - }, - request_config=LettaRequestConfig.from_letta_request(request), - ), - actor=actor, - ) - run_update_metadata = None - await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None) - else: - run = None - - try: - if agent_eligible and model_compatible: - agent_loop = AgentLoop.load(agent_state=agent, actor=actor) - - async def error_aware_stream(): - """Stream that handles early LLM errors gracefully in streaming format.""" - from letta.errors import LLMAuthenticationError, LLMError, LLMRateLimitError, LLMTimeoutError - - try: - stream = agent_loop.stream( - input_messages=request.messages, - max_steps=request.max_steps, - stream_tokens=request.stream_tokens and model_compatible_token_streaming, - run_id=run.id if run else None, - use_assistant_message=request.use_assistant_message, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, - ) - async for chunk in stream: - yield chunk - - if run: - runs_manager = RunManager() - from letta.schemas.enums import RunStatus - - if agent_loop.stop_reason.stop_reason.value == "cancelled": - run_status = RunStatus.cancelled - else: - run_status = RunStatus.completed - - await runs_manager.update_run_by_id_async( - run_id=run.id, - update=RunUpdate(status=run_status, stop_reason=agent_loop.stop_reason.stop_reason.value), - actor=actor, - ) - - except LLMTimeoutError as e: - error_data = { - "error": {"type": "llm_timeout", "message": "The LLM request timed out. Please try again.", "detail": str(e)} - } - yield (f"data: {json.dumps(error_data)}\n\n", 504) - except LLMRateLimitError as e: - error_data = { - "error": { - "type": "llm_rate_limit", - "message": "Rate limit exceeded for LLM model provider. Please wait before making another request.", - "detail": str(e), - } - } - yield (f"data: {json.dumps(error_data)}\n\n", 429) - except LLMAuthenticationError as e: - error_data = { - "error": { - "type": "llm_authentication", - "message": "Authentication failed with the LLM model provider.", - "detail": str(e), - } - } - yield (f"data: {json.dumps(error_data)}\n\n", 401) - except LLMError as e: - error_data = {"error": {"type": "llm_error", "message": "An error occurred with the LLM request.", "detail": str(e)}} - yield (f"data: {json.dumps(error_data)}\n\n", 502) - except Exception as e: - error_data = {"error": {"type": "internal_error", "message": "An internal server error occurred.", "detail": str(e)}} - yield (f"data: {json.dumps(error_data)}\n\n", 500) - - raw_stream = error_aware_stream() - - from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream - - if request.background and settings.track_agent_run: - if isinstance(redis_client, NoopAsyncRedisClient): - raise HTTPException( - status_code=503, - detail=( - "Background streaming requires Redis to be running. " - "Please ensure Redis is properly configured. " - f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}" - ), - ) - - safe_create_task( - create_background_stream_processor( - stream_generator=raw_stream, - redis_client=redis_client, - run_id=run.id, - run_manager=server.run_manager, - actor=actor, - ), - label=f"background_stream_processor_{run.id}", - ) - - raw_stream = redis_sse_stream_generator( - redis_client=redis_client, - run_id=run.id, - ) - - # Conditionally wrap with keepalive based on request parameter - if request.include_pings and settings.enable_keepalive: - stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval) - else: - stream = raw_stream - - result = StreamingResponseWithStatusCode( - stream, - media_type="text/event-stream", - ) - else: - result = await server.send_message_to_agent( - agent_id=agent_id, - actor=actor, - input_messages=request.messages, - stream_steps=True, - stream_tokens=request.stream_tokens, - # Support for AssistantMessage - use_assistant_message=request.use_assistant_message, - assistant_message_tool_name=request.assistant_message_tool_name, - assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, - ) - if settings.track_agent_run: - run_status = RunStatus.running - return result - except PendingApprovalError as e: - if settings.track_agent_run: - run_update_metadata = {"error": str(e)} - run_status = RunStatus.failed - raise HTTPException( - status_code=409, detail={"code": "PENDING_APPROVAL", "message": str(e), "pending_request_id": e.pending_request_id} - ) - except Exception as e: - if settings.track_agent_run: - run_update_metadata = {"error": str(e)} - run_status = RunStatus.failed - raise - finally: - if settings.track_agent_run: - await server.run_manager.update_run_by_id_async( - run_id=run.id, update=RunUpdate(status=run_status, metadata=run_update_metadata), actor=actor - ) + return result class CancelAgentRunRequest(BaseModel): diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py new file mode 100644 index 00000000..66c03ae4 --- /dev/null +++ b/letta/services/streaming_service.py @@ -0,0 +1,338 @@ +""" +Streaming service for handling agent message streaming with various formats. +Provides a unified interface for streaming agent responses with support for +different output formats (Letta native, OpenAI-compatible, etc.) +""" + +import asyncio +import json +from typing import AsyncIterator, Optional, Union + +from fastapi import HTTPException +from fastapi.responses import StreamingResponse + +from letta.agents.agent_loop import AgentLoop +from letta.constants import REDIS_RUN_ID_PREFIX +from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client +from letta.errors import LLMAuthenticationError, LLMError, LLMRateLimitError, LLMTimeoutError, PendingApprovalError +from letta.helpers.datetime_helpers import get_utc_timestamp_ns +from letta.log import get_logger +from letta.otel.context import get_ctx_attributes +from letta.otel.metric_registry import MetricRegistry +from letta.schemas.agent import AgentState +from letta.schemas.enums import AgentType, RunStatus +from letta.schemas.job import LettaRequestConfig +from letta.schemas.letta_message import MessageType +from letta.schemas.letta_request import LettaStreamingRequest +from letta.schemas.letta_response import LettaResponse +from letta.schemas.message import MessageCreate +from letta.schemas.run import Run as PydanticRun, RunUpdate +from letta.schemas.user import User +from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator +from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream +from letta.services.lettuce import LettuceClient +from letta.services.run_manager import RunManager +from letta.settings import settings +from letta.utils import safe_create_task + +logger = get_logger(__name__) + + +class StreamingService: + """ + Service for managing agent streaming responses. + Handles run creation, stream generation, error handling, and format conversion. + """ + + def __init__(self, server): + """ + Initialize the streaming service. + + Args: + server: The SyncServer instance for accessing managers and services + """ + self.server = server + self.runs_manager = RunManager() if settings.track_agent_run else None + + async def create_agent_stream( + self, + agent_id: str, + actor: User, + request: LettaStreamingRequest, + run_type: str = "streaming", + ) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]: + """ + Create a streaming response for an agent. + + Args: + agent_id: The agent ID to stream from + actor: The user making the request + request: The LettaStreamingRequest containing all request parameters + run_type: Type of run for tracking + + Returns: + Tuple of (run object or None, streaming response) + """ + request_start_timestamp_ns = get_utc_timestamp_ns() + MetricRegistry().user_message_counter.add(1, get_ctx_attributes()) + + # get redis client + redis_client = await get_redis_client() + + # load agent and check eligibility + agent = await self.server.agent_manager.get_agent_by_id_async( + agent_id, actor, include_relationships=["memory", "multi_agent_group", "sources", "tool_exec_environment_variables", "tools"] + ) + + agent_eligible = self._is_agent_eligible(agent) + model_compatible = self._is_model_compatible(agent) + model_compatible_token_streaming = self._is_token_streaming_compatible(agent) + + # create run if tracking is enabled + run = None + run_update_metadata = None + if settings.track_agent_run: + run = await self._create_run(agent_id, request, run_type, actor) + await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None) + + try: + if agent_eligible and model_compatible: + # use agent loop for streaming + agent_loop = AgentLoop.load(agent_state=agent, actor=actor) + + # create the base stream with error handling + raw_stream = self._create_error_aware_stream( + agent_loop=agent_loop, + messages=request.messages, + max_steps=request.max_steps, + stream_tokens=request.stream_tokens and model_compatible_token_streaming, + run_id=run.id if run else None, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, + actor=actor, + ) + + # handle background streaming if requested + if request.background and settings.track_agent_run: + if isinstance(redis_client, NoopAsyncRedisClient): + raise HTTPException( + status_code=503, + detail=( + "Background streaming requires Redis to be running. " + "Please ensure Redis is properly configured. " + f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}" + ), + ) + + safe_create_task( + create_background_stream_processor( + stream_generator=raw_stream, + redis_client=redis_client, + run_id=run.id, + run_manager=self.server.run_manager, + actor=actor, + ), + label=f"background_stream_processor_{run.id}", + ) + + raw_stream = redis_sse_stream_generator( + redis_client=redis_client, + run_id=run.id, + ) + + # conditionally wrap with keepalive based on request parameter + if request.include_pings and settings.enable_keepalive: + stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval) + else: + stream = raw_stream + + result = StreamingResponseWithStatusCode( + stream, + media_type="text/event-stream", + ) + else: + # fallback to non-agent-loop streaming + result = await self.server.send_message_to_agent( + agent_id=agent_id, + actor=actor, + input_messages=request.messages, + stream_steps=True, + stream_tokens=request.stream_tokens, + use_assistant_message=request.use_assistant_message, + assistant_message_tool_name=request.assistant_message_tool_name, + assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, + ) + + # update run status to running before returning + if settings.track_agent_run and run: + run_status = RunStatus.running + + return run, result + + except PendingApprovalError as e: + if settings.track_agent_run: + run_update_metadata = {"error": str(e)} + run_status = RunStatus.failed + raise HTTPException( + status_code=409, detail={"code": "PENDING_APPROVAL", "message": str(e), "pending_request_id": e.pending_request_id} + ) + except Exception as e: + if settings.track_agent_run: + run_update_metadata = {"error": str(e)} + run_status = RunStatus.failed + raise + finally: + if settings.track_agent_run and run: + await self.server.run_manager.update_run_by_id_async( + run_id=run.id, + update=RunUpdate(status=run_status, metadata=run_update_metadata), + actor=actor, + ) + + def _create_error_aware_stream( + self, + agent_loop: AgentLoop, + messages: list[MessageCreate], + max_steps: int, + stream_tokens: bool, + run_id: Optional[str], + use_assistant_message: bool, + request_start_timestamp_ns: int, + include_return_message_types: Optional[list[MessageType]], + actor: User, + ) -> AsyncIterator: + """ + Create a stream with unified error handling. + + Returns: + Async iterator that yields chunks with proper error handling + """ + + async def error_aware_stream(): + """Stream that handles early LLM errors gracefully in streaming format.""" + try: + stream = agent_loop.stream( + input_messages=messages, + max_steps=max_steps, + stream_tokens=stream_tokens, + run_id=run_id, + use_assistant_message=use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=include_return_message_types, + ) + async for chunk in stream: + yield chunk + + # update run status after completion + if run_id and self.runs_manager: + if agent_loop.stop_reason.stop_reason.value == "cancelled": + run_status = RunStatus.cancelled + else: + run_status = RunStatus.completed + + await self.runs_manager.update_run_by_id_async( + run_id=run_id, + update=RunUpdate(status=run_status, stop_reason=agent_loop.stop_reason.stop_reason.value), + actor=actor, + ) + + except LLMTimeoutError as e: + error_data = {"error": {"type": "llm_timeout", "message": "The LLM request timed out. Please try again.", "detail": str(e)}} + yield (f"data: {json.dumps(error_data)}\n\n", 504) + except LLMRateLimitError as e: + error_data = { + "error": { + "type": "llm_rate_limit", + "message": "Rate limit exceeded for LLM model provider. Please wait before making another request.", + "detail": str(e), + } + } + yield (f"data: {json.dumps(error_data)}\n\n", 429) + except LLMAuthenticationError as e: + error_data = { + "error": { + "type": "llm_authentication", + "message": "Authentication failed with the LLM model provider.", + "detail": str(e), + } + } + yield (f"data: {json.dumps(error_data)}\n\n", 401) + except LLMError as e: + error_data = {"error": {"type": "llm_error", "message": "An error occurred with the LLM request.", "detail": str(e)}} + yield (f"data: {json.dumps(error_data)}\n\n", 502) + except Exception as e: + error_data = {"error": {"type": "internal_error", "message": "An internal server error occurred.", "detail": str(e)}} + yield (f"data: {json.dumps(error_data)}\n\n", 500) + + return error_aware_stream() + + def _is_agent_eligible(self, agent: AgentState) -> bool: + """Check if agent is eligible for streaming.""" + return agent.multi_agent_group is None or agent.multi_agent_group.manager_type in ["sleeptime", "voice_sleeptime"] + + def _is_model_compatible(self, agent: AgentState) -> bool: + """Check if agent's model is compatible with streaming.""" + return agent.llm_config.model_endpoint_type in [ + "anthropic", + "openai", + "together", + "google_ai", + "google_vertex", + "bedrock", + "ollama", + "azure", + "xai", + "groq", + "deepseek", + ] + + def _is_token_streaming_compatible(self, agent: AgentState) -> bool: + """Check if agent's model supports token-level streaming.""" + base_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock", "deepseek"] + google_letta_v1 = agent.agent_type == AgentType.letta_v1_agent and agent.llm_config.model_endpoint_type in [ + "google_ai", + "google_vertex", + ] + return base_compatible or google_letta_v1 + + async def _create_run(self, agent_id: str, request: LettaStreamingRequest, run_type: str, actor: User) -> PydanticRun: + """Create a run for tracking execution.""" + run = await self.runs_manager.create_run( + pydantic_run=PydanticRun( + agent_id=agent_id, + background=request.background or False, + metadata={ + "run_type": run_type, + }, + request_config=LettaRequestConfig.from_letta_request(request), + ), + actor=actor, + ) + return run + + async def _update_run_status( + self, + run_id: str, + status: RunStatus, + actor: User, + error: Optional[str] = None, + stop_reason: Optional[str] = None, + ): + """Update the status of a run.""" + if not self.runs_manager: + return + + update = RunUpdate(status=status) + if error: + update.metadata = {"error": error} + if stop_reason: + update.stop_reason = stop_reason + + await self.runs_manager.update_run_by_id_async( + run_id=run_id, + update=update, + actor=actor, + )