Files
letta-server/letta/services/job_manager.py
cthomas 9a95a8f976 fix: duplicate session commit in step logging (#7512)
* fix: duplicate session commit in step logging

* update all callsites
2026-01-12 10:57:19 -08:00

605 lines
25 KiB
Python

from functools import partial, reduce
from operator import add
from typing import List, Literal, Optional, Union
from httpx import AsyncClient, post
from sqlalchemy import select
from sqlalchemy.orm import Session
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.job import Job as JobModel
from letta.orm.message import Message as MessageModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step, Step as StepModel
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import JobStatus, JobType, MessageRole, PrimitiveType
from letta.schemas.job import BatchJob as PydanticBatchJob, Job as PydanticJob, JobUpdate, LettaRequestConfig
from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.run import Run as PydanticRun
from letta.schemas.step import Step as PydanticStep
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
class JobManager:
"""Manager class to handle business logic related to Jobs."""
@enforce_types
@trace_method
async def create_job_async(
self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser
) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]:
"""Create a new job based on the JobCreate schema."""
async with db_registry.async_session() as session:
# Associate the job with the user
pydantic_job.user_id = actor.id
# Get agent_id if present
agent_id = getattr(pydantic_job, "agent_id", None)
# Verify agent exists before creating the job
if agent_id:
await validate_agent_exists_async(session, agent_id, actor)
job_data = pydantic_job.model_dump(to_orm=True)
# Remove agent_id from job_data as it's not a field in the Job ORM model
job_data.pop("agent_id", None)
job = JobModel(**job_data)
job.organization_id = actor.organization_id
job = await job.create_async(session, actor=actor, no_commit=True, no_refresh=True) # Save job in the database
# context manager now handles commits
# await session.commit()
# Convert to pydantic first, then add agent_id if needed
result = super(JobModel, job).to_pydantic()
# Add back the agent_id field to the result if it was present
if agent_id and isinstance(pydantic_job, PydanticRun):
result.agent_id = agent_id
return result
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
@trace_method
async def update_job_by_id_async(
self, job_id: str, job_update: JobUpdate, actor: PydanticUser, safe_update: bool = False
) -> PydanticJob:
"""Update a job by its ID with the given JobUpdate object asynchronously."""
# First check if we need to dispatch a callback
needs_callback = False
callback_url = None
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
# Safely update job status with state transition guards: Created -> Pending -> Running --> <Terminal>
if safe_update:
current_status = JobStatus(job.status)
if not any(
(
job_update.status.is_terminal and not current_status.is_terminal,
current_status == JobStatus.created and job_update.status != JobStatus.created,
current_status == JobStatus.pending and job_update.status == JobStatus.running,
)
):
logger.error(f"Invalid job status transition from {current_status} to {job_update.status} for job {job_id}")
raise ValueError(f"Invalid job status transition from {current_status} to {job_update.status}")
# Check if we'll need to dispatch callback (only if not already completed)
not_completed_before = not bool(job.completed_at)
if job_update.status in {JobStatus.completed, JobStatus.failed} and not_completed_before and job.callback_url:
needs_callback = True
callback_url = job.callback_url
# Update job attributes with only the fields that were explicitly set
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Automatically update the completion timestamp if status is set to 'completed'
for key, value in update_data.items():
# Ensure completed_at is timezone-naive for database compatibility
if key == "completed_at" and value is not None and hasattr(value, "replace"):
value = value.replace(tzinfo=None)
setattr(job, key, value)
# If we are updating the job to a terminal state
if job_update.status in {JobStatus.completed, JobStatus.failed}:
logger.info(f"Current job completed at: {job.completed_at}")
job.completed_at = get_utc_time().replace(tzinfo=None)
# Save the updated job to the database first
job = await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
# Get the updated metadata for callback
final_metadata = job.metadata_
result = job.to_pydantic()
# context manager now handles commits
# await session.commit()
# Dispatch callback outside of database session if needed
if needs_callback:
callback_info = {
"job_id": job_id,
"callback_url": callback_url,
"status": job_update.status,
"completed_at": get_utc_time().replace(tzinfo=None),
"metadata": final_metadata,
}
callback_result = await self._dispatch_callback_async(callback_info)
# Update callback status in a separate transaction
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.callback_sent_at = callback_result["callback_sent_at"]
job.callback_status_code = callback_result.get("callback_status_code")
job.callback_error = callback_result.get("callback_error")
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
result = job.to_pydantic()
# context manager now handles commits
# await session.commit()
return result
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
@trace_method
async def safe_update_job_status_async(
self,
job_id: str,
new_status: JobStatus,
actor: PydanticUser,
stop_reason: Optional[StopReasonType] = None,
metadata: Optional[dict] = None,
) -> bool:
"""
Safely update job status with state transition guards.
Created -> Pending -> Running --> <Terminal>
Returns:
True if update was successful, False if update was skipped due to invalid transition
"""
try:
job_update_builder = partial(JobUpdate, status=new_status, stop_reason=stop_reason)
# If metadata is provided, merge it with existing metadata
if metadata:
# Get the current job to access existing metadata
current_job = await self.get_job_by_id_async(job_id=job_id, actor=actor)
merged_metadata = {}
if current_job.metadata:
merged_metadata.update(current_job.metadata)
merged_metadata.update(metadata)
job_update_builder = partial(job_update_builder, metadata=merged_metadata)
if new_status.is_terminal:
job_update_builder = partial(job_update_builder, completed_at=get_utc_time())
await self.update_job_by_id_async(job_id=job_id, job_update=job_update_builder(), actor=actor)
return True
except Exception as e:
logger.error(f"Failed to safely update job status for job {job_id}: {e}")
return False
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
@trace_method
async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Fetch a job by its ID asynchronously."""
async with db_registry.async_session() as session:
# Retrieve job by ID using the Job model's read method
job = await JobModel.read_async(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER)
return job.to_pydantic()
@enforce_types
@trace_method
async def list_jobs_async(
self,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 50,
statuses: Optional[List[JobStatus]] = None,
job_type: JobType = JobType.JOB,
ascending: bool = True,
source_id: Optional[str] = None,
stop_reason: Optional[StopReasonType] = None,
# agent_id: Optional[str] = None,
agent_ids: Optional[List[str]] = None,
background: Optional[bool] = None,
) -> List[PydanticJob]:
"""List all jobs with optional pagination and status filter."""
from sqlalchemy import and_, or_, select
async with db_registry.async_session() as session:
# build base query
query = select(JobModel).where(JobModel.user_id == actor.id).where(JobModel.job_type == job_type)
# add status filter if provided
if statuses:
query = query.where(JobModel.status.in_(statuses))
# add stop_reason filter if provided
if stop_reason is not None:
query = query.where(JobModel.stop_reason == stop_reason)
# add background filter if provided
if background is not None:
query = query.where(JobModel.background == background)
# add source_id filter if provided
if source_id:
column = getattr(JobModel, "metadata_")
column = column.op("->>")("source_id")
query = query.where(column == source_id)
# handle cursor-based pagination
if before or after:
# get cursor objects
before_obj = None
after_obj = None
if before:
before_obj = await session.get(JobModel, before)
if not before_obj:
raise ValueError(f"Job with id {before} not found")
if after:
after_obj = await session.get(JobModel, after)
if not after_obj:
raise ValueError(f"Job with id {after} not found")
# validate cursors
if before_obj and after_obj:
if before_obj.created_at < after_obj.created_at:
raise ValueError("'before' reference must be later than 'after' reference")
elif before_obj.created_at == after_obj.created_at and before_obj.id < after_obj.id:
raise ValueError("'before' reference must be later than 'after' reference")
# build cursor conditions
conditions = []
if before_obj:
# records before this cursor (older)
before_timestamp = before_obj.created_at
conditions.append(
or_(
JobModel.created_at < before_timestamp,
and_(JobModel.created_at == before_timestamp, JobModel.id < before_obj.id),
)
)
if after_obj:
# records after this cursor (newer)
after_timestamp = after_obj.created_at
conditions.append(
or_(JobModel.created_at > after_timestamp, and_(JobModel.created_at == after_timestamp, JobModel.id > after_obj.id))
)
if conditions:
query = query.where(and_(*conditions))
# apply ordering
if ascending:
query = query.order_by(JobModel.created_at.asc(), JobModel.id.asc())
else:
query = query.order_by(JobModel.created_at.desc(), JobModel.id.desc())
# apply limit
if limit:
query = query.limit(limit)
# execute query
result = await session.execute(query)
jobs = result.scalars().all()
return [job.to_pydantic() for job in jobs]
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
@trace_method
async def delete_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
"""Delete a job by its ID."""
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor)
await job.hard_delete_async(db_session=session, actor=actor)
return job.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
@trace_method
async def get_run_messages(
self,
run_id: str,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 100,
role: Optional[MessageRole] = None,
ascending: bool = True,
) -> List[LettaMessage]:
"""
Get messages associated with a job using cursor-based pagination.
This is a wrapper around get_job_messages that provides cursor-based pagination.
Args:
job_id: The ID of the job to get messages for
actor: The user making the request
before: Message ID to get messages after
after: Message ID to get messages before
limit: Maximum number of messages to return
ascending: Whether to return messages in ascending order
role: Optional role filter
Returns:
List of LettaMessages associated with the job
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
messages = await self.get_job_messages(
job_id=run_id,
actor=actor,
before=before,
after=after,
limit=limit,
role=role,
ascending=ascending,
)
request_config = await self._get_run_request_config(run_id)
print("request_config", request_config)
messages = PydanticMessage.to_letta_messages_from_list(
messages=messages,
use_assistant_message=request_config["use_assistant_message"],
assistant_message_tool_name=request_config["assistant_message_tool_name"],
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
reverse=not ascending,
)
if request_config["include_return_message_types"]:
messages = [msg for msg in messages if msg.message_type in request_config["include_return_message_types"]]
return messages
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
@trace_method
async def get_step_messages(
self,
run_id: str,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 100,
role: Optional[MessageRole] = None,
ascending: bool = True,
) -> List[LettaMessage]:
"""
Get steps associated with a job using cursor-based pagination.
This is a wrapper around get_job_messages that provides cursor-based pagination.
Args:
run_id: The ID of the run to get steps for
actor: The user making the request
before: Message ID to get messages after
after: Message ID to get messages before
limit: Maximum number of messages to return
ascending: Whether to return messages in ascending order
role: Optional role filter
Returns:
List of Steps associated with the job
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
messages = await self.get_job_messages(
job_id=run_id,
actor=actor,
before=before,
after=after,
limit=limit,
role=role,
ascending=ascending,
)
request_config = await self._get_run_request_config(run_id)
messages = PydanticMessage.to_letta_messages_from_list(
messages=messages,
use_assistant_message=request_config["use_assistant_message"],
assistant_message_tool_name=request_config["assistant_message_tool_name"],
assistant_message_tool_kwarg=request_config["assistant_message_tool_kwarg"],
)
return messages
async def _verify_job_access_async(
self,
session: Session,
job_id: str,
actor: PydanticUser,
access: List[Literal["read", "write", "delete"]] = ["read"],
) -> JobModel:
"""
Verify that a job exists and the user has the required access.
Args:
session: The database session
job_id: The ID of the job to verify
actor: The user making the request
Returns:
The job if it exists and the user has access
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
job_query = select(JobModel).where(JobModel.id == job_id)
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
result = await session.execute(job_query)
job = result.scalar_one_or_none()
if not job:
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
return job
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def record_ttft(self, job_id: str, ttft_ns: int, actor: PydanticUser) -> None:
"""Record time to first token for a run"""
try:
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.ttft_ns = ttft_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
# context manager now handles commits
# await session.commit()
except Exception as e:
logger.warning(f"Failed to record TTFT for job {job_id}: {e}")
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
async def record_response_duration(self, job_id: str, total_duration_ns: int, actor: PydanticUser) -> None:
"""Record total response duration for a run"""
try:
async with db_registry.async_session() as session:
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
job.total_duration_ns = total_duration_ns
await job.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
# context manager now handles commits
# await session.commit()
except Exception as e:
logger.warning(f"Failed to record response duration for job {job_id}: {e}")
@trace_method
def _dispatch_callback_sync(self, callback_info: dict) -> dict:
"""
POST a standard JSON payload to callback_url and return callback status.
"""
payload = {
"job_id": callback_info["job_id"],
"status": callback_info["status"],
"completed_at": callback_info["completed_at"].isoformat() if callback_info["completed_at"] else None,
"metadata": callback_info["metadata"],
}
callback_sent_at = get_utc_time().replace(tzinfo=None)
result = {"callback_sent_at": callback_sent_at}
try:
log_event("POST callback dispatched", payload)
resp = post(callback_info["callback_url"], json=payload, timeout=5.0)
log_event("POST callback finished")
result["callback_status_code"] = resp.status_code
except Exception as e:
error_message = f"Failed to dispatch callback for job {callback_info['job_id']} to {callback_info['callback_url']}: {e!s}"
logger.error(error_message)
result["callback_error"] = error_message
# Continue silently - callback failures should not affect job completion
finally:
return result
@trace_method
async def _dispatch_callback_async(self, callback_info: dict) -> dict:
"""
POST a standard JSON payload to callback_url and return callback status asynchronously.
"""
payload = {
"job_id": callback_info["job_id"],
"status": callback_info["status"],
"completed_at": callback_info["completed_at"].isoformat() if callback_info["completed_at"] else None,
"metadata": callback_info["metadata"],
}
callback_sent_at = get_utc_time().replace(tzinfo=None)
result = {"callback_sent_at": callback_sent_at}
try:
async with AsyncClient() as client:
log_event("POST callback dispatched", payload)
resp = await client.post(callback_info["callback_url"], json=payload, timeout=5.0)
log_event("POST callback finished")
result["callback_status_code"] = resp.status_code
except Exception as e:
error_message = f"Failed to dispatch callback for job {callback_info['job_id']} to {callback_info['callback_url']}: {e!s}"
logger.error(error_message)
result["callback_error"] = error_message
# Continue silently - callback failures should not affect job completion
finally:
return result
@enforce_types
@raise_on_invalid_id(param_name="job_id", expected_prefix=PrimitiveType.JOB)
@trace_method
async def get_job_steps(
self,
job_id: str,
actor: PydanticUser,
before: Optional[str] = None,
after: Optional[str] = None,
limit: Optional[int] = 100,
ascending: bool = True,
) -> List[PydanticStep]:
"""
Get all steps associated with a job.
Args:
job_id: The ID of the job to get steps for
actor: The user making the request
before: Cursor for pagination
after: Cursor for pagination
limit: Maximum number of steps to return
ascending: Optional flag to sort in ascending order
Returns:
List of steps associated with the job
Raises:
NoResultFound: If the job does not exist or user does not have access
"""
async with db_registry.async_session() as session:
# Build filters
filters = {}
filters["job_id"] = job_id
# Get steps
steps = StepModel.list_async(
db_session=session,
before=before,
after=after,
ascending=ascending,
limit=limit,
actor=actor,
**filters,
)
return [step.to_pydantic() for step in steps]
async def _get_run_request_config(self, run_id: str) -> LettaRequestConfig:
"""
Get the request config for a job.
Args:
job_id: The ID of the job to get messages for
Returns:
The request config for the job
"""
async with db_registry.async_session() as session:
job = await JobModel.read_async(db_session=session, identifier=run_id)
request_config = job.request_config or LettaRequestConfig()
return request_config