feat: add route to get steps for job (#1174)

This commit is contained in:
cthomas
2025-03-02 17:43:18 -08:00
committed by GitHub
parent 61a15cebe8
commit cfdc2ad417
3 changed files with 159 additions and 0 deletions

View File

@@ -9,6 +9,7 @@ from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.run import Run
from letta.schemas.step import Step
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
@@ -137,6 +138,54 @@ def retrieve_run_usage(
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
@router.get(
"/{run_id}/steps",
response_model=List[Step],
operation_id="list_run_steps",
)
async def list_run_steps(
run_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
before: Optional[str] = Query(None, description="Cursor for pagination"),
after: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
order: str = Query(
"desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
),
):
"""
Get messages associated with a run with filtering options.
Args:
run_id: ID of the run
before: A cursor for use in pagination. `before` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, starting with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
limit: Maximum number of steps to return
order: Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
Returns:
A list of steps associated with the run.
"""
if order not in ["asc", "desc"]:
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
actor = server.user_manager.get_user_or_default(user_id=actor_id)
try:
steps = server.job_manager.get_job_steps(
job_id=run_id,
actor=actor,
limit=limit,
before=before,
after=after,
ascending=(order == "asc"),
)
return steps
except NoResultFound as e:
raise HTTPException(status_code=404, detail=str(e))
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
def delete_run(
run_id: str,

View File

@@ -13,12 +13,14 @@ from letta.orm.job_messages import JobMessage
from letta.orm.message import Message as MessageModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step
from letta.orm.step import Step as StepModel
from letta.schemas.enums import JobStatus, MessageRole
from letta.schemas.job import Job as PydanticJob
from letta.schemas.job import JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.run import Run as PydanticRun
from letta.schemas.step import Step as PydanticStep
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
@@ -161,6 +163,51 @@ class JobManager:
return [message.to_pydantic() for message in messages]
@enforce_types
def get_job_steps(
self,
job_id: str,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 100,
ascending: bool = True,
) -> List[PydanticStep]:
"""
Get all steps associated with a job.
Args:
job_id: The ID of the job to get steps for
actor: The user making the request
before: Cursor for pagination
after: Cursor for pagination
limit: Maximum number of steps to return
ascending: Optional flag to sort in ascending order
Returns:
List of steps associated with the job
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
with self.session_maker() as session:
# Build filters
filters = {}
filters["job_id"] = job_id
# Get steps
steps = StepModel.list(
db_session=session,
before=before,
after=after,
ascending=ascending,
limit=limit,
actor=actor,
**filters,
)
return [step.to_pydantic() for step in steps]
@enforce_types
def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None:
"""
@@ -312,6 +359,57 @@ class JobManager:
return messages
@enforce_types
def get_step_messages(
self,
run_id: str,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 100,
role: Optional[MessageRole] = None,
ascending: bool = True,
) -> List[LettaMessage]:
"""
Get steps associated with a job using cursor-based pagination.
This is a wrapper around get_job_messages that provides cursor-based pagination.
Args:
run_id: The ID of the run to get steps for
actor: The user making the request
before: Message ID to get messages after
after: Message ID to get messages before
limit: Maximum number of messages to return
ascending: Whether to return messages in ascending order
role: Optional role filter
Returns:
List of Steps associated with the job
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
messages = self.get_job_messages(
job_id=run_id,
actor=actor,
before=before,
after=after,
limit=limit,
role=role,
ascending=ascending,
)
request_config = self._get_run_request_config(run_id)
messages = PydanticMessage.to_letta_messages_from_list(
messages=messages,
use_assistant_message=request_config["use_assistant_message"],
assistant_message_tool_name=request_config["assistant_message_tool_name"],
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
)
return messages
def _verify_job_access(
self,
session: Session,

View File

@@ -3123,6 +3123,10 @@ def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_us
assert usage_stats.prompt_tokens == 50
assert usage_stats.total_tokens == 150
# get steps
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
assert len(steps) == 1
def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_user):
"""Test getting usage statistics for a job with no stats."""
@@ -3136,6 +3140,10 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
assert usage_stats.prompt_tokens == 0
assert usage_stats.total_tokens == 0
# get steps
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
assert len(steps) == 0
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
"""Test adding multiple usage statistics entries for a job."""
@@ -3181,6 +3189,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
assert usage_stats.total_tokens == 450
assert usage_stats.step_count == 2
# get steps
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
assert len(steps) == 2
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
"""Test getting usage statistics for a nonexistent job."""