diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index fddad2cf..281268d2 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -130,7 +130,9 @@ class JobManager: @enforce_types @trace_method - async def update_job_by_id_async(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob: + async def update_job_by_id_async( + self, job_id: str, job_update: JobUpdate, actor: PydanticUser, safe_update: bool = False + ) -> PydanticJob: """Update a job by its ID with the given JobUpdate object asynchronously.""" # First check if we need to dispatch a callback needs_callback = False @@ -138,16 +140,24 @@ class JobManager: async with db_registry.async_session() as session: job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"]) + # Safely update job status with state transition guards: Created -> Pending -> Running --> + if safe_update: + current_status = JobStatus(job.status) + if not any( + ( + job_update.status.is_terminal and not current_status.is_terminal, + current_status == JobStatus.created and job_update.status != JobStatus.created, + current_status == JobStatus.pending and job_update.status == JobStatus.running, + ) + ): + logger.error(f"Invalid job status transition from {current_status} to {job_update.status} for job {job_id}") + raise ValueError(f"Invalid job status transition from {current_status} to {job_update.status}") + # Check if we'll need to dispatch callback if job_update.status in {JobStatus.completed, JobStatus.failed} and job.callback_url: needs_callback = True callback_url = job.callback_url - # Update the job first to get the final metadata - async with db_registry.async_session() as session: - # Fetch the job by ID - job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"]) - # Update job attributes with only the fields that were explicitly set update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) @@ -205,20 +215,6 @@ class JobManager: True if update was successful, False if update was skipped due to invalid transition """ try: - # Get current job state - current_job = await self.get_job_by_id_async(job_id=job_id, actor=actor) - - current_status = current_job.status - if not any( - ( - new_status.is_terminal and not current_status.is_terminal, - current_status == JobStatus.created and new_status != JobStatus.created, - current_status == JobStatus.pending and new_status == JobStatus.running, - ) - ): - logger.warning(f"Invalid job status transition from {current_job.status} to {new_status} for job {job_id}") - return False - job_update_builder = partial(JobUpdate, status=new_status) if metadata: job_update_builder = partial(job_update_builder, metadata=metadata)