feat: add keepalive wrapper to stream route (#3645)
Co-authored-by: Caren Thomas <carenthomas@gmail.com>
This commit is contained in:
@@ -1061,28 +1061,42 @@ async def send_message_streaming(
|
||||
else SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER
|
||||
),
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
|
||||
|
||||
if request.stream_tokens and model_compatible_token_streaming and not_letta_endpoint:
|
||||
raw_stream = agent_loop.step_stream(
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
# Conditionally wrap with keepalive based on settings
|
||||
if settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval)
|
||||
else:
|
||||
stream = raw_stream
|
||||
|
||||
result = StreamingResponseWithStatusCode(
|
||||
agent_loop.step_stream(
|
||||
input_messages=request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
raw_stream = agent_loop.step_stream_no_tokens(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
# Conditionally wrap with keepalive based on settings
|
||||
if settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(raw_stream, keepalive_interval=settings.keepalive_interval)
|
||||
else:
|
||||
stream = raw_stream
|
||||
|
||||
result = StreamingResponseWithStatusCode(
|
||||
agent_loop.step_stream_no_tokens(
|
||||
request.messages,
|
||||
max_steps=request.max_steps,
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
|
||||
@@ -11,6 +11,7 @@ from starlette.types import Send
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.letta_ping import LettaPing
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import capture_sentry_exception
|
||||
from letta.services.job_manager import JobManager
|
||||
@@ -18,6 +19,80 @@ from letta.services.job_manager import JobManager
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def add_keepalive_to_stream(
|
||||
stream_generator: AsyncIterator[str | bytes],
|
||||
keepalive_interval: float = 30.0,
|
||||
) -> AsyncIterator[str | bytes]:
|
||||
"""
|
||||
Adds periodic keepalive messages to a stream to prevent connection timeouts.
|
||||
|
||||
Sends a keepalive ping every `keepalive_interval` seconds, regardless of
|
||||
whether data is flowing. This ensures connections stay alive during long
|
||||
operations like tool execution.
|
||||
|
||||
Args:
|
||||
stream_generator: The original stream generator to wrap
|
||||
keepalive_interval: Seconds between keepalive messages (default: 30)
|
||||
|
||||
Yields:
|
||||
Original stream chunks interspersed with keepalive messages
|
||||
"""
|
||||
# Use a queue to decouple the stream reading from keepalive timing
|
||||
queue = asyncio.Queue()
|
||||
stream_exhausted = False
|
||||
|
||||
async def stream_reader():
|
||||
"""Read from the original stream and put items in the queue."""
|
||||
nonlocal stream_exhausted
|
||||
try:
|
||||
async for item in stream_generator:
|
||||
await queue.put(("data", item))
|
||||
finally:
|
||||
stream_exhausted = True
|
||||
await queue.put(("end", None))
|
||||
|
||||
# Start the stream reader task
|
||||
reader_task = asyncio.create_task(stream_reader())
|
||||
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
# Wait for data with a timeout equal to keepalive interval
|
||||
msg_type, data = await asyncio.wait_for(queue.get(), timeout=keepalive_interval)
|
||||
|
||||
if msg_type == "end":
|
||||
# Stream finished
|
||||
break
|
||||
elif msg_type == "data":
|
||||
yield data
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
# No data received within keepalive interval
|
||||
if not stream_exhausted:
|
||||
# Send keepalive ping in the same format as [DONE]
|
||||
yield f"data: {LettaPing().model_dump_json()}\n\n"
|
||||
else:
|
||||
# Stream is done but queue might be processing
|
||||
# Check if there's anything left
|
||||
try:
|
||||
msg_type, data = queue.get_nowait()
|
||||
if msg_type == "end":
|
||||
break
|
||||
elif msg_type == "data":
|
||||
yield data
|
||||
except asyncio.QueueEmpty:
|
||||
# Really done now
|
||||
break
|
||||
|
||||
finally:
|
||||
# Clean up the reader task
|
||||
reader_task.cancel()
|
||||
try:
|
||||
await reader_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
|
||||
# TODO (cliandy) wrap this and handle types
|
||||
async def cancellation_aware_stream_wrapper(
|
||||
stream_generator: AsyncIterator[str | bytes],
|
||||
|
||||
@@ -194,6 +194,10 @@ class Settings(BaseSettings):
|
||||
debug: Optional[bool] = False
|
||||
cors_origins: Optional[list] = cors_origins
|
||||
|
||||
# SSE Streaming keepalive settings
|
||||
enable_keepalive: bool = Field(True, description="Enable keepalive messages in SSE streams to prevent timeouts")
|
||||
keepalive_interval: float = Field(15.0, description="Seconds between keepalive messages (default: 15)")
|
||||
|
||||
# default handles
|
||||
default_llm_handle: Optional[str] = None
|
||||
default_embedding_handle: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user