diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 496345e6..3c8b6be7 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -70,3 +70,21 @@ class CreateBatch(BaseModel): "'status' is the final batch status (e.g., 'completed', 'failed'), and " "'completed_at' is an ISO 8601 timestamp indicating when the batch job completed.", ) + + +class RetrieveStreamRequest(BaseModel): + starting_after: int = Field( + 0, description="Sequence id to use as a cursor for pagination. Response will start streaming after this chunk sequence id" + ) + include_pings: Optional[bool] = Field( + default=False, + description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.", + ) + poll_interval: Optional[float] = Field( + default=0.1, + description="Seconds to wait between polls when no new data.", + ) + batch_size: Optional[int] = Field( + default=100, + description="Number of entries to read per batch.", + ) diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py index 46eaf5f3..091304f1 100644 --- a/letta/server/rest_api/redis_stream_manager.py +++ b/letta/server/rest_api/redis_stream_manager.py @@ -263,9 +263,6 @@ async def redis_sse_stream_generator( logger.debug(f"Starting redis_sse_stream_generator for run_id={run_id}, stream_key={stream_key}") - # Add a small initial delay to allow background task to start writing - await asyncio.sleep(0.05) - while True: entries = await redis_client.xrange(stream_key, start=last_redis_id, count=batch_size) diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index ae1c0f59..72ef6d8b 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -1,16 +1,21 @@ from typing import Annotated, List, Optional -from fastapi import APIRouter, Depends, Header, HTTPException, Query +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from pydantic import Field +from letta.data_sources.redis_client import get_redis_client from letta.orm.errors import NoResultFound from letta.schemas.enums import JobStatus, JobType, MessageRole from letta.schemas.letta_message import LettaMessageUnion +from letta.schemas.letta_request import RetrieveStreamRequest from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.run import Run from letta.schemas.step import Step +from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator +from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer +from letta.settings import settings router = APIRouter(prefix="/runs", tags=["runs"]) @@ -213,3 +218,41 @@ async def delete_run( return Run.from_job(job) except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") + + +@router.post( + "/{run_id}/stream", + response_model=None, + operation_id="retrieve_stream", + responses={ + 200: { + "description": "Successful response", + "content": { + "text/event-stream": {"description": "Server-Sent Events stream"}, + }, + } + }, +) +async def retrieve_stream( + run_id: str, + request: RetrieveStreamRequest = Body(...), + actor_id: Optional[str] = Header(None, alias="user_id"), + server: "SyncServer" = Depends(get_letta_server), +): + redis_client = await get_redis_client() + + stream = redis_sse_stream_generator( + redis_client=redis_client, + run_id=run_id, + start_cursor=request.starting_after, + poll_interval=request.poll_interval, + batch_size=request.batch_size, + ) + + if request.include_pings and settings.enable_keepalive: + stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval) + + return StreamingResponseWithStatusCode( + stream, + media_type="text/event-stream", + )