Files
letta-server/letta/server/rest_api/streaming_response.py
cthomas 6599aa3b44 feat: populate seq_id for ping messages (#8844)
* feat: populate seq_id for ping messages

* fix import
2026-01-19 15:54:43 -08:00

357 lines
15 KiB
Python

# Alternative implementation of StreamingResponse that allows for effectively
# stremaing HTTP trailers, as we cannot set codes after the initial response.
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
import asyncio
import json
import re
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from uuid import uuid4
import anyio
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from starlette.types import Send
from letta.errors import LettaUnexpectedStreamCancellationError, PendingApprovalError
from letta.log import get_logger
from letta.schemas.enums import RunStatus
from letta.schemas.letta_message import LettaPing
from letta.schemas.user import User
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.run_manager import RunManager
from letta.settings import settings
from letta.utils import safe_create_task
logger = get_logger(__name__)
class RunCancelledException(Exception):
"""Exception raised when a run is explicitly cancelled (not due to client timeout)"""
def __init__(self, run_id: str, message: str = None):
self.run_id = run_id
super().__init__(message or f"Run {run_id} was explicitly cancelled")
async def add_keepalive_to_stream(
stream_generator: AsyncIterator[str | bytes],
run_id: str,
keepalive_interval: float = 30.0,
) -> AsyncIterator[str | bytes]:
"""
Adds periodic keepalive messages to a stream to prevent connection timeouts.
Sends a keepalive ping every `keepalive_interval` seconds, regardless of
whether data is flowing. This ensures connections stay alive during long
operations like tool execution.
Args:
stream_generator: The original stream generator to wrap
keepalive_interval: Seconds between keepalive messages (default: 30)
Yields:
Original stream chunks interspersed with keepalive messages
"""
# Use a bounded queue to decouple reading from keepalive while preserving backpressure
# A small maxsize prevents unbounded memory growth if the client is slow
queue = asyncio.Queue(maxsize=1)
stream_exhausted = False
last_seq_id = None
async def stream_reader():
"""Read from the original stream and put items in the queue."""
nonlocal stream_exhausted
try:
async for item in stream_generator:
await queue.put(("data", item))
finally:
stream_exhausted = True
await queue.put(("end", None))
# Start the stream reader task
reader_task = safe_create_task(stream_reader(), label="stream_reader")
try:
while True:
try:
# Wait for data with a timeout equal to keepalive interval
msg_type, data = await asyncio.wait_for(queue.get(), timeout=keepalive_interval)
if msg_type == "end":
# Stream finished
break
elif msg_type == "data":
# Track seq_id from chunks for ping messages
if isinstance(data, str):
seq_id_match = re.search(r'"seq_id":(\d+)', data) # Look for "seq_id":<number> pattern in the SSE chunk
if seq_id_match:
last_seq_id = int(seq_id_match.group(1))
yield data
except asyncio.TimeoutError:
# No data received within keepalive interval
if not stream_exhausted:
# Send keepalive ping with the last seq_id to allow clients to track progress
yield f"data: {LettaPing(id=f'ping-{uuid4()}', date=datetime.now(timezone.utc), run_id=run_id, seq_id=last_seq_id).model_dump_json()}\n\n"
else:
# Stream is done but queue might be processing
# Check if there's anything left
try:
msg_type, data = queue.get_nowait()
if msg_type == "end":
break
elif msg_type == "data":
yield data
except asyncio.QueueEmpty:
# Really done now
break
finally:
# Clean up the reader task
reader_task.cancel()
try:
await reader_task
except asyncio.CancelledError:
pass
# TODO (cliandy) wrap this and handle types
async def cancellation_aware_stream_wrapper(
stream_generator: AsyncIterator[str | bytes],
run_manager: RunManager,
run_id: str,
actor: User,
cancellation_check_interval: float = 0.5,
) -> AsyncIterator[str | bytes]:
"""
Wraps a stream generator to provide real-time run cancellation checking.
This wrapper periodically checks for run cancellation while streaming and
can interrupt the stream at any point, not just at step boundaries.
Args:
stream_generator: The original stream generator to wrap
run_manager: Run manager instance for checking run status
run_id: ID of the run to monitor for cancellation
actor: User/actor making the request
cancellation_check_interval: How often to check for cancellation (seconds)
Yields:
Stream chunks from the original generator until cancelled
Raises:
asyncio.CancelledError: If the run is cancelled during streaming
"""
last_cancellation_check = asyncio.get_event_loop().time()
try:
async for chunk in stream_generator:
# Check for cancellation periodically (not on every chunk for performance)
current_time = asyncio.get_event_loop().time()
if current_time - last_cancellation_check >= cancellation_check_interval:
try:
run = await run_manager.get_run_by_id(run_id=run_id, actor=actor)
if run.status == RunStatus.cancelled:
logger.info(f"Stream cancelled for run {run_id}, interrupting stream")
# Send cancellation event to client
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
yield f"data: {json.dumps(cancellation_event)}\n\n"
# Raise custom exception for explicit run cancellation
raise RunCancelledException(run_id, f"Run {run_id} was cancelled")
except RunCancelledException:
# Re-raise cancellation immediately, don't catch it
raise
except Exception as e:
# Log warning but don't fail the stream if cancellation check fails
logger.warning(f"Failed to check run cancellation for run {run_id}: {e}")
last_cancellation_check = current_time
yield chunk
except RunCancelledException:
# Re-raise RunCancelledException to distinguish from client timeout
logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up")
raise
except asyncio.CancelledError:
# Re-raise CancelledError (likely client timeout) to ensure proper cleanup
logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up")
raise
except Exception as e:
logger.error(f"Error in cancellation-aware stream wrapper for run {run_id}: {e}")
raise
class StreamingResponseWithStatusCode(StreamingResponse):
"""
Variation of StreamingResponse that can dynamically decide the HTTP status code,
based on the return value of the content iterator (parameter `content`).
Expects the content to yield either just str content as per the original `StreamingResponse`
or else tuples of (`content`: `str`, `status_code`: `int`).
"""
body_iterator: AsyncIterator[str | bytes]
response_started: bool = False
_client_connected: bool = True
async def stream_response(self, send: Send) -> None:
if settings.use_asyncio_shield:
try:
await asyncio.shield(self._protected_stream_response(send))
except asyncio.CancelledError:
logger.info("Stream response was cancelled, but shielded task should continue")
except anyio.ClosedResourceError:
logger.info("Client disconnected, but shielded task should continue")
self._client_connected = False
except PendingApprovalError as e:
# This is an expected error, don't log as error
logger.info(f"Pending approval conflict in stream response: {e}")
# Re-raise as HTTPException for proper client handling
raise HTTPException(
status_code=409, detail={"code": "PENDING_APPROVAL", "message": str(e), "pending_request_id": e.pending_request_id}
)
except Exception as e:
logger.error(f"Error in protected stream response: {e}")
raise
else:
await self._protected_stream_response(send)
async def _protected_stream_response(self, send: Send) -> None:
more_body = True
try:
first_chunk = await self.body_iterator.__anext__()
logger.debug("stream_response first chunk:", first_chunk)
if isinstance(first_chunk, tuple):
first_chunk_content, self.status_code = first_chunk
else:
first_chunk_content = first_chunk
if isinstance(first_chunk_content, str):
first_chunk_content = first_chunk_content.encode(self.charset)
try:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
self.response_started = True
await send(
{
"type": "http.response.body",
"body": first_chunk_content,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
logger.info("Client disconnected during initial response, continuing processing without sending more chunks")
self._client_connected = False
async for chunk in self.body_iterator:
if isinstance(chunk, tuple):
content, status_code = chunk
if status_code // 100 != 2:
# An error occurred mid-stream
if not isinstance(content, bytes):
content = content.encode(self.charset)
more_body = False
raise Exception(f"An exception occurred mid-stream with status code {status_code} with content {content}")
else:
content = chunk
if isinstance(content, str):
content = content.encode(self.charset)
more_body = True
# Only attempt to send if client is still connected
if self._client_connected:
try:
await send(
{
"type": "http.response.body",
"body": content,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
logger.info("Client disconnected, continuing processing without sending more data")
self._client_connected = False
# Continue processing but don't try to send more data
# Handle explicit run cancellations (should not throw error)
except RunCancelledException as exc:
logger.info(f"Stream was explicitly cancelled for run {exc.run_id}")
# Handle explicit cancellation gracefully without error
more_body = False
cancellation_resp = {"message": "Run was cancelled"}
cancellation_event = f"event: cancelled\ndata: {json.dumps(cancellation_resp)}\n\n".encode(self.charset)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 200, # Use 200 for graceful cancellation
"headers": self.raw_headers,
}
)
raise
if self._client_connected:
try:
await send(
{
"type": "http.response.body",
"body": cancellation_event,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
self._client_connected = False
return
# Handle client timeouts (should throw error to inform user)
except asyncio.CancelledError as exc:
logger.warning("Stream was terminated due to unexpected cancellation from server")
# Handle unexpected cancellation with error
more_body = False
capture_sentry_exception(exc)
raise LettaUnexpectedStreamCancellationError("Stream was terminated due to unexpected cancellation from server")
except Exception as exc:
logger.exception(f"Unhandled Streaming Error: {str(exc)}")
more_body = False
# error_resp = {"error": {"message": str(exc)}}
error_resp = {"error": str(exc), "code": "INTERNAL_SERVER_ERROR"}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
logger.debug("response_started:", self.response_started)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 500,
"headers": self.raw_headers,
}
)
raise
if self._client_connected:
try:
await send(
{
"type": "http.response.body",
"body": error_event,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
self._client_connected = False
capture_sentry_exception(exc)
return
if more_body and self._client_connected:
try:
await send({"type": "http.response.body", "body": b"", "more_body": False})
except anyio.ClosedResourceError:
self._client_connected = False