diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 0e5dff98..91a7dbf4 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -9,6 +9,7 @@ from letta.schemas.enums import JobStatus, MessageRole from letta.schemas.letta_message import LettaMessageUnion from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.run import Run +from letta.schemas.step import Step from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -137,6 +138,54 @@ def retrieve_run_usage( raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found") +@router.get( + "/{run_id}/steps", + response_model=List[Step], + operation_id="list_run_steps", +) +async def list_run_steps( + run_id: str, + server: "SyncServer" = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), + before: Optional[str] = Query(None, description="Cursor for pagination"), + after: Optional[str] = Query(None, description="Cursor for pagination"), + limit: Optional[int] = Query(100, description="Maximum number of messages to return"), + order: str = Query( + "desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order." + ), +): + """ + Get messages associated with a run with filtering options. + + Args: + run_id: ID of the run + before: A cursor for use in pagination. `before` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, starting with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list. + after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list. + limit: Maximum number of steps to return + order: Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order. + + Returns: + A list of steps associated with the run. + """ + if order not in ["asc", "desc"]: + raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'") + + actor = server.user_manager.get_user_or_default(user_id=actor_id) + + try: + steps = server.job_manager.get_job_steps( + job_id=run_id, + actor=actor, + limit=limit, + before=before, + after=after, + ascending=(order == "asc"), + ) + return steps + except NoResultFound as e: + raise HTTPException(status_code=404, detail=str(e)) + + @router.delete("/{run_id}", response_model=Run, operation_id="delete_run") def delete_run( run_id: str, diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 5f3f7fd0..7a4c7f3c 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -13,12 +13,14 @@ from letta.orm.job_messages import JobMessage from letta.orm.message import Message as MessageModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step +from letta.orm.step import Step as StepModel from letta.schemas.enums import JobStatus, MessageRole from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessage from letta.schemas.message import Message as PydanticMessage from letta.schemas.run import Run as PydanticRun +from letta.schemas.step import Step as PydanticStep from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types @@ -161,6 +163,51 @@ class JobManager: return [message.to_pydantic() for message in messages] + @enforce_types + def get_job_steps( + self, + job_id: str, + actor: PydanticUser, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 100, + ascending: bool = True, + ) -> List[PydanticStep]: + """ + Get all steps associated with a job. + + Args: + job_id: The ID of the job to get steps for + actor: The user making the request + before: Cursor for pagination + after: Cursor for pagination + limit: Maximum number of steps to return + ascending: Optional flag to sort in ascending order + + Returns: + List of steps associated with the job + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + with self.session_maker() as session: + # Build filters + filters = {} + filters["job_id"] = job_id + + # Get steps + steps = StepModel.list( + db_session=session, + before=before, + after=after, + ascending=ascending, + limit=limit, + actor=actor, + **filters, + ) + + return [step.to_pydantic() for step in steps] + @enforce_types def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None: """ @@ -312,6 +359,57 @@ class JobManager: return messages + @enforce_types + def get_step_messages( + self, + run_id: str, + actor: PydanticUser, + before: Optional[str] = None, + after: Optional[str] = None, + limit: Optional[int] = 100, + role: Optional[MessageRole] = None, + ascending: bool = True, + ) -> List[LettaMessage]: + """ + Get steps associated with a job using cursor-based pagination. + This is a wrapper around get_job_messages that provides cursor-based pagination. + + Args: + run_id: The ID of the run to get steps for + actor: The user making the request + before: Message ID to get messages after + after: Message ID to get messages before + limit: Maximum number of messages to return + ascending: Whether to return messages in ascending order + role: Optional role filter + + Returns: + List of Steps associated with the job + + Raises: + NoResultFound: If the job does not exist or user does not have access + """ + messages = self.get_job_messages( + job_id=run_id, + actor=actor, + before=before, + after=after, + limit=limit, + role=role, + ascending=ascending, + ) + + request_config = self._get_run_request_config(run_id) + + messages = PydanticMessage.to_letta_messages_from_list( + messages=messages, + use_assistant_message=request_config["use_assistant_message"], + assistant_message_tool_name=request_config["assistant_message_tool_name"], + assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"], + ) + + return messages + def _verify_job_access( self, session: Session, diff --git a/tests/test_managers.py b/tests/test_managers.py index 4d207d8a..52206d72 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3123,6 +3123,10 @@ def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_us assert usage_stats.prompt_tokens == 50 assert usage_stats.total_tokens == 150 + # get steps + steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user) + assert len(steps) == 1 + def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_user): """Test getting usage statistics for a job with no stats.""" @@ -3136,6 +3140,10 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u assert usage_stats.prompt_tokens == 0 assert usage_stats.total_tokens == 0 + # get steps + steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user) + assert len(steps) == 0 + def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user): """Test adding multiple usage statistics entries for a job.""" @@ -3181,6 +3189,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u assert usage_stats.total_tokens == 450 assert usage_stats.step_count == 2 + # get steps + steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user) + assert len(steps) == 2 + def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user): """Test getting usage statistics for a nonexistent job."""