From ce4e69cac6d3a8831cf51503f2e12d5e119fc987 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 29 Jul 2025 22:53:38 -0700 Subject: [PATCH] feat: add keepalive wrapper to stream route (#3645) Co-authored-by: Caren Thomas --- letta/server/rest_api/routers/v1/agents.py | 44 +++++++----- letta/server/rest_api/streaming_response.py | 75 +++++++++++++++++++++ letta/settings.py | 4 ++ 3 files changed, 108 insertions(+), 15 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 2811dc95..387732cf 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1061,28 +1061,42 @@ async def send_message_streaming( else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER ), ) - from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode + from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint: + raw_stream = agent_loop.step_stream( + input_messages=request.messages, + max_steps=request.max_steps, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, + ) + # Conditionally wrap with keepalive based on settings + if settings.enable_keepalive: + stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval) + else: + stream = raw_stream + result = StreamingResponseWithStatusCode( - agent_loop.step_stream( - input_messages=request.messages, - max_steps=request.max_steps, - use_assistant_message=request.use_assistant_message, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, - ), + stream, media_type="text/event-stream", ) else: + raw_stream = agent_loop.step_stream_no_tokens( + request.messages, + max_steps=request.max_steps, + use_assistant_message=request.use_assistant_message, + request_start_timestamp_ns=request_start_timestamp_ns, + include_return_message_types=request.include_return_message_types, + ) + # Conditionally wrap with keepalive based on settings + if settings.enable_keepalive: + stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval) + else: + stream = raw_stream + result = StreamingResponseWithStatusCode( - agent_loop.step_stream_no_tokens( - request.messages, - max_steps=request.max_steps, - use_assistant_message=request.use_assistant_message, - request_start_timestamp_ns=request_start_timestamp_ns, - include_return_message_types=request.include_return_message_types, - ), + stream, media_type="text/event-stream", ) else: diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 344d067d..ab3ca925 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -11,6 +11,7 @@ from starlette.types import Send from letta.log import get_logger from letta.schemas.enums import JobStatus +from letta.schemas.letta_ping import LettaPing from letta.schemas.user import User from letta.server.rest_api.utils import capture_sentry_exception from letta.services.job_manager import JobManager @@ -18,6 +19,80 @@ from letta.services.job_manager import JobManager logger = get_logger(__name__) +async def add_keepalive_to_stream( + stream_generator: AsyncIterator[str | bytes], + keepalive_interval: float = 30.0, +) -> AsyncIterator[str | bytes]: + """ + Adds periodic keepalive messages to a stream to prevent connection timeouts. + + Sends a keepalive ping every `keepalive_interval` seconds, regardless of + whether data is flowing. This ensures connections stay alive during long + operations like tool execution. + + Args: + stream_generator: The original stream generator to wrap + keepalive_interval: Seconds between keepalive messages (default: 30) + + Yields: + Original stream chunks interspersed with keepalive messages + """ + # Use a queue to decouple the stream reading from keepalive timing + queue = asyncio.Queue() + stream_exhausted = False + + async def stream_reader(): + """Read from the original stream and put items in the queue.""" + nonlocal stream_exhausted + try: + async for item in stream_generator: + await queue.put(("data", item)) + finally: + stream_exhausted = True + await queue.put(("end", None)) + + # Start the stream reader task + reader_task = asyncio.create_task(stream_reader()) + + try: + while True: + try: + # Wait for data with a timeout equal to keepalive interval + msg_type, data = await asyncio.wait_for(queue.get(), timeout=keepalive_interval) + + if msg_type == "end": + # Stream finished + break + elif msg_type == "data": + yield data + + except asyncio.TimeoutError: + # No data received within keepalive interval + if not stream_exhausted: + # Send keepalive ping in the same format as [DONE] + yield f"data: {LettaPing().model_dump_json()}\n\n" + else: + # Stream is done but queue might be processing + # Check if there's anything left + try: + msg_type, data = queue.get_nowait() + if msg_type == "end": + break + elif msg_type == "data": + yield data + except asyncio.QueueEmpty: + # Really done now + break + + finally: + # Clean up the reader task + reader_task.cancel() + try: + await reader_task + except asyncio.CancelledError: + pass + + # TODO (cliandy) wrap this and handle types async def cancellation_aware_stream_wrapper( stream_generator: AsyncIterator[str | bytes], diff --git a/letta/settings.py b/letta/settings.py index 7997c33c..b8de25b9 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -194,6 +194,10 @@ class Settings(BaseSettings): debug: Optional[bool] = False cors_origins: Optional[list] = cors_origins + # SSE Streaming keepalive settings + enable_keepalive: bool = Field(True, description="Enable keepalive messages in SSE streams to prevent timeouts") + keepalive_interval: float = Field(15.0, description="Seconds between keepalive messages (default: 15)") + # default handles default_llm_handle: Optional[str] = None default_embedding_handle: Optional[str] = None