Files
letta-server/letta/server/rest_api/routers/v1/runs.py
2025-10-07 17:50:44 -07:00

366 lines
14 KiB
Python

from datetime import timedelta
from typing import Annotated, List, Literal, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from pydantic import Field
from temporalio.client import Client
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.helpers.datetime_helpers import get_utc_time
from letta.orm.errors import NoResultFound
from letta.schemas.enums import JobStatus, JobType
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_request import RetrieveStreamRequest
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.run import Run
from letta.schemas.step import Step
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
from letta.server.rest_api.streaming_response import (
StreamingResponseWithStatusCode,
add_keepalive_to_stream,
cancellation_aware_stream_wrapper,
)
from letta.server.server import SyncServer
from letta.settings import settings
router = APIRouter(prefix="/runs", tags=["runs"])
@router.get("/", response_model=List[Run], operation_id="list_runs")
def list_runs(
server: "SyncServer" = Depends(get_letta_server),
agent_id: Optional[str] = Query(None, description="The unique identifier of the agent associated with the run."),
agent_ids: Optional[List[str]] = Query(None, description="(DEPRECATED) The unique identifiers of the agents associated with the run."),
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
stop_reason: Optional[StopReasonType] = Query(None, description="Filter runs by stop reason."),
after: Optional[str] = Query(None, description="Cursor for pagination"),
before: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(50, description="Maximum number of runs to return"),
active: bool = Query(False, description="Filter for active runs."),
ascending: bool = Query(
False,
description="Whether to sort agents oldest to newest (True) or newest to oldest (False, default)",
),
headers: HeaderParams = Depends(get_headers),
):
"""
List all runs.
"""
actor = server.user_manager.get_user_or_default(user_id=headers.actor_id)
statuses = None
if active:
statuses = [JobStatus.created, JobStatus.running]
if agent_id:
# NOTE: we are deprecating agent_ids so this will the primary path soon
agent_ids = [agent_id]
runs = [
Run.from_job(job)
for job in server.job_manager.list_jobs(
actor=actor,
statuses=statuses,
job_type=JobType.RUN,
limit=limit,
before=before,
after=after,
ascending=False,
stop_reason=stop_reason,
# agent_id=agent_id,
agent_ids=agent_ids,
background=background,
)
]
return runs
@router.get("/active", response_model=List[Run], operation_id="list_active_runs", deprecated=True)
def list_active_runs(
server: "SyncServer" = Depends(get_letta_server),
agent_id: Optional[str] = Query(None, description="The unique identifier of the agent associated with the run."),
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
headers: HeaderParams = Depends(get_headers),
):
"""
List all active runs.
"""
actor = server.user_manager.get_user_or_default(user_id=headers.actor_id)
if agent_id:
agent_ids = [agent_id]
else:
agent_ids = None
active_runs = server.job_manager.list_jobs(
actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN, agent_ids=agent_ids, background=background
)
active_runs = [Run.from_job(job) for job in active_runs]
return active_runs
@router.get("/{run_id}", response_model=Run, operation_id="retrieve_run")
async def retrieve_run(
run_id: str,
headers: HeaderParams = Depends(get_headers),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get the status of a run.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
job = await server.job_manager.get_job_by_id_async(job_id=run_id, actor=actor)
if job.metadata.get("temporal") and settings.temporal_endpoint:
client = await Client.connect(
settings.temporal_endpoint,
namespace=settings.temporal_namespace,
api_key=settings.temporal_api_key,
tls=True, # This should be false for local runs
)
handle = client.get_workflow_handle(run_id)
# Fetch the workflow description
desc = await handle.describe()
# Map the status to our enum
job_status = JobStatus.created
if desc.status.name == "RUNNING":
job_status = JobStatus.running
elif desc.status.name == "COMPLETED":
job_status = JobStatus.completed
elif desc.status.name == "FAILED":
job_status = JobStatus.failed
elif desc.status.name == "CANCELED":
job_status = JobStatus.canceled
# elif desc.status.name == "TERMINATED":
# job_status = JobStatus.terminated
# elif desc.status.name == "TIMED_OUT":
# job_status = JobStatus.timed_out
# elif desc.status.name == "CONTINUED_AS_NEW":
# return WorkflowStatus.CONTINUED_AS_NEW
# else:
# return WorkflowStatus.UNKNOWN
job.status = job_status
return Run.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")
RunMessagesResponse = Annotated[
List[LettaMessageUnion], Field(json_schema_extra={"type": "array", "items": {"$ref": "#/components/schemas/LettaMessageUnion"}})
]
@router.get(
"/{run_id}/messages",
response_model=RunMessagesResponse,
operation_id="list_run_messages",
)
async def list_run_messages(
run_id: str,
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
before: Optional[str] = Query(
None, description="Message ID cursor for pagination. Returns messages that come before this message ID in the specified sort order"
),
after: Optional[str] = Query(
None, description="Message ID cursor for pagination. Returns messages that come after this message ID in the specified sort order"
),
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
order: Literal["asc", "desc"] = Query(
"asc", description="Sort order for messages by creation time. 'asc' for oldest first, 'desc' for newest first"
),
):
"""Get response messages associated with a run."""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
messages = server.job_manager.get_run_messages(
run_id=run_id,
actor=actor,
limit=limit,
before=before,
after=after,
ascending=(order == "asc"),
)
return messages
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.get("/{run_id}/usage", response_model=UsageStatistics, operation_id="retrieve_run_usage")
def retrieve_run_usage(
run_id: str,
headers: HeaderParams = Depends(get_headers),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Get usage statistics for a run.
"""
actor = server.user_manager.get_user_or_default(user_id=headers.actor_id)
try:
usage = server.job_manager.get_job_usage(job_id=run_id, actor=actor)
return usage
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
@router.get(
"/{run_id}/steps",
response_model=List[Step],
operation_id="list_run_steps",
)
async def list_run_steps(
run_id: str,
server: "SyncServer" = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
order: str = Query(
"desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
),
):
"""
Get messages associated with a run with filtering options.
Args:
run_id: ID of the run
before: A cursor for use in pagination. `before` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, starting with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
limit: Maximum number of steps to return
order: Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
Returns:
A list of steps associated with the run.
"""
if order not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
steps = server.job_manager.get_job_steps(
job_id=run_id,
actor=actor,
limit=limit,
before=before,
after=after,
ascending=(order == "asc"),
)
return steps
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
async def delete_run(
run_id: str,
headers: HeaderParams = Depends(get_headers),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Delete a run by its run_id.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
job = await server.job_manager.delete_job_by_id_async(job_id=run_id, actor=actor)
return Run.from_job(job)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")
@router.post(
"/{run_id}/stream",
response_model=None,
operation_id="retrieve_stream",
responses={
200: {
"description": "Successful response",
"content": {
# Align streaming schema with agents.create_stream so SDKs accept approval messages
"text/event-stream": {
"description": "Server-Sent Events stream",
"schema": {
"oneOf": [
{"$ref": "#/components/schemas/SystemMessage"},
{"$ref": "#/components/schemas/UserMessage"},
{"$ref": "#/components/schemas/ReasoningMessage"},
{"$ref": "#/components/schemas/HiddenReasoningMessage"},
{"$ref": "#/components/schemas/ToolCallMessage"},
{"$ref": "#/components/schemas/ToolReturnMessage"},
{"$ref": "#/components/schemas/AssistantMessage"},
{"$ref": "#/components/schemas/ApprovalRequestMessage"},
{"$ref": "#/components/schemas/ApprovalResponseMessage"},
{"$ref": "#/components/schemas/LettaPing"},
{"$ref": "#/components/schemas/LettaStopReason"},
{"$ref": "#/components/schemas/LettaUsageStatistics"},
]
},
},
},
}
},
)
async def retrieve_stream(
run_id: str,
request: RetrieveStreamRequest = Body(None),
headers: HeaderParams = Depends(get_headers),
server: "SyncServer" = Depends(get_letta_server),
):
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
try:
job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor)
except NoResultFound:
raise HTTPException(status_code=404, detail="Run not found")
run = Run.from_job(job)
if not run.background:
raise HTTPException(status_code=400, detail="Run was not created in background mode, so it cannot be retrieved.")
if run.created_at < get_utc_time() - timedelta(hours=3):
raise HTTPException(status_code=410, detail="Run was created more than 3 hours ago, and is now expired.")
redis_client = await get_redis_client()
if isinstance(redis_client, NoopAsyncRedisClient):
raise HTTPException(
status_code=503,
detail=(
"Background streaming requires Redis to be running. "
"Please ensure Redis is properly configured. "
f"LETTA_REDIS_HOST: {settings.redis_host}, LETTA_REDIS_PORT: {settings.redis_port}"
),
)
stream = redis_sse_stream_generator(
redis_client=redis_client,
run_id=run_id,
starting_after=request.starting_after,
poll_interval=request.poll_interval,
batch_size=request.batch_size,
)
if settings.enable_cancellation_aware_streaming:
stream = cancellation_aware_stream_wrapper(
stream_generator=stream,
job_manager=server.job_manager,
job_id=run_id,
actor=actor,
)
if request.include_pings and settings.enable_keepalive:
stream = add_keepalive_to_stream(stream, keepalive_interval=settings.keepalive_interval)
return StreamingResponseWithStatusCode(
stream,
media_type="text/event-stream",
)