feat: add stop_reason to runs (#2935)
This commit is contained in:
@@ -0,0 +1,28 @@
|
||||
"""add stop_reason to jobs table
|
||||
|
||||
Revision ID: 7f7933666957
|
||||
Revises: d06594144ef3
|
||||
Create Date: 2025-09-16 13:20:42.368007
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "7f7933666957"
|
||||
down_revision: Union[str, None] = "d06594144ef3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add stop_reason column to jobs table
|
||||
op.add_column("jobs", sa.Column("stop_reason", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("jobs", "stop_reason")
|
||||
@@ -9477,6 +9477,24 @@
|
||||
},
|
||||
"description": "If True, filters for runs that were created in background mode."
|
||||
},
|
||||
{
|
||||
"name": "stop_reason",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/StopReasonType"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Filter runs by stop reason.",
|
||||
"title": "Stop Reason"
|
||||
},
|
||||
"description": "Filter runs by stop reason."
|
||||
},
|
||||
{
|
||||
"name": "after",
|
||||
"in": "query",
|
||||
@@ -15376,6 +15394,17 @@
|
||||
"title": "Completed At",
|
||||
"description": "The unix timestamp of when the job was completed."
|
||||
},
|
||||
"stop_reason": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/StopReasonType"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The reason why the job was stopped."
|
||||
},
|
||||
"metadata": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -22267,6 +22296,17 @@
|
||||
"title": "Completed At",
|
||||
"description": "The unix timestamp of when the job was completed."
|
||||
},
|
||||
"stop_reason": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/StopReasonType"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The reason why the job was stopped."
|
||||
},
|
||||
"metadata": {
|
||||
"anyOf": [
|
||||
{
|
||||
@@ -25509,6 +25549,17 @@
|
||||
"title": "Completed At",
|
||||
"description": "The unix timestamp of when the job was completed."
|
||||
},
|
||||
"stop_reason": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/StopReasonType"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "The reason why the run was stopped."
|
||||
},
|
||||
"metadata": {
|
||||
"anyOf": [
|
||||
{
|
||||
|
||||
@@ -1226,6 +1226,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
new_status=JobStatus.failed if is_error else JobStatus.completed,
|
||||
actor=self.actor,
|
||||
metadata=job_update_metadata,
|
||||
stop_reason=self.stop_reason.stop_reason if self.stop_reason else StopReasonType.error,
|
||||
)
|
||||
if request_span:
|
||||
request_span.end()
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, BigInteger, ForeignKey, Index, String
|
||||
from sqlalchemy import JSON, BigInteger, Boolean, ForeignKey, Index, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import UserMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.enums import JobStatus, JobType
|
||||
from letta.schemas.job import Job as PydanticJob, LettaRequestConfig
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.job_messages import JobMessage
|
||||
@@ -28,6 +29,7 @@ class Job(SqlalchemyBase, UserMixin):
|
||||
|
||||
status: Mapped[JobStatus] = mapped_column(String, default=JobStatus.created, doc="The current status of the job.")
|
||||
completed_at: Mapped[Optional[datetime]] = mapped_column(nullable=True, doc="The unix timestamp of when the job was completed.")
|
||||
stop_reason: Mapped[Optional[StopReasonType]] = mapped_column(String, nullable=True, doc="The reason why the job was stopped.")
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, doc="The metadata of the job.")
|
||||
job_type: Mapped[JobType] = mapped_column(
|
||||
String,
|
||||
|
||||
@@ -8,16 +8,26 @@ from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.enums import JobStatus, JobType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
|
||||
class JobBase(OrmMetadataBase):
|
||||
__id_prefix__ = "job"
|
||||
status: JobStatus = Field(default=JobStatus.created, description="The status of the job.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the job was created.")
|
||||
|
||||
# completion related
|
||||
completed_at: Optional[datetime] = Field(None, description="The unix timestamp of when the job was completed.")
|
||||
stop_reason: Optional[StopReasonType] = Field(None, description="The reason why the job was stopped.")
|
||||
|
||||
# metadata
|
||||
metadata: Optional[dict] = Field(None, validation_alias="metadata_", description="The metadata of the job.")
|
||||
job_type: JobType = Field(default=JobType.JOB, description="The type of the job.")
|
||||
|
||||
## TODO: Run-specific fields
|
||||
# background: Optional[bool] = Field(None, description="Whether the job was created in background mode.")
|
||||
# agent_id: Optional[str] = Field(None, description="The agent associated with this job/run.")
|
||||
|
||||
callback_url: Optional[str] = Field(None, description="If set, POST to this URL when the job completes.")
|
||||
callback_sent_at: Optional[datetime] = Field(None, description="Timestamp when the callback was last attempted.")
|
||||
callback_status_code: Optional[int] = Field(None, description="HTTP status code returned by the callback endpoint.")
|
||||
|
||||
@@ -4,6 +4,7 @@ from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import JobType
|
||||
from letta.schemas.job import Job, JobBase, LettaRequestConfig
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
|
||||
|
||||
class RunBase(JobBase):
|
||||
@@ -29,6 +30,7 @@ class Run(RunBase):
|
||||
id: str = RunBase.generate_id_field()
|
||||
user_id: Optional[str] = Field(None, description="The unique identifier of the user associated with the run.")
|
||||
request_config: Optional[LettaRequestConfig] = Field(None, description="The request configuration for the run.")
|
||||
stop_reason: Optional[StopReasonType] = Field(None, description="The reason why the run was stopped.")
|
||||
|
||||
@classmethod
|
||||
def from_job(cls, job: Job) -> "Run":
|
||||
|
||||
@@ -38,6 +38,7 @@ from letta.schemas.job import JobStatus, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion, LettaMessageUpdateUnion, MessageType
|
||||
from letta.schemas.letta_request import LettaAsyncRequest, LettaRequest, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.memory import (
|
||||
ArchivalMemorySearchResponse,
|
||||
ArchivalMemorySearchResult,
|
||||
@@ -1192,6 +1193,7 @@ async def send_message(
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
|
||||
|
||||
try:
|
||||
result = None
|
||||
if agent_eligible and model_compatible:
|
||||
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
|
||||
result = await agent_loop.step(
|
||||
@@ -1229,11 +1231,17 @@ async def send_message(
|
||||
raise
|
||||
finally:
|
||||
if settings.track_agent_run:
|
||||
if result:
|
||||
stop_reason = result.stop_reason.stop_reason
|
||||
else:
|
||||
# NOTE: we could also consider this an error?
|
||||
stop_reason = None
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
metadata=job_update_metadata,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
|
||||
|
||||
@@ -1440,10 +1448,7 @@ async def send_message_streaming(
|
||||
finally:
|
||||
if settings.track_agent_run:
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
metadata=job_update_metadata,
|
||||
job_id=run.id, new_status=job_status, actor=actor, metadata=job_update_metadata
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.enums import JobStatus, JobType
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import RetrieveStreamRequest
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.step import Step
|
||||
@@ -31,6 +32,7 @@ def list_runs(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
agent_ids: Optional[List[str]] = Query(None, description="The unique identifier of the agent associated with the run."),
|
||||
background: Optional[bool] = Query(None, description="If True, filters for runs that were created in background mode."),
|
||||
stop_reason: Optional[StopReasonType] = Query(None, description="Filter runs by stop reason."),
|
||||
after: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
before: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(50, description="Maximum number of runs to return"),
|
||||
@@ -54,6 +56,7 @@ def list_runs(
|
||||
before=before,
|
||||
after=after,
|
||||
ascending=False,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
]
|
||||
if agent_ids:
|
||||
|
||||
@@ -18,6 +18,7 @@ from letta.otel.tracing import log_event, trace_method
|
||||
from letta.schemas.enums import JobStatus, JobType, MessageRole
|
||||
from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.schemas.step import Step as PydanticStep
|
||||
@@ -207,7 +208,12 @@ class JobManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def safe_update_job_status_async(
|
||||
self, job_id: str, new_status: JobStatus, actor: PydanticUser, metadata: Optional[dict] = None
|
||||
self,
|
||||
job_id: str,
|
||||
new_status: JobStatus,
|
||||
actor: PydanticUser,
|
||||
stop_reason: Optional[StopReasonType] = None,
|
||||
metadata: Optional[dict] = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Safely update job status with state transition guards.
|
||||
@@ -217,7 +223,7 @@ class JobManager:
|
||||
True if update was successful, False if update was skipped due to invalid transition
|
||||
"""
|
||||
try:
|
||||
job_update_builder = partial(JobUpdate, status=new_status)
|
||||
job_update_builder = partial(JobUpdate, status=new_status, stop_reason=stop_reason)
|
||||
|
||||
# If metadata is provided, merge it with existing metadata
|
||||
if metadata:
|
||||
@@ -268,6 +274,7 @@ class JobManager:
|
||||
statuses: Optional[List[JobStatus]] = None,
|
||||
job_type: JobType = JobType.JOB,
|
||||
ascending: bool = True,
|
||||
stop_reason: Optional[StopReasonType] = None,
|
||||
) -> List[PydanticJob]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
with db_registry.session() as session:
|
||||
@@ -277,6 +284,10 @@ class JobManager:
|
||||
if statuses:
|
||||
filter_kwargs["status"] = statuses
|
||||
|
||||
# Add stop_reason filter if provided
|
||||
if stop_reason is not None:
|
||||
filter_kwargs["stop_reason"] = stop_reason
|
||||
|
||||
jobs = JobModel.list(
|
||||
db_session=session,
|
||||
before=before,
|
||||
@@ -299,6 +310,7 @@ class JobManager:
|
||||
job_type: JobType = JobType.JOB,
|
||||
ascending: bool = True,
|
||||
source_id: Optional[str] = None,
|
||||
stop_reason: Optional[StopReasonType] = None,
|
||||
) -> List[PydanticJob]:
|
||||
"""List all jobs with optional pagination and status filter."""
|
||||
from sqlalchemy import and_, or_, select
|
||||
@@ -317,6 +329,10 @@ class JobManager:
|
||||
column = column.op("->>")("source_id")
|
||||
query = query.where(column == source_id)
|
||||
|
||||
# add stop_reason filter if provided
|
||||
if stop_reason is not None:
|
||||
query = query.where(JobModel.stop_reason == stop_reason)
|
||||
|
||||
# handle cursor-based pagination
|
||||
if before or after:
|
||||
# get cursor objects
|
||||
|
||||
@@ -9011,6 +9011,25 @@ async def test_list_jobs_filter_by_type(server: SyncServer, default_user, defaul
|
||||
assert jobs[0].id == run.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_jobs_by_stop_reason(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test listing jobs by stop reason."""
|
||||
|
||||
run_pydantic = PydanticRun(
|
||||
user_id=default_user.id,
|
||||
status=JobStatus.pending,
|
||||
job_type=JobType.RUN,
|
||||
stop_reason=StopReasonType.requires_approval,
|
||||
)
|
||||
run = await server.job_manager.create_job_async(pydantic_job=run_pydantic, actor=default_user)
|
||||
assert run.stop_reason == StopReasonType.requires_approval
|
||||
|
||||
# list jobs by stop reason
|
||||
jobs = await server.job_manager.list_jobs_async(actor=default_user, job_type=JobType.RUN, stop_reason=StopReasonType.requires_approval)
|
||||
assert len(jobs) == 1
|
||||
assert jobs[0].id == run.id
|
||||
|
||||
|
||||
async def test_e2e_job_callback(monkeypatch, server: SyncServer, default_user):
|
||||
"""Test that job callbacks are properly dispatched when a job is completed."""
|
||||
captured = {}
|
||||
|
||||
Reference in New Issue
Block a user