From 2d971cdcf0deab9a373d465f51735c93232bf495 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 10 Sep 2025 17:08:07 -0700 Subject: [PATCH] feat: hold reference to asyncio tasks in memory (#2823) --- letta/adapters/letta_llm_request_adapter.py | 6 ++- letta/adapters/letta_llm_stream_adapter.py | 6 ++- letta/agents/letta_agent_v2.py | 7 ++-- letta/functions/helpers.py | 5 ++- letta/groups/sleeptime_multi_agent_v2.py | 6 ++- letta/groups/sleeptime_multi_agent_v3.py | 6 ++- letta/server/rest_api/redis_stream_manager.py | 3 +- .../chat_completions/chat_completions.py | 6 ++- letta/server/rest_api/routers/v1/agents.py | 20 +++++---- letta/server/rest_api/routers/v1/folders.py | 4 +- letta/server/rest_api/routers/v1/sources.py | 4 +- letta/server/rest_api/streaming_response.py | 3 +- letta/server/server.py | 12 +++--- letta/services/agent_serialization_manager.py | 7 ++-- letta/services/mcp_manager.py | 4 +- letta/services/summarizer/summarizer.py | 3 +- .../multi_agent_tool_executor.py | 8 +++- letta/services/tool_sandbox/local_sandbox.py | 4 +- .../tool_sandbox/modal_version_manager.py | 3 +- letta/utils.py | 42 +++++++++++++++++-- 20 files changed, 111 insertions(+), 48 deletions(-) diff --git a/letta/adapters/letta_llm_request_adapter.py b/letta/adapters/letta_llm_request_adapter.py index a21663f4..6d98a6fd 100644 --- a/letta/adapters/letta_llm_request_adapter.py +++ b/letta/adapters/letta_llm_request_adapter.py @@ -8,6 +8,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.user import User from letta.settings import settings +from letta.utils import safe_create_task class LettaLLMRequestAdapter(LettaLLMAdapter): @@ -98,7 +99,7 @@ class LettaLLMRequestAdapter(LettaLLMAdapter): if step_id is None or actor is None or not settings.track_provider_trace: return - asyncio.create_task( + safe_create_task( self.telemetry_manager.create_provider_trace_async( actor=actor, provider_trace_create=ProviderTraceCreate( @@ -107,5 +108,6 @@ class LettaLLMRequestAdapter(LettaLLMAdapter): step_id=step_id, # Use original step_id for telemetry organization_id=actor.organization_id, ), - ) + ), + label="create_provider_trace", ) diff --git a/letta/adapters/letta_llm_stream_adapter.py b/letta/adapters/letta_llm_stream_adapter.py index c0bf2e9a..8985daa6 100644 --- a/letta/adapters/letta_llm_stream_adapter.py +++ b/letta/adapters/letta_llm_stream_adapter.py @@ -13,6 +13,7 @@ from letta.schemas.provider_trace import ProviderTraceCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.settings import settings +from letta.utils import safe_create_task class LettaLLMStreamAdapter(LettaLLMAdapter): @@ -141,7 +142,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): if step_id is None or actor is None or not settings.track_provider_trace: return - asyncio.create_task( + safe_create_task( self.telemetry_manager.create_provider_trace_async( actor=actor, provider_trace_create=ProviderTraceCreate( @@ -165,5 +166,6 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): step_id=step_id, # Use original step_id for telemetry organization_id=actor.organization_id, ), - ) + ), + label="create_provider_trace", ) diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index f9164ecc..7d8b397c 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -58,7 +58,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan from letta.settings import model_settings, settings, summarizer_settings from letta.system import package_function_response from letta.types import JsonDict -from letta.utils import log_telemetry, united_diff, validate_function_response +from letta.utils import log_telemetry, safe_create_task, united_diff, validate_function_response class LettaAgentV2(BaseAgentV2): @@ -1151,7 +1151,7 @@ class LettaAgentV2(BaseAgentV2): step_metrics: StepMetrics, run_id: str | None = None, ): - task = asyncio.create_task( + task = safe_create_task( self.step_manager.record_step_metrics_async( actor=self.actor, step_id=step_id, @@ -1163,7 +1163,8 @@ class LettaAgentV2(BaseAgentV2): project_id=self.agent_state.project_id, template_id=self.agent_state.template_id, base_template_id=self.agent_state.base_template_id, - ) + ), + label="record_step_metrics", ) return task diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index dc1a3b0b..76a01235 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -19,6 +19,7 @@ from letta.schemas.message import Message, MessageCreate from letta.schemas.user import User from letta.server.rest_api.utils import get_letta_server from letta.settings import settings +from letta.utils import safe_create_task # TODO needed? @@ -447,7 +448,7 @@ async def _send_message_to_agents_matching_tags_async( timeout=settings.multi_agent_send_message_timeout, ) - tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in matching_agents] + tasks = [safe_create_task(_send_single(agent_state), label=f"send_to_agent_{agent_state.id}") for agent_state in matching_agents] results = await asyncio.gather(*tasks, return_exceptions=True) final = [] for r in results: @@ -488,7 +489,7 @@ async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", mess timeout=settings.multi_agent_send_message_timeout, ) - tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in worker_agents] + tasks = [safe_create_task(_send_single(agent_state), label=f"send_to_worker_{agent_state.id}") for agent_state in worker_agents] results = await asyncio.gather(*tasks, return_exceptions=True) final = [] for r in results: diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index 275fe3bf..879241c2 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -24,6 +24,7 @@ from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.step_manager import NoopStepManager, StepManager from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager +from letta.utils import safe_create_task class SleeptimeMultiAgentV2(BaseAgent): @@ -236,7 +237,7 @@ class SleeptimeMultiAgentV2(BaseAgent): ) run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor) - asyncio.create_task( + safe_create_task( self._participant_agent_step( foreground_agent_id=self.agent_id, sleeptime_agent_id=sleeptime_agent_id, @@ -244,7 +245,8 @@ class SleeptimeMultiAgentV2(BaseAgent): last_processed_message_id=last_processed_message_id, run_id=run.id, use_assistant_message=True, - ) + ), + label=f"participant_agent_step_{sleeptime_agent_id}", ) return run.id diff --git a/letta/groups/sleeptime_multi_agent_v3.py b/letta/groups/sleeptime_multi_agent_v3.py index f9d6ad8f..d8c49399 100644 --- a/letta/groups/sleeptime_multi_agent_v3.py +++ b/letta/groups/sleeptime_multi_agent_v3.py @@ -17,6 +17,7 @@ from letta.schemas.message import Message, MessageCreate from letta.schemas.run import Run from letta.schemas.user import User from letta.services.group_manager import GroupManager +from letta.utils import safe_create_task class SleeptimeMultiAgentV3(LettaAgentV2): @@ -142,7 +143,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2): ) run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor) - asyncio.create_task( + safe_create_task( self._participant_agent_step( foreground_agent_id=self.agent_state.id, sleeptime_agent_id=sleeptime_agent_id, @@ -150,7 +151,8 @@ class SleeptimeMultiAgentV3(LettaAgentV2): last_processed_message_id=last_processed_message_id, run_id=run.id, use_assistant_message=use_assistant_message, - ) + ), + label=f"participant_agent_step_{sleeptime_agent_id}", ) return run.id diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py index 951b511a..adfdf7c2 100644 --- a/letta/server/rest_api/redis_stream_manager.py +++ b/letta/server/rest_api/redis_stream_manager.py @@ -8,6 +8,7 @@ from typing import AsyncIterator, Dict, List, Optional from letta.data_sources.redis_client import AsyncRedisClient from letta.log import get_logger +from letta.utils import safe_create_task logger = get_logger(__name__) @@ -62,7 +63,7 @@ class RedisSSEStreamWriter: """Start the background flush task.""" if not self._running: self._running = True - self._flush_task = asyncio.create_task(self._periodic_flush()) + self._flush_task = safe_create_task(self._periodic_flush(), label="redis_periodic_flush") async def stop(self): """Stop the background flush task and flush remaining data.""" diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 75579145..793e96ca 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -14,6 +14,7 @@ from letta.server.rest_api.chat_completions_interface import ChatCompletionsStre # TODO this belongs in a controller! from letta.server.rest_api.utils import get_letta_server, get_user_message_from_chat_completions_request, sse_async_generator +from letta.utils import safe_create_task if TYPE_CHECKING: from letta.server.server import SyncServer @@ -98,7 +99,7 @@ async def send_message_to_agent_chat_completions( # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() - asyncio.create_task( + safe_create_task( asyncio.to_thread( server.send_messages, actor=actor, @@ -106,7 +107,8 @@ async def send_message_to_agent_chat_completions( input_messages=messages, interface=streaming_interface, put_inner_thoughts_first=False, - ) + ), + label="openai_send_messages", ) # return a stream diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 21a03d7a..9d6342db 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1343,12 +1343,13 @@ async def send_message_streaming( ), ) - asyncio.create_task( + safe_create_task( create_background_stream_processor( stream_generator=raw_stream, redis_client=redis_client, run_id=run.id, - ) + ), + label=f"background_stream_processor_{run.id}", ) raw_stream = redis_sse_stream_generator( @@ -1609,7 +1610,7 @@ async def send_message_async( run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor) # Create asyncio task for background processing - task = asyncio.create_task( + task = safe_create_task( _process_message_background( run_id=run.id, server=server, @@ -1621,7 +1622,8 @@ async def send_message_async( assistant_message_tool_kwarg=request.assistant_message_tool_kwarg, max_steps=request.max_steps, include_return_message_types=request.include_return_message_types, - ) + ), + label=f"process_message_background_{run.id}", ) def handle_task_completion(t): @@ -1629,7 +1631,7 @@ async def send_message_async( t.result() except asyncio.CancelledError: logger.error(f"Background task for run {run.id} was cancelled") - asyncio.create_task( + safe_create_task( server.job_manager.update_job_by_id_async( job_id=run.id, job_update=JobUpdate( @@ -1638,11 +1640,12 @@ async def send_message_async( metadata={"error": "Task was cancelled"}, ), actor=actor, - ) + ), + label=f"update_cancelled_job_{run.id}", ) except Exception as e: logger.error(f"Unhandled exception in background task for run {run.id}: {e}") - asyncio.create_task( + safe_create_task( server.job_manager.update_job_by_id_async( job_id=run.id, job_update=JobUpdate( @@ -1651,7 +1654,8 @@ async def send_message_async( metadata={"error": str(e)}, ), actor=actor, - ) + ), + label=f"update_failed_job_{run.id}", ) task.add_done_callback(handle_task_completion) diff --git a/letta/server/rest_api/routers/v1/folders.py b/letta/server/rest_api/routers/v1/folders.py index 84a59723..760b1930 100644 --- a/letta/server/rest_api/routers/v1/folders.py +++ b/letta/server/rest_api/routers/v1/folders.py @@ -327,7 +327,7 @@ async def upload_file_to_folder( logger=logger, label="file_processor.process", ) - safe_create_task(sleeptime_document_ingest_async(server, folder_id, actor), logger=logger, label="sleeptime_document_ingest_async") + safe_create_task(sleeptime_document_ingest_async(server, folder_id, actor), label="sleeptime_document_ingest_async") return file_metadata @@ -467,7 +467,7 @@ async def delete_file_from_folder( logger.info(f"Deleting file {file_id} from pinecone index") await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor) - asyncio.create_task(sleeptime_document_ingest_async(server, folder_id, actor, clear_history=True)) + safe_create_task(sleeptime_document_ingest_async(server, folder_id, actor, clear_history=True), label="document_ingest_after_delete") if deleted_file is None: raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.") diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index a5fee7b8..7496be32 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -325,7 +325,7 @@ async def upload_file_to_source( logger=logger, label="file_processor.process", ) - safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async") + safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), label="sleeptime_document_ingest_async") return file_metadata @@ -452,7 +452,7 @@ async def delete_file_from_source( logger.info(f"Deleting file {file_id} from pinecone index") await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor) - asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True)) + safe_create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True), label="document_ingest_after_delete") if deleted_file is None: raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.") diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 8b11ab33..295e3f1f 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -19,6 +19,7 @@ from letta.schemas.user import User from letta.server.rest_api.utils import capture_sentry_exception from letta.services.job_manager import JobManager from letta.settings import settings +from letta.utils import safe_create_task logger = get_logger(__name__) @@ -64,7 +65,7 @@ async def add_keepalive_to_stream( await queue.put(("end", None)) # Start the stream reader task - reader_task = asyncio.create_task(stream_reader()) + reader_task = safe_create_task(stream_reader(), label="stream_reader") try: while True: diff --git a/letta/server/server.py b/letta/server/server.py index 48fc8801..390b1874 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -109,7 +109,7 @@ from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager from letta.settings import DatabaseChoice, model_settings, settings, tool_settings from letta.streaming_interface import AgentChunkStreamingInterface -from letta.utils import get_friendly_error_msg, get_persona_text, make_key +from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task config = LettaConfig.load() logger = get_logger(__name__) @@ -2248,7 +2248,7 @@ class SyncServer(Server): # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() - task = asyncio.create_task( + task = safe_create_task( asyncio.to_thread( self.send_messages, actor=actor, @@ -2256,7 +2256,8 @@ class SyncServer(Server): input_messages=input_messages, interface=streaming_interface, metadata=metadata, - ) + ), + label="send_messages_thread", ) if stream_steps: @@ -2363,13 +2364,14 @@ class SyncServer(Server): streaming_interface.metadata = metadata streaming_interface.stream_start() - task = asyncio.create_task( + task = safe_create_task( asyncio.to_thread( letta_multi_agent.step, input_messages=input_messages, chaining=self.chaining, max_chaining_steps=self.max_chaining_steps, - ) + ), + label="multi_agent_step_thread", ) if stream_steps: diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index a5dc89d4..dbfb3250 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -53,7 +53,7 @@ from letta.services.message_manager import MessageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings -from letta.utils import get_latest_alembic_revision +from letta.utils import get_latest_alembic_revision, safe_create_task logger = get_logger(__name__) @@ -622,10 +622,11 @@ class AgentSerializationManager: # Create background task for file processing # TODO: This can be moved to celery or RQ or something - task = asyncio.create_task( + task = safe_create_task( self._process_file_async( file_metadata=file_metadata, source_id=source_db_id, file_processor=file_processor, actor=actor - ) + ), + label=f"process_file_{file_metadata.file_name}", ) background_tasks.append(task) logger.info(f"Started background processing for file {file_metadata.file_name} (ID: {file_db_id})") diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 8668984a..648af4e6 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -43,7 +43,7 @@ from letta.services.mcp.stdio_client import AsyncStdioMCPClient from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient from letta.services.tool_manager import ToolManager from letta.settings import tool_settings -from letta.utils import enforce_types, printd +from letta.utils import enforce_types, printd, safe_create_task logger = get_logger(__name__) @@ -869,7 +869,7 @@ class MCPManager: # Run connect_to_server in background to avoid blocking # This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database - connect_task = asyncio.create_task(temp_client.connect_to_server()) + connect_task = safe_create_task(temp_client.connect_to_server(), label="mcp_oauth_connect") # Give the OAuth flow time to trigger and save the URL await asyncio.sleep(1.0) diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 6dc99ea1..48a320ac 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -20,6 +20,7 @@ from letta.services.message_manager import MessageManager from letta.services.summarizer.enums import SummarizationMode from letta.system import package_summarize_message_no_counts from letta.templates.template_helper import render_template +from letta.utils import safe_create_task logger = get_logger(__name__) @@ -100,7 +101,7 @@ class Summarizer: return in_context_messages, False def fire_and_forget(self, coro): - task = asyncio.create_task(coro) + task = safe_create_task(coro, label="summarizer_background_task") def callback(t): try: diff --git a/letta/services/tool_executor/multi_agent_tool_executor.py b/letta/services/tool_executor/multi_agent_tool_executor.py index 7aa57bae..f502e757 100644 --- a/letta/services/tool_executor/multi_agent_tool_executor.py +++ b/letta/services/tool_executor/multi_agent_tool_executor.py @@ -13,6 +13,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult from letta.schemas.user import User from letta.services.tool_executor.tool_executor_base import ToolExecutor from letta.settings import settings +from letta.utils import safe_create_task logger = get_logger(__name__) @@ -75,7 +76,10 @@ class LettaMultiAgentToolExecutor(ToolExecutor): ) tasks = [ - asyncio.create_task(self._process_agent(agent_id=agent_state.id, message=augmented_message)) for agent_state in matching_agents + safe_create_task( + self._process_agent(agent_id=agent_state.id, message=augmented_message), label=f"process_agent_{agent_state.id}" + ) + for agent_state in matching_agents ] results = await asyncio.gather(*tasks) return str(results) @@ -123,7 +127,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor): f"{message}" ) - task = asyncio.create_task(self._process_agent(agent_id=other_agent_id, message=prefixed)) + task = safe_create_task(self._process_agent(agent_id=other_agent_id, message=prefixed), label=f"send_message_to_{other_agent_id}") task.add_done_callback(lambda t: (logger.error(f"Async send_message task failed: {t.exception()}") if t.exception() else None)) diff --git a/letta/services/tool_sandbox/local_sandbox.py b/letta/services/tool_sandbox/local_sandbox.py index 29b353bb..d83fb057 100644 --- a/letta/services/tool_sandbox/local_sandbox.py +++ b/letta/services/tool_sandbox/local_sandbox.py @@ -23,7 +23,7 @@ from letta.services.helpers.tool_execution_helper import ( from letta.services.helpers.tool_parser_helper import parse_stdout_best_effort from letta.services.tool_sandbox.base import AsyncToolSandboxBase from letta.settings import tool_settings -from letta.utils import get_friendly_error_msg, parse_stderr_error_msg +from letta.utils import get_friendly_error_msg, parse_stderr_error_msg, safe_create_task logger = get_logger(__name__) @@ -89,7 +89,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): venv_preparation_task = None if use_venv: venv_path = str(os.path.join(sandbox_dir, local_configs.venv_name)) - venv_preparation_task = asyncio.create_task(self._prepare_venv(local_configs, venv_path, env)) + venv_preparation_task = safe_create_task(self._prepare_venv(local_configs, venv_path, env), label="prepare_venv") # Generate and write execution script (always with markers, since we rely on stdout) code = await self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True) diff --git a/letta/services/tool_sandbox/modal_version_manager.py b/letta/services/tool_sandbox/modal_version_manager.py index 29179386..41be9ce5 100644 --- a/letta/services/tool_sandbox/modal_version_manager.py +++ b/letta/services/tool_sandbox/modal_version_manager.py @@ -16,6 +16,7 @@ from letta.log import get_logger from letta.schemas.tool import ToolUpdate from letta.services.tool_manager import ToolManager from letta.services.tool_sandbox.modal_constants import CACHE_TTL_SECONDS, DEFAULT_CONFIG_KEY, MODAL_DEPLOYMENTS_KEY +from letta.utils import safe_create_task logger = get_logger(__name__) @@ -197,7 +198,7 @@ class ModalVersionManager: if deployment_key in self._deployments_in_progress: self._deployments_in_progress[deployment_key].set() # Clean up after a short delay to allow waiters to wake up - asyncio.create_task(self._cleanup_deployment_marker(deployment_key)) + safe_create_task(self._cleanup_deployment_marker(deployment_key), label=f"cleanup_deployment_{deployment_key}") async def _cleanup_deployment_marker(self, deployment_key: str): """Clean up deployment marker after a delay.""" diff --git a/letta/utils.py b/letta/utils.py index 581b469e..c8a08547 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -39,6 +39,7 @@ from letta.constants import ( ) from letta.helpers.json_helpers import json_dumps, json_loads from letta.log import get_logger +from letta.otel.tracing import log_attributes, trace_method from letta.schemas.openai.chat_completion_response import ChatCompletionResponse logger = get_logger(__name__) @@ -1093,14 +1094,35 @@ def make_key(*args, **kwargs): return str((args, tuple(sorted(kwargs.items())))) -def safe_create_task(coro, logger: Logger, label: str = "background task"): +# Global set to keep strong references to background tasks +_background_tasks: set = set() + + +def get_background_task_count() -> int: + """Get the current number of background tasks for debugging/monitoring.""" + return len(_background_tasks) + + +@trace_method +def safe_create_task(coro, label: str = "background task"): async def wrapper(): try: await coro except Exception as e: logger.exception(f"{label} failed with {type(e).__name__}: {e}") - return asyncio.create_task(wrapper()) + task = asyncio.create_task(wrapper()) + + # Add task to the set to maintain strong reference + _background_tasks.add(task) + + # Log task count to trace + log_attributes({"total_background_task_count": get_background_task_count()}) + + # Remove task from set when done to prevent memory leaks + task.add_done_callback(_background_tasks.discard) + + return task def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: Logger, label: str = "file processing task"): @@ -1137,7 +1159,15 @@ def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: except Exception as update_error: logger.error(f"Failed to update file status to ERROR for {file_metadata.id}: {update_error}") - return asyncio.create_task(wrapper()) + task = asyncio.create_task(wrapper()) + + # Add task to the set to maintain strong reference + _background_tasks.add(task) + + # Remove task from set when done to prevent memory leaks + task.add_done_callback(_background_tasks.discard) + + return task class CancellationSignal: @@ -1289,6 +1319,12 @@ def fire_and_forget(coro, task_name: Optional[str] = None, error_callback: Optio task = asyncio.create_task(coro) + # Add task to the set to maintain strong reference + _background_tasks.add(task) + + # Remove task from set when done to prevent memory leaks + task.add_done_callback(_background_tasks.discard) + def callback(t): try: t.result() # this re-raises exceptions from the task