feat: cleanup cancellation code and add more logging (#6588)

This commit is contained in:
Sarah Wooders
2025-12-10 11:56:12 -08:00
committed by Caren Thomas
parent 70c57c5072
commit c8fa77a01f
7 changed files with 71 additions and 97 deletions

View File

@@ -1261,77 +1261,6 @@ def safe_create_file_processing_task(coro, file_metadata, server, actor, logger:
return task
class CancellationSignal:
"""
A signal that can be checked for cancellation during streaming operations.
This provides a lightweight way to check if an operation should be cancelled
without having to pass job managers and other dependencies through every method.
"""
def __init__(self, job_manager=None, job_id=None, actor=None):
from letta.log import get_logger
from letta.schemas.user import User
from letta.services.job_manager import JobManager
self.job_manager: JobManager | None = job_manager
self.job_id: str | None = job_id
self.actor: User | None = actor
self._is_cancelled = False
self.logger = get_logger(__name__)
async def is_cancelled(self) -> bool:
"""
Check if the operation has been cancelled.
Returns:
True if cancelled, False otherwise
"""
from letta.schemas.enums import JobStatus
if self._is_cancelled:
return True
if not self.job_manager or not self.job_id or not self.actor:
return False
try:
job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor)
self._is_cancelled = job.status == JobStatus.cancelled
return self._is_cancelled
except Exception as e:
self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}")
return False
def cancel(self):
"""Mark this signal as cancelled locally (for testing or direct cancellation)."""
self._is_cancelled = True
async def check_and_raise_if_cancelled(self):
"""
Check for cancellation and raise CancelledError if cancelled.
Raises:
asyncio.CancelledError: If the operation has been cancelled
"""
if await self.is_cancelled():
self.logger.info(f"Operation cancelled for job {self.job_id}")
raise asyncio.CancelledError(f"Job {self.job_id} was cancelled")
class NullCancellationSignal(CancellationSignal):
"""A null cancellation signal that is never cancelled."""
def __init__(self):
super().__init__()
async def is_cancelled(self) -> bool:
return False
async def check_and_raise_if_cancelled(self):
pass
async def get_latest_alembic_revision() -> str:
"""Get the current alembic revision ID from the alembic_version table."""
from letta.server.db import db_registry