From 858b8aa5c3c18dd40b8fde305249b5f571f3a2b3 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 10 Sep 2025 17:38:43 -0700 Subject: [PATCH] feat: add asyncio shield to async message route (#2825) --- letta/server/rest_api/routers/v1/agents.py | 22 ++++----------- letta/utils.py | 33 ++++++++++++++++++++++ 2 files changed, 39 insertions(+), 16 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 9d6342db..c64e7b7b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -60,7 +60,7 @@ from letta.server.server import SyncServer from letta.services.summarizer.enums import SummarizationMode from letta.services.telemetry_manager import NoopTelemetryManager from letta.settings import settings -from letta.utils import safe_create_task, truncate_file_visible_content +from letta.utils import safe_create_shielded_task, safe_create_task, truncate_file_visible_content # These can be forward refs, but because Fastapi needs them at runtime the must be imported normally @@ -1609,8 +1609,8 @@ async def send_message_async( ) run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor) - # Create asyncio task for background processing - task = safe_create_task( + # Create asyncio task for background processing (shielded to prevent cancellation) + task = safe_create_shielded_task( _process_message_background( run_id=run.id, server=server, @@ -1630,19 +1630,9 @@ async def send_message_async( try: t.result() except asyncio.CancelledError: - logger.error(f"Background task for run {run.id} was cancelled") - safe_create_task( - server.job_manager.update_job_by_id_async( - job_id=run.id, - job_update=JobUpdate( - status=JobStatus.failed, - completed_at=datetime.now(timezone.utc), - metadata={"error": "Task was cancelled"}, - ), - actor=actor, - ), - label=f"update_cancelled_job_{run.id}", - ) + # Note: With shielded tasks, cancellation attempts don't actually stop the task + logger.info(f"Cancellation attempted on shielded background task for run {run.id}, but task continues running") + # Don't mark as failed since the shielded task is still running except Exception as e: logger.error(f"Unhandled exception in background task for run {run.id}: {e}") safe_create_task( diff --git a/letta/utils.py b/letta/utils.py index c8a08547..793dfe99 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -1125,6 +1125,39 @@ def safe_create_task(coro, label: str = "background task"): return task +def safe_create_shielded_task(coro, label: str = "shielded background task"): + """ + Create a shielded background task that cannot be cancelled externally. + + This is useful for critical operations that must complete even if the + parent operation is cancelled. The task is internally shielded but the + returned task can still have callbacks added to it. + """ + + async def shielded_wrapper(): + try: + # Shield the original coroutine to prevent cancellation + result = await asyncio.shield(coro) + return result + except Exception as e: + logger.exception(f"{label} failed with {type(e).__name__}: {e}") + raise + + # Create the task with the shielded wrapper + task = asyncio.create_task(shielded_wrapper()) + + # Add task to the set to maintain strong reference + _background_tasks.add(task) + + # Log task count to trace + log_attributes({"total_background_task_count": get_background_task_count()}) + + # Remove task from set when done to prevent memory leaks + task.add_done_callback(_background_tasks.discard) + + return task + + def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: Logger, label: str = "file processing task"): """ Create a task for file processing that updates file status on failure.