From 8eef166c3a9065f83e4641420b7ad1b49adb4f78 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 16 Sep 2025 14:27:00 -0700 Subject: [PATCH] feat: add `stop_reason` to runs (#2935) --- ...933666957_add_stop_reason_to_jobs_table.py | 28 ++++++++++ fern/openapi.json | 51 +++++++++++++++++++ letta/agents/letta_agent_v2.py | 1 + letta/orm/job.py | 4 +- letta/schemas/job.py | 10 ++++ letta/schemas/run.py | 2 + letta/server/rest_api/routers/v1/agents.py | 13 +++-- letta/server/rest_api/routers/v1/runs.py | 3 ++ letta/services/job_manager.py | 20 +++++++- tests/test_managers.py | 19 +++++++ 10 files changed, 144 insertions(+), 7 deletions(-) create mode 100644 alembic/versions/7f7933666957_add_stop_reason_to_jobs_table.py diff --git a/alembic/versions/7f7933666957_add_stop_reason_to_jobs_table.py b/alembic/versions/7f7933666957_add_stop_reason_to_jobs_table.py new file mode 100644 index 00000000..b138ab24 --- /dev/null +++ b/alembic/versions/7f7933666957_add_stop_reason_to_jobs_table.py @@ -0,0 +1,28 @@ +"""add stop_reason to jobs table + +Revision ID: 7f7933666957 +Revises: d06594144ef3 +Create Date: 2025-09-16 13:20:42.368007 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "7f7933666957" +down_revision: Union[str, None] = "d06594144ef3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add stop_reason column to jobs table + op.add_column("jobs", sa.Column("stop_reason", sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("jobs", "stop_reason") diff --git a/fern/openapi.json b/fern/openapi.json index 5b2cf3b8..e6945ab5 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -9477,6 +9477,24 @@ }, "description": "If True, filters for runs that were created in background mode." }, + { + "name": "stop_reason", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "$ref": "#/components/schemas/StopReasonType" + }, + { + "type": "null" + } + ], + "description": "Filter runs by stop reason.", + "title": "Stop Reason" + }, + "description": "Filter runs by stop reason." + }, { "name": "after", "in": "query", @@ -15376,6 +15394,17 @@ "title": "Completed At", "description": "The unix timestamp of when the job was completed." }, + "stop_reason": { + "anyOf": [ + { + "$ref": "#/components/schemas/StopReasonType" + }, + { + "type": "null" + } + ], + "description": "The reason why the job was stopped." + }, "metadata": { "anyOf": [ { @@ -22267,6 +22296,17 @@ "title": "Completed At", "description": "The unix timestamp of when the job was completed." }, + "stop_reason": { + "anyOf": [ + { + "$ref": "#/components/schemas/StopReasonType" + }, + { + "type": "null" + } + ], + "description": "The reason why the job was stopped." + }, "metadata": { "anyOf": [ { @@ -25509,6 +25549,17 @@ "title": "Completed At", "description": "The unix timestamp of when the job was completed." }, + "stop_reason": { + "anyOf": [ + { + "$ref": "#/components/schemas/StopReasonType" + }, + { + "type": "null" + } + ], + "description": "The reason why the run was stopped." + }, "metadata": { "anyOf": [ { diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 524f5d6b..b8814d5c 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -1226,6 +1226,7 @@ class LettaAgentV2(BaseAgentV2): new_status=JobStatus.failed if is_error else JobStatus.completed, actor=self.actor, metadata=job_update_metadata, + stop_reason=self.stop_reason.stop_reason if self.stop_reason else StopReasonType.error, ) if request_span: request_span.end() diff --git a/letta/orm/job.py b/letta/orm/job.py index 362e71f1..29c09c4b 100644 --- a/letta/orm/job.py +++ b/letta/orm/job.py @@ -1,13 +1,14 @@ from datetime import datetime from typing import TYPE_CHECKING, List, Optional -from sqlalchemy import JSON, BigInteger, ForeignKey, Index, String +from sqlalchemy import JSON, BigInteger, Boolean, ForeignKey, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import UserMixin from letta.orm.sqlalchemy_base import SqlalchemyBase from letta.schemas.enums import JobStatus, JobType from letta.schemas.job import Job as PydanticJob, LettaRequestConfig +from letta.schemas.letta_stop_reason import StopReasonType if TYPE_CHECKING: from letta.orm.job_messages import JobMessage @@ -28,6 +29,7 @@ class Job(SqlalchemyBase, UserMixin): status: Mapped[JobStatus] = mapped_column(String, default=JobStatus.created, doc="The current status of the job.") completed_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="The unix timestamp of when the job was completed.") + stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(String, nullable=True, doc="The reason why the job was stopped.") metadata_: Mapped[Optional[dict]] = mapped_column(JSON, doc="The metadata of the job.") job_type: Mapped[JobType] = mapped_column( String, diff --git a/letta/schemas/job.py b/letta/schemas/job.py index 257917a0..a0a5aebb 100644 --- a/letta/schemas/job.py +++ b/letta/schemas/job.py @@ -8,16 +8,26 @@ from letta.helpers.datetime_helpers import get_utc_time from letta.schemas.enums import JobStatus, JobType from letta.schemas.letta_base import OrmMetadataBase from letta.schemas.letta_message import MessageType +from letta.schemas.letta_stop_reason import StopReasonType class JobBase(OrmMetadataBase): __id_prefix__ = "job" status: JobStatus = Field(default=JobStatus.created, description="The status of the job.") created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.") + + # completion related completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.") + stop_reason: Optional[StopReasonType] = Field(None, description="The reason why the job was stopped.") + + # metadata metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.") job_type: JobType = Field(default=JobType.JOB, description="The type of the job.") + ## TODO: Run-specific fields + # background: Optional[bool] = Field(None, description="Whether the job was created in background mode.") + # agent_id: Optional[str] = Field(None, description="The agent associated with this job/run.") + callback_url: Optional[str] = Field(None, description="If set, POST to this URL when the job completes.") callback_sent_at: Optional[datetime] = Field(None, description="Timestamp when the callback was last attempted.") callback_status_code: Optional[int] = Field(None, description="HTTP status code returned by the callback endpoint.") diff --git a/letta/schemas/run.py b/letta/schemas/run.py index 11e05839..433552aa 100644 --- a/letta/schemas/run.py +++ b/letta/schemas/run.py @@ -4,6 +4,7 @@ from pydantic import Field from letta.schemas.enums import JobType from letta.schemas.job import Job, JobBase, LettaRequestConfig +from letta.schemas.letta_stop_reason import StopReasonType class RunBase(JobBase): @@ -29,6 +30,7 @@ class Run(RunBase): id: str = RunBase.generate_id_field() user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the run.") request_config: Optional[LettaRequestConfig] = Field(None, description="The request configuration for the run.") + stop_reason: Optional[StopReasonType] = Field(None, description="The reason why the run was stopped.") @classmethod def from_job(cls, job: Job) -> "Run": diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index f1d63a78..01d3ab40 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -38,6 +38,7 @@ from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse +from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.memory import ( ArchivalMemorySearchResponse, ArchivalMemorySearchResult, @@ -1192,6 +1193,7 @@ async def send_message( await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None) try: + result = None if agent_eligible and model_compatible: agent_loop = AgentLoop.load(agent_state=agent, actor=actor) result = await agent_loop.step( @@ -1229,11 +1231,17 @@ async def send_message( raise finally: if settings.track_agent_run: + if result: + stop_reason = result.stop_reason.stop_reason + else: + # NOTE: we could also consider this an error? + stop_reason = None await server.job_manager.safe_update_job_status_async( job_id=run.id, new_status=job_status, actor=actor, metadata=job_update_metadata, + stop_reason=stop_reason, ) @@ -1440,10 +1448,7 @@ async def send_message_streaming( finally: if settings.track_agent_run: await server.job_manager.safe_update_job_status_async( - job_id=run.id, - new_status=job_status, - actor=actor, - metadata=job_update_metadata, + job_id=run.id, new_status=job_status, actor=actor, metadata=job_update_metadata ) diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 86d305ea..8bfaaddf 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -10,6 +10,7 @@ from letta.orm.errors import NoResultFound from letta.schemas.enums import JobStatus, JobType from letta.schemas.letta_message import LettaMessageUnion from letta.schemas.letta_request import RetrieveStreamRequest +from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.openai.chat_completion_response import UsageStatistics from letta.schemas.run import Run from letta.schemas.step import Step @@ -31,6 +32,7 @@ def list_runs( server: "SyncServer" = Depends(get_letta_server), agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."), background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."), + stop_reason: Optional[StopReasonType] = Query(None, description="Filter runs by stop reason."), after: Optional[str] = Query(None, description="Cursor for pagination"), before: Optional[str] = Query(None, description="Cursor for pagination"), limit: Optional[int] = Query(50, description="Maximum number of runs to return"), @@ -54,6 +56,7 @@ def list_runs( before=before, after=after, ascending=False, + stop_reason=stop_reason, ) ] if agent_ids: diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index df25c064..4cac2c07 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -18,6 +18,7 @@ from letta.otel.tracing import log_event, trace_method from letta.schemas.enums import JobStatus, JobType, MessageRole from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig from letta.schemas.letta_message import LettaMessage +from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.message import Message as PydanticMessage from letta.schemas.run import Run as PydanticRun from letta.schemas.step import Step as PydanticStep @@ -207,7 +208,12 @@ class JobManager: @enforce_types @trace_method async def safe_update_job_status_async( - self, job_id: str, new_status: JobStatus, actor: PydanticUser, metadata: Optional[dict] = None + self, + job_id: str, + new_status: JobStatus, + actor: PydanticUser, + stop_reason: Optional[StopReasonType] = None, + metadata: Optional[dict] = None, ) -> bool: """ Safely update job status with state transition guards. @@ -217,7 +223,7 @@ class JobManager: True if update was successful, False if update was skipped due to invalid transition """ try: - job_update_builder = partial(JobUpdate, status=new_status) + job_update_builder = partial(JobUpdate, status=new_status, stop_reason=stop_reason) # If metadata is provided, merge it with existing metadata if metadata: @@ -268,6 +274,7 @@ class JobManager: statuses: Optional[List[JobStatus]] = None, job_type: JobType = JobType.JOB, ascending: bool = True, + stop_reason: Optional[StopReasonType] = None, ) -> List[PydanticJob]: """List all jobs with optional pagination and status filter.""" with db_registry.session() as session: @@ -277,6 +284,10 @@ class JobManager: if statuses: filter_kwargs["status"] = statuses + # Add stop_reason filter if provided + if stop_reason is not None: + filter_kwargs["stop_reason"] = stop_reason + jobs = JobModel.list( db_session=session, before=before, @@ -299,6 +310,7 @@ class JobManager: job_type: JobType = JobType.JOB, ascending: bool = True, source_id: Optional[str] = None, + stop_reason: Optional[StopReasonType] = None, ) -> List[PydanticJob]: """List all jobs with optional pagination and status filter.""" from sqlalchemy import and_, or_, select @@ -317,6 +329,10 @@ class JobManager: column = column.op("->>")("source_id") query = query.where(column == source_id) + # add stop_reason filter if provided + if stop_reason is not None: + query = query.where(JobModel.stop_reason == stop_reason) + # handle cursor-based pagination if before or after: # get cursor objects diff --git a/tests/test_managers.py b/tests/test_managers.py index 2c05ed0d..579d7924 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -9011,6 +9011,25 @@ async def test_list_jobs_filter_by_type(server: SyncServer, default_user, defaul assert jobs[0].id == run.id +@pytest.mark.asyncio +async def test_list_jobs_by_stop_reason(server: SyncServer, sarah_agent, default_user): + """Test listing jobs by stop reason.""" + + run_pydantic = PydanticRun( + user_id=default_user.id, + status=JobStatus.pending, + job_type=JobType.RUN, + stop_reason=StopReasonType.requires_approval, + ) + run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user) + assert run.stop_reason == StopReasonType.requires_approval + + # list jobs by stop reason + jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, stop_reason=StopReasonType.requires_approval) + assert len(jobs) == 1 + assert jobs[0].id == run.id + + async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user): """Test that job callbacks are properly dispatched when a job is completed.""" captured = {}