Files
letta-server/letta/server/rest_api/streaming_response.py

342 lines
14 KiB
Python

# Alternative implementation of StreamingResponse that allows for effectively
# stremaing HTTP trailers, as we cannot set codes after the initial response.
# Taken from: https://github.com/fastapi/fastapi/discussions/10138#discussioncomment-10377361
import asyncio
import json
from collections.abc import AsyncIterator
import anyio
from fastapi import HTTPException
from fastapi.responses import StreamingResponse
from starlette.types import Send
from letta.errors import LettaUnexpectedStreamCancellationError, PendingApprovalError
from letta.log import get_logger
from letta.schemas.enums import JobStatus
from letta.schemas.letta_ping import LettaPing
from letta.schemas.user import User
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.job_manager import JobManager
from letta.settings import settings
from letta.utils import safe_create_task
logger = get_logger(__name__)
class JobCancelledException(Exception):
"""Exception raised when a job is explicitly cancelled (not due to client timeout)"""
def __init__(self, job_id: str, message: str = None):
self.job_id = job_id
super().__init__(message or f"Job {job_id} was explicitly cancelled")
async def add_keepalive_to_stream(
stream_generator: AsyncIterator[str | bytes],
keepalive_interval: float = 30.0,
) -> AsyncIterator[str | bytes]:
"""
Adds periodic keepalive messages to a stream to prevent connection timeouts.
Sends a keepalive ping every `keepalive_interval` seconds, regardless of
whether data is flowing. This ensures connections stay alive during long
operations like tool execution.
Args:
stream_generator: The original stream generator to wrap
keepalive_interval: Seconds between keepalive messages (default: 30)
Yields:
Original stream chunks interspersed with keepalive messages
"""
# Use a queue to decouple the stream reading from keepalive timing
queue = asyncio.Queue()
stream_exhausted = False
async def stream_reader():
"""Read from the original stream and put items in the queue."""
nonlocal stream_exhausted
try:
async for item in stream_generator:
await queue.put(("data", item))
finally:
stream_exhausted = True
await queue.put(("end", None))
# Start the stream reader task
reader_task = safe_create_task(stream_reader(), label="stream_reader")
try:
while True:
try:
# Wait for data with a timeout equal to keepalive interval
msg_type, data = await asyncio.wait_for(queue.get(), timeout=keepalive_interval)
if msg_type == "end":
# Stream finished
break
elif msg_type == "data":
yield data
except asyncio.TimeoutError:
# No data received within keepalive interval
if not stream_exhausted:
# Send keepalive ping in the same format as [DONE]
yield f"data: {LettaPing().model_dump_json()}\n\n"
else:
# Stream is done but queue might be processing
# Check if there's anything left
try:
msg_type, data = queue.get_nowait()
if msg_type == "end":
break
elif msg_type == "data":
yield data
except asyncio.QueueEmpty:
# Really done now
break
finally:
# Clean up the reader task
reader_task.cancel()
try:
await reader_task
except asyncio.CancelledError:
pass
# TODO (cliandy) wrap this and handle types
async def cancellation_aware_stream_wrapper(
stream_generator: AsyncIterator[str | bytes],
job_manager: JobManager,
job_id: str,
actor: User,
cancellation_check_interval: float = 0.5,
) -> AsyncIterator[str | bytes]:
"""
Wraps a stream generator to provide real-time job cancellation checking.
This wrapper periodically checks for job cancellation while streaming and
can interrupt the stream at any point, not just at step boundaries.
Args:
stream_generator: The original stream generator to wrap
job_manager: Job manager instance for checking job status
job_id: ID of the job to monitor for cancellation
actor: User/actor making the request
cancellation_check_interval: How often to check for cancellation (seconds)
Yields:
Stream chunks from the original generator until cancelled
Raises:
asyncio.CancelledError: If the job is cancelled during streaming
"""
last_cancellation_check = asyncio.get_event_loop().time()
try:
async for chunk in stream_generator:
# Check for cancellation periodically (not on every chunk for performance)
current_time = asyncio.get_event_loop().time()
if current_time - last_cancellation_check >= cancellation_check_interval:
try:
job = await job_manager.get_job_by_id_async(job_id=job_id, actor=actor)
if job.status == JobStatus.cancelled:
logger.info(f"Stream cancelled for job {job_id}, interrupting stream")
# Send cancellation event to client
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
yield f"data: {json.dumps(cancellation_event)}\n\n"
# Raise custom exception for explicit job cancellation
raise JobCancelledException(job_id, f"Job {job_id} was cancelled")
except Exception as e:
# Log warning but don't fail the stream if cancellation check fails
logger.warning(f"Failed to check job cancellation for job {job_id}: {e}")
last_cancellation_check = current_time
yield chunk
except JobCancelledException:
# Re-raise JobCancelledException to distinguish from client timeout
logger.info(f"Stream for job {job_id} was explicitly cancelled and cleaned up")
raise
except asyncio.CancelledError:
# Re-raise CancelledError (likely client timeout) to ensure proper cleanup
logger.info(f"Stream for job {job_id} was cancelled (likely client timeout) and cleaned up")
raise
except Exception as e:
logger.error(f"Error in cancellation-aware stream wrapper for job {job_id}: {e}")
raise
class StreamingResponseWithStatusCode(StreamingResponse):
"""
Variation of StreamingResponse that can dynamically decide the HTTP status code,
based on the return value of the content iterator (parameter `content`).
Expects the content to yield either just str content as per the original `StreamingResponse`
or else tuples of (`content`: `str`, `status_code`: `int`).
"""
body_iterator: AsyncIterator[str | bytes]
response_started: bool = False
_client_connected: bool = True
async def stream_response(self, send: Send) -> None:
if settings.use_asyncio_shield:
try:
await asyncio.shield(self._protected_stream_response(send))
except asyncio.CancelledError:
logger.info("Stream response was cancelled, but shielded task should continue")
except anyio.ClosedResourceError:
logger.info("Client disconnected, but shielded task should continue")
self._client_connected = False
except PendingApprovalError as e:
# This is an expected error, don't log as error
logger.info(f"Pending approval conflict in stream response: {e}")
# Re-raise as HTTPException for proper client handling
raise HTTPException(
status_code=409, detail={"code": "PENDING_APPROVAL", "message": str(e), "pending_request_id": e.pending_request_id}
)
except Exception as e:
logger.error(f"Error in protected stream response: {e}")
raise
else:
await self._protected_stream_response(send)
async def _protected_stream_response(self, send: Send) -> None:
more_body = True
try:
first_chunk = await self.body_iterator.__anext__()
logger.debug("stream_response first chunk:", first_chunk)
if isinstance(first_chunk, tuple):
first_chunk_content, self.status_code = first_chunk
else:
first_chunk_content = first_chunk
if isinstance(first_chunk_content, str):
first_chunk_content = first_chunk_content.encode(self.charset)
try:
await send(
{
"type": "http.response.start",
"status": self.status_code,
"headers": self.raw_headers,
}
)
self.response_started = True
await send(
{
"type": "http.response.body",
"body": first_chunk_content,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
logger.info("Client disconnected during initial response, continuing processing without sending more chunks")
self._client_connected = False
async for chunk in self.body_iterator:
if isinstance(chunk, tuple):
content, status_code = chunk
if status_code // 100 != 2:
# An error occurred mid-stream
if not isinstance(content, bytes):
content = content.encode(self.charset)
more_body = False
raise Exception(f"An exception occurred mid-stream with status code {status_code} with content {content}")
else:
content = chunk
if isinstance(content, str):
content = content.encode(self.charset)
more_body = True
# Only attempt to send if client is still connected
if self._client_connected:
try:
await send(
{
"type": "http.response.body",
"body": content,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
logger.info("Client disconnected, continuing processing without sending more data")
self._client_connected = False
# Continue processing but don't try to send more data
# Handle explicit job cancellations (should not throw error)
except JobCancelledException as exc:
logger.info(f"Stream was explicitly cancelled for job {exc.job_id}")
# Handle explicit cancellation gracefully without error
more_body = False
cancellation_resp = {"message": "Job was cancelled"}
cancellation_event = f"event: cancelled\ndata: {json.dumps(cancellation_resp)}\n\n".encode(self.charset)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 200, # Use 200 for graceful cancellation
"headers": self.raw_headers,
}
)
raise
if self._client_connected:
try:
await send(
{
"type": "http.response.body",
"body": cancellation_event,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
self._client_connected = False
return
# Handle client timeouts (should throw error to inform user)
except asyncio.CancelledError as exc:
logger.warning("Stream was terminated due to unexpected cancellation from server")
# Handle unexpected cancellation with error
more_body = False
capture_sentry_exception(exc)
raise LettaUnexpectedStreamCancellationError("Stream was terminated due to unexpected cancellation from server")
except Exception as exc:
logger.exception(f"Unhandled Streaming Error: {str(exc)}")
more_body = False
# error_resp = {"error": {"message": str(exc)}}
error_resp = {"error": str(exc), "code": "INTERNAL_SERVER_ERROR"}
error_event = f"event: error\ndata: {json.dumps(error_resp)}\n\n".encode(self.charset)
logger.debug("response_started:", self.response_started)
if not self.response_started:
await send(
{
"type": "http.response.start",
"status": 500,
"headers": self.raw_headers,
}
)
raise
if self._client_connected:
try:
await send(
{
"type": "http.response.body",
"body": error_event,
"more_body": more_body,
}
)
except anyio.ClosedResourceError:
self._client_connected = False
capture_sentry_exception(exc)
return
if more_body and self._client_connected:
try:
await send({"type": "http.response.body", "body": b"", "more_body": False})
except anyio.ClosedResourceError:
self._client_connected = False