feat: add new client.runs.stream endpoint (#4165)
This commit is contained in:
@@ -1,16 +1,21 @@
|
||||
from typing import Annotated, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Query
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
from pydantic import Field
|
||||
|
||||
from letta.data_sources.redis_client import get_redis_client
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.enums import JobStatus, JobType, MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import RetrieveStreamRequest
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.step import Step
|
||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode, add_keepalive_to_stream
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/runs", tags=["runs"])
|
||||
|
||||
@@ -213,3 +218,41 @@ async def delete_run(
|
||||
return Run.from_job(job)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Run not found")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/{run_id}/stream",
|
||||
response_model=None,
|
||||
operation_id="retrieve_stream",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def retrieve_stream(
|
||||
run_id: str,
|
||||
request: RetrieveStreamRequest = Body(...),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
redis_client = await get_redis_client()
|
||||
|
||||
stream = redis_sse_stream_generator(
|
||||
redis_client=redis_client,
|
||||
run_id=run_id,
|
||||
start_cursor=request.starting_after,
|
||||
poll_interval=request.poll_interval,
|
||||
batch_size=request.batch_size,
|
||||
)
|
||||
|
||||
if request.include_pings and settings.enable_keepalive:
|
||||
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
|
||||
|
||||
return StreamingResponseWithStatusCode(
|
||||
stream,
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user