diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 5f16380a..496345e6 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -46,6 +46,10 @@ class LettaStreamingRequest(LettaRequest): default=False, description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.", ) + background: bool = Field( + default=False, + description="Whether to process the request in the background.", + ) class LettaAsyncRequest(LettaRequest): diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py new file mode 100644 index 00000000..561218d2 --- /dev/null +++ b/letta/server/rest_api/redis_stream_manager.py @@ -0,0 +1,306 @@ +"""Redis stream manager for reading and writing SSE chunks with batching and TTL.""" + +import asyncio +import json +import time +from collections import defaultdict +from typing import AsyncIterator, Dict, List, Optional + +from letta.data_sources.redis_client import AsyncRedisClient +from letta.log import get_logger + +logger = get_logger(__name__) + + +class RedisSSEStreamWriter: + """ + Efficiently writes SSE chunks to Redis streams with batching and TTL management. + + Features: + - Batches writes using Redis pipelines for performance + - Automatically sets/refreshes TTL on streams + - Tracks sequential IDs for cursor-based recovery + - Handles flush on size or time thresholds + """ + + def __init__( + self, + redis_client: AsyncRedisClient, + flush_interval: float = 0.5, + flush_size: int = 50, + stream_ttl_seconds: int = 21600, # 6 hours default + max_stream_length: int = 10000, # Max entries per stream + ): + """ + Initialize the Redis SSE stream writer. + + Args: + redis_client: Redis client instance + flush_interval: Seconds between automatic flushes + flush_size: Number of chunks to buffer before flushing + stream_ttl_seconds: TTL for streams in seconds (default: 6 hours) + max_stream_length: Maximum entries per stream before trimming + """ + self.redis = redis_client + self.flush_interval = flush_interval + self.flush_size = flush_size + self.stream_ttl = stream_ttl_seconds + self.max_stream_length = max_stream_length + + # Buffer for batching: run_id -> list of chunks + self.buffer: Dict[str, List[Dict]] = defaultdict(list) + # Track sequence IDs per run + self.seq_counters: Dict[str, int] = defaultdict(int) + # Track last flush time per run + self.last_flush: Dict[str, float] = defaultdict(float) + + # Background flush task + self._flush_task = None + self._running = False + + async def start(self): + """Start the background flush task.""" + if not self._running: + self._running = True + self._flush_task = asyncio.create_task(self._periodic_flush()) + + async def stop(self): + """Stop the background flush task and flush remaining data.""" + self._running = False + if self._flush_task: + self._flush_task.cancel() + try: + await self._flush_task + except asyncio.CancelledError: + pass + + # Flush all remaining buffers + for run_id in list(self.buffer.keys()): + if self.buffer[run_id]: + await self._flush_run(run_id) + + async def write_chunk( + self, + run_id: str, + data: str, + is_complete: bool = False, + ) -> int: + """ + Write an SSE chunk to the buffer for a specific run. + + Args: + run_id: The run ID to write to + data: SSE-formatted chunk data + is_complete: Whether this is the final chunk + + Returns: + The sequence ID assigned to this chunk + """ + # Assign sequence ID + seq_id = self.seq_counters[run_id] + self.seq_counters[run_id] += 1 + + # Add to buffer + chunk = { + "seq_id": seq_id, + "data": data, + "timestamp": int(time.time() * 1000), + } + + # Mark completion if this is the last chunk + if is_complete: + chunk["complete"] = "true" + + self.buffer[run_id].append(chunk) + + # Check if we should flush + should_flush = ( + len(self.buffer[run_id]) >= self.flush_size or is_complete or (time.time() - self.last_flush[run_id]) > self.flush_interval + ) + + if should_flush: + await self._flush_run(run_id) + + return seq_id + + async def _flush_run(self, run_id: str): + """Flush buffered chunks for a specific run to Redis.""" + if not self.buffer[run_id]: + return + + chunks = self.buffer[run_id] + self.buffer[run_id] = [] + stream_key = f"sse:run:{run_id}" + + try: + client = await self.redis.get_client() + + # Use pipeline for batch writes + async with client.pipeline(transaction=False) as pipe: + for chunk in chunks: + pipe.xadd(stream_key, chunk, maxlen=self.max_stream_length, approximate=True) + + # Set/refresh TTL on the stream + pipe.expire(stream_key, self.stream_ttl) + + # Execute all commands in one round trip + await pipe.execute() + + self.last_flush[run_id] = time.time() + + # Log successful flush + logger.debug( + f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}" + ) + + # If this was a completion chunk, clean up tracking + if chunks[-1].get("complete") == "true": + self._cleanup_run(run_id) + + except Exception as e: + logger.error(f"Failed to flush chunks for run {run_id}: {e}") + # Put chunks back in buffer to retry + self.buffer[run_id] = chunks + self.buffer[run_id] + raise + + async def _periodic_flush(self): + """Background task to periodically flush buffers.""" + while self._running: + try: + await asyncio.sleep(self.flush_interval) + + # Check each run for time-based flush + current_time = time.time() + runs_to_flush = [ + run_id + for run_id, last_flush in self.last_flush.items() + if (current_time - last_flush) > self.flush_interval and self.buffer[run_id] + ] + + for run_id in runs_to_flush: + await self._flush_run(run_id) + + except asyncio.CancelledError: + break + except Exception as e: + logger.error(f"Error in periodic flush: {e}") + + def _cleanup_run(self, run_id: str): + """Clean up tracking data for a completed run.""" + self.buffer.pop(run_id, None) + self.seq_counters.pop(run_id, None) + self.last_flush.pop(run_id, None) + + async def mark_complete(self, run_id: str): + """Mark a stream as complete and flush.""" + # Add a [DONE] marker + await self.write_chunk(run_id, "data: [DONE]\n\n", is_complete=True) + + +async def create_background_stream_processor( + stream_generator, + redis_client: AsyncRedisClient, + run_id: str, + writer: Optional[RedisSSEStreamWriter] = None, +) -> None: + """ + Process a stream in the background and store chunks to Redis. + + This function consumes the stream generator and writes all chunks + to Redis for later retrieval. + + Args: + stream_generator: The async generator yielding SSE chunks + redis_client: Redis client instance + run_id: The run ID to store chunks under + writer: Optional pre-configured writer (creates new if not provided) + """ + if writer is None: + writer = RedisSSEStreamWriter(redis_client) + await writer.start() + should_stop_writer = True + else: + should_stop_writer = False + + try: + async for chunk in stream_generator: + # Check if this is the final chunk + is_done = "data: [DONE]" in chunk if isinstance(chunk, str) else False + + # Write chunk to Redis + await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done) + + if is_done: + break + + except Exception as e: + logger.error(f"Error processing stream for run {run_id}: {e}") + # Write error chunk + error_chunk = {"message_type": "error", "error": str(e)} + await writer.write_chunk(run_id=run_id, data=f"data: {json.dumps(error_chunk)}\n\n", is_complete=True) + finally: + if should_stop_writer: + await writer.stop() + + +async def redis_sse_stream_generator( + redis_client: AsyncRedisClient, + run_id: str, + start_cursor: Optional[int] = None, + poll_interval: float = 0.1, + batch_size: int = 100, +) -> AsyncIterator[str]: + """ + Generate SSE events from Redis stream chunks. + + This generator reads chunks stored in Redis streams and yields them as SSE events. + It supports cursor-based recovery by allowing you to start from a specific seq_id. + + Args: + redis_client: Redis client instance + run_id: The run ID to read chunks for + start_cursor: Sequential ID (integer) to start reading from (default: 0 for beginning) + poll_interval: Seconds to wait between polls when no new data (default: 0.1) + batch_size: Number of entries to read per batch (default: 100) + + Yields: + SSE-formatted chunks from the Redis stream + """ + stream_key = f"sse:run:{run_id}" + last_redis_id = "-" + cursor_seq_id = start_cursor or 0 + + logger.debug(f"Starting redis_sse_stream_generator for run_id={run_id}, stream_key={stream_key}") + + # Add a small initial delay to allow background task to start writing + await asyncio.sleep(0.05) + + while True: + entries = await redis_client.xrange(stream_key, start=last_redis_id, count=batch_size) + + if entries: + yielded_any = False + for entry_id, fields in entries: + if entry_id == last_redis_id: + continue + + chunk_seq_id = int(fields.get("seq_id", 0)) + if chunk_seq_id >= cursor_seq_id: + data = fields.get("data", "") + if not data: + continue + + if data == "data: [DONE]\n\n": + yield data + return + + yield data + yielded_any = True + + last_redis_id = entry_id + + if not yielded_any and len(entries) > 1: + continue + + if not entries or (len(entries) == 1 and entries[0][0] == last_redis_id): + await asyncio.sleep(poll_interval) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index bdfbb6d1..74bad958 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -14,7 +14,7 @@ from starlette.responses import Response, StreamingResponse from letta.agents.letta_agent import LettaAgent from letta.constants import AGENT_ID_PATTERN, DEFAULT_MAX_STEPS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REDIS_RUN_ID_PREFIX -from letta.data_sources.redis_client import get_redis_client +from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client from letta.errors import AgentExportIdMappingError, AgentExportProcessingError, AgentFileImportError, AgentNotFoundForExportError from letta.groups.sleeptime_multi_agent_v2 import SleeptimeMultiAgentV2 from letta.helpers.datetime_helpers import get_utc_timestamp_ns @@ -40,6 +40,7 @@ from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.user import User from letta.serialize_schemas.pydantic_agent_schema import AgentSchema +from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer from letta.services.summarizer.enums import SummarizationMode @@ -1259,8 +1260,43 @@ async def send_message_streaming( else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER ), ) + from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream + if request.background and request.stream_tokens and settings.track_agent_run: + if isinstance(redis_client, NoopAsyncRedisClient): + raise HTTPException( + status_code=503, + detail="Background streaming is not available: Redis is not configured. Please ensure Redis is properly configured and running.", + ) + + asyncio.create_task( + create_background_stream_processor( + stream_generator=agent_loop.step_stream( + input_messages=request.messages, + max_steps=request.max_steps, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, + ), + redis_client=redis_client, + run_id=run.id, + ) + ) + + stream = redis_sse_stream_generator( + redis_client=redis_client, + run_id=run.id, + ) + + if request.include_pings and settings.enable_keepalive: + stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval) + + return StreamingResponseWithStatusCode( + stream, + media_type="text/event-stream", + ) + if request.stream_tokens and model_compatible_token_streaming: raw_stream = agent_loop.step_stream( input_messages=request.messages,