feat: add route to get steps for job (#1174)
This commit is contained in:
@@ -9,6 +9,7 @@ from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.step import Step
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -137,6 +138,54 @@ def retrieve_run_usage(
|
||||
raise HTTPException(status_code=404, detail=f"Run '{run_id}' not found")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/{run_id}/steps",
|
||||
response_model=List[Step],
|
||||
operation_id="list_run_steps",
|
||||
)
|
||||
async def list_run_steps(
|
||||
run_id: str,
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(100, description="Maximum number of messages to return"),
|
||||
order: str = Query(
|
||||
"desc", description="Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order."
|
||||
),
|
||||
):
|
||||
"""
|
||||
Get messages associated with a run with filtering options.
|
||||
|
||||
Args:
|
||||
run_id: ID of the run
|
||||
before: A cursor for use in pagination. `before` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, starting with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list.
|
||||
after: A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list.
|
||||
limit: Maximum number of steps to return
|
||||
order: Sort order by the created_at timestamp of the objects. asc for ascending order and desc for descending order.
|
||||
|
||||
Returns:
|
||||
A list of steps associated with the run.
|
||||
"""
|
||||
if order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
steps = server.job_manager.get_job_steps(
|
||||
job_id=run_id,
|
||||
actor=actor,
|
||||
limit=limit,
|
||||
before=before,
|
||||
after=after,
|
||||
ascending=(order == "asc"),
|
||||
)
|
||||
return steps
|
||||
except NoResultFound as e:
|
||||
raise HTTPException(status_code=404, detail=str(e))
|
||||
|
||||
|
||||
@router.delete("/{run_id}", response_model=Run, operation_id="delete_run")
|
||||
def delete_run(
|
||||
run_id: str,
|
||||
|
||||
@@ -13,12 +13,14 @@ from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.orm.sqlalchemy_base import AccessType
|
||||
from letta.orm.step import Step
|
||||
from letta.orm.step import Step as StepModel
|
||||
from letta.schemas.enums import JobStatus, MessageRole
|
||||
from letta.schemas.job import Job as PydanticJob
|
||||
from letta.schemas.job import JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
@@ -161,6 +163,51 @@ class JobManager:
|
||||
|
||||
return [message.to_pydantic() for message in messages]
|
||||
|
||||
@enforce_types
|
||||
def get_job_steps(
|
||||
self,
|
||||
job_id: str,
|
||||
actor: PydanticUser,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
ascending: bool = True,
|
||||
) -> List[PydanticStep]:
|
||||
"""
|
||||
Get all steps associated with a job.
|
||||
|
||||
Args:
|
||||
job_id: The ID of the job to get steps for
|
||||
actor: The user making the request
|
||||
before: Cursor for pagination
|
||||
after: Cursor for pagination
|
||||
limit: Maximum number of steps to return
|
||||
ascending: Optional flag to sort in ascending order
|
||||
|
||||
Returns:
|
||||
List of steps associated with the job
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
# Build filters
|
||||
filters = {}
|
||||
filters["job_id"] = job_id
|
||||
|
||||
# Get steps
|
||||
steps = StepModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
after=after,
|
||||
ascending=ascending,
|
||||
limit=limit,
|
||||
actor=actor,
|
||||
**filters,
|
||||
)
|
||||
|
||||
return [step.to_pydantic() for step in steps]
|
||||
|
||||
@enforce_types
|
||||
def add_message_to_job(self, job_id: str, message_id: str, actor: PydanticUser) -> None:
|
||||
"""
|
||||
@@ -312,6 +359,57 @@ class JobManager:
|
||||
|
||||
return messages
|
||||
|
||||
@enforce_types
|
||||
def get_step_messages(
|
||||
self,
|
||||
run_id: str,
|
||||
actor: PydanticUser,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
role: Optional[MessageRole] = None,
|
||||
ascending: bool = True,
|
||||
) -> List[LettaMessage]:
|
||||
"""
|
||||
Get steps associated with a job using cursor-based pagination.
|
||||
This is a wrapper around get_job_messages that provides cursor-based pagination.
|
||||
|
||||
Args:
|
||||
run_id: The ID of the run to get steps for
|
||||
actor: The user making the request
|
||||
before: Message ID to get messages after
|
||||
after: Message ID to get messages before
|
||||
limit: Maximum number of messages to return
|
||||
ascending: Whether to return messages in ascending order
|
||||
role: Optional role filter
|
||||
|
||||
Returns:
|
||||
List of Steps associated with the job
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
messages = self.get_job_messages(
|
||||
job_id=run_id,
|
||||
actor=actor,
|
||||
before=before,
|
||||
after=after,
|
||||
limit=limit,
|
||||
role=role,
|
||||
ascending=ascending,
|
||||
)
|
||||
|
||||
request_config = self._get_run_request_config(run_id)
|
||||
|
||||
messages = PydanticMessage.to_letta_messages_from_list(
|
||||
messages=messages,
|
||||
use_assistant_message=request_config["use_assistant_message"],
|
||||
assistant_message_tool_name=request_config["assistant_message_tool_name"],
|
||||
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
def _verify_job_access(
|
||||
self,
|
||||
session: Session,
|
||||
|
||||
@@ -3123,6 +3123,10 @@ def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_us
|
||||
assert usage_stats.prompt_tokens == 50
|
||||
assert usage_stats.total_tokens == 150
|
||||
|
||||
# get steps
|
||||
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
|
||||
assert len(steps) == 1
|
||||
|
||||
|
||||
def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_user):
|
||||
"""Test getting usage statistics for a job with no stats."""
|
||||
@@ -3136,6 +3140,10 @@ def test_job_usage_stats_get_no_stats(server: SyncServer, default_job, default_u
|
||||
assert usage_stats.prompt_tokens == 0
|
||||
assert usage_stats.total_tokens == 0
|
||||
|
||||
# get steps
|
||||
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_user):
|
||||
"""Test adding multiple usage statistics entries for a job."""
|
||||
@@ -3181,6 +3189,10 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u
|
||||
assert usage_stats.total_tokens == 450
|
||||
assert usage_stats.step_count == 2
|
||||
|
||||
# get steps
|
||||
steps = job_manager.get_job_steps(job_id=default_job.id, actor=default_user)
|
||||
assert len(steps) == 2
|
||||
|
||||
|
||||
def test_job_usage_stats_get_nonexistent_job(server: SyncServer, default_user):
|
||||
"""Test getting usage statistics for a nonexistent job."""
|
||||
|
||||
Reference in New Issue
Block a user