feat: add asyncio shield to async message route (#2825)
This commit is contained in:
@@ -60,7 +60,7 @@ from letta.server.server import SyncServer
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task, truncate_file_visible_content
|
||||
from letta.utils import safe_create_shielded_task, safe_create_task, truncate_file_visible_content
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
@@ -1609,8 +1609,8 @@ async def send_message_async(
|
||||
)
|
||||
run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor)
|
||||
|
||||
# Create asyncio task for background processing
|
||||
task = safe_create_task(
|
||||
# Create asyncio task for background processing (shielded to prevent cancellation)
|
||||
task = safe_create_shielded_task(
|
||||
_process_message_background(
|
||||
run_id=run.id,
|
||||
server=server,
|
||||
@@ -1630,19 +1630,9 @@ async def send_message_async(
|
||||
try:
|
||||
t.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.error(f"Background task for run {run.id} was cancelled")
|
||||
safe_create_task(
|
||||
server.job_manager.update_job_by_id_async(
|
||||
job_id=run.id,
|
||||
job_update=JobUpdate(
|
||||
status=JobStatus.failed,
|
||||
completed_at=datetime.now(timezone.utc),
|
||||
metadata={"error": "Task was cancelled"},
|
||||
),
|
||||
actor=actor,
|
||||
),
|
||||
label=f"update_cancelled_job_{run.id}",
|
||||
)
|
||||
# Note: With shielded tasks, cancellation attempts don't actually stop the task
|
||||
logger.info(f"Cancellation attempted on shielded background task for run {run.id}, but task continues running")
|
||||
# Don't mark as failed since the shielded task is still running
|
||||
except Exception as e:
|
||||
logger.error(f"Unhandled exception in background task for run {run.id}: {e}")
|
||||
safe_create_task(
|
||||
|
||||
@@ -1125,6 +1125,39 @@ def safe_create_task(coro, label: str = "background task"):
|
||||
return task
|
||||
|
||||
|
||||
def safe_create_shielded_task(coro, label: str = "shielded background task"):
|
||||
"""
|
||||
Create a shielded background task that cannot be cancelled externally.
|
||||
|
||||
This is useful for critical operations that must complete even if the
|
||||
parent operation is cancelled. The task is internally shielded but the
|
||||
returned task can still have callbacks added to it.
|
||||
"""
|
||||
|
||||
async def shielded_wrapper():
|
||||
try:
|
||||
# Shield the original coroutine to prevent cancellation
|
||||
result = await asyncio.shield(coro)
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
|
||||
raise
|
||||
|
||||
# Create the task with the shielded wrapper
|
||||
task = asyncio.create_task(shielded_wrapper())
|
||||
|
||||
# Add task to the set to maintain strong reference
|
||||
_background_tasks.add(task)
|
||||
|
||||
# Log task count to trace
|
||||
log_attributes({"total_background_task_count": get_background_task_count()})
|
||||
|
||||
# Remove task from set when done to prevent memory leaks
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: Logger, label: str = "file processing task"):
|
||||
"""
|
||||
Create a task for file processing that updates file status on failure.
|
||||
|
||||
Reference in New Issue
Block a user