feat: add keepalive wrapper to stream route (#3645)

Co-authored-by: Caren Thomas <carenthomas@gmail.com>
This commit is contained in:
Charles Packer
2025-07-29 22:53:38 -07:00
committed by GitHub
parent b94699e910
commit ce4e69cac6
3 changed files with 108 additions and 15 deletions

View File

@@ -1061,28 +1061,42 @@ async def send_message_streaming(
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
),
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint:
raw_stream = agent_loop.step_stream(
input_messages=request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
)
# Conditionally wrap with keepalive based on settings
if settings.enable_keepalive:
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval)
else:
stream = raw_stream
result = StreamingResponseWithStatusCode(
agent_loop.step_stream(
input_messages=request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
),
stream,
media_type="text/event-stream",
)
else:
raw_stream = agent_loop.step_stream_no_tokens(
request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
)
# Conditionally wrap with keepalive based on settings
if settings.enable_keepalive:
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval)
else:
stream = raw_stream
result = StreamingResponseWithStatusCode(
agent_loop.step_stream_no_tokens(
request.messages,
max_steps=request.max_steps,
use_assistant_message=request.use_assistant_message,
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=request.include_return_message_types,
),
stream,
media_type="text/event-stream",
)
else:

View File

@@ -11,6 +11,7 @@ from starlette.types import Send
from letta.log import get_logger
from letta.schemas.enums import JobStatus
from letta.schemas.letta_ping import LettaPing
from letta.schemas.user import User
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.job_manager import JobManager
@@ -18,6 +19,80 @@ from letta.services.job_manager import JobManager
logger = get_logger(__name__)
async def add_keepalive_to_stream(
stream_generator: AsyncIterator[str | bytes],
keepalive_interval: float = 30.0,
) -> AsyncIterator[str | bytes]:
"""
Adds periodic keepalive messages to a stream to prevent connection timeouts.
Sends a keepalive ping every `keepalive_interval` seconds, regardless of
whether data is flowing. This ensures connections stay alive during long
operations like tool execution.
Args:
stream_generator: The original stream generator to wrap
keepalive_interval: Seconds between keepalive messages (default: 30)
Yields:
Original stream chunks interspersed with keepalive messages
"""
# Use a queue to decouple the stream reading from keepalive timing
queue = asyncio.Queue()
stream_exhausted = False
async def stream_reader():
"""Read from the original stream and put items in the queue."""
nonlocal stream_exhausted
try:
async for item in stream_generator:
await queue.put(("data", item))
finally:
stream_exhausted = True
await queue.put(("end", None))
# Start the stream reader task
reader_task = asyncio.create_task(stream_reader())
try:
while True:
try:
# Wait for data with a timeout equal to keepalive interval
msg_type, data = await asyncio.wait_for(queue.get(), timeout=keepalive_interval)
if msg_type == "end":
# Stream finished
break
elif msg_type == "data":
yield data
except asyncio.TimeoutError:
# No data received within keepalive interval
if not stream_exhausted:
# Send keepalive ping in the same format as [DONE]
yield f"data: {LettaPing().model_dump_json()}\n\n"
else:
# Stream is done but queue might be processing
# Check if there's anything left
try:
msg_type, data = queue.get_nowait()
if msg_type == "end":
break
elif msg_type == "data":
yield data
except asyncio.QueueEmpty:
# Really done now
break
finally:
# Clean up the reader task
reader_task.cancel()
try:
await reader_task
except asyncio.CancelledError:
pass
# TODO (cliandy) wrap this and handle types
async def cancellation_aware_stream_wrapper(
stream_generator: AsyncIterator[str | bytes],

View File

@@ -194,6 +194,10 @@ class Settings(BaseSettings):
debug: Optional[bool] = False
cors_origins: Optional[list] = cors_origins
# SSE Streaming keepalive settings
enable_keepalive: bool = Field(True, description="Enable keepalive messages in SSE streams to prevent timeouts")
keepalive_interval: float = Field(15.0, description="Seconds between keepalive messages (default: 15)")
# default handles
default_llm_handle: Optional[str] = None
default_embedding_handle: Optional[str] = None