feat: add new client.runs.stream endpoint (#4165)

This commit is contained in:
cthomas
2025-08-25 13:42:22 -07:00
committed by GitHub
parent dc83ff9f52
commit 8909fd257b
3 changed files with 62 additions and 4 deletions

View File

@@ -70,3 +70,21 @@ class CreateBatch(BaseModel):
"'status' is the final batch status (e.g., 'completed', 'failed'), and "
"'completed_at' is an ISO 8601 timestamp indicating when the batch job completed.",
)
class RetrieveStreamRequest(BaseModel):
starting_after: int = Field(
0, description="Sequence id to use as a cursor for pagination. Response will start streaming after this chunk sequence id"
)
include_pings: Optional[bool] = Field(
default=False,
description="Whether to include periodic keepalive ping messages in the stream to prevent connection timeouts.",
)
poll_interval: Optional[float] = Field(
default=0.1,
description="Seconds to wait between polls when no new data.",
)
batch_size: Optional[int] = Field(
default=100,
description="Number of entries to read per batch.",
)

View File

@@ -263,9 +263,6 @@ async def redis_sse_stream_generator(
logger.debug(f"Starting redis_sse_stream_generator for run_id={run_id}, stream_key={stream_key}")
# Add a small initial delay to allow background task to start writing
await asyncio.sleep(0.05)
while True:
entries = await redis_client.xrange(stream_key, start=last_redis_id, count=batch_size)

View File

@@ -1,16 +1,21 @@
from typing import Annotated, List, Optional
from fastapi import APIRouter, Depends, Header, HTTPException, Query
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from pydantic import Field
from letta.data_sources.redis_client import get_redis_client
from letta.orm.errors import NoResultFound
from letta.schemas.enums import JobStatus, JobType, MessageRole
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_request import RetrieveStreamRequest
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.run import Run
from letta.schemas.step import Step
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.settings import settings
router = APIRouter(prefix="/runs", tags=["runs"])
@@ -213,3 +218,41 @@ async def delete_run(
return Run.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")
@router.post(
"/{run_id}/stream",
response_model=None,
operation_id="retrieve_stream",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
)
async def retrieve_stream(
run_id: str,
request: RetrieveStreamRequest = Body(...),
actor_id: Optional[str] = Header(None, alias="user_id"),
server: "SyncServer" = Depends(get_letta_server),
):
redis_client = await get_redis_client()
stream = redis_sse_stream_generator(
redis_client=redis_client,
run_id=run_id,
start_cursor=request.starting_after,
poll_interval=request.poll_interval,
batch_size=request.batch_size,
)
if request.include_pings and settings.enable_keepalive:
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
return StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)