feat: hold reference to asyncio tasks in memory (#2823)
This commit is contained in:
@@ -8,6 +8,7 @@ from letta.schemas.letta_message_content import OmittedReasoningContent, Reasoni
|
||||
from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
|
||||
class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
@@ -98,7 +99,7 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
if step_id is None or actor is None or not settings.track_provider_trace:
|
||||
return
|
||||
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
self.telemetry_manager.create_provider_trace_async(
|
||||
actor=actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
@@ -107,5 +108,6 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=actor.organization_id,
|
||||
),
|
||||
)
|
||||
),
|
||||
label="create_provider_trace",
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.provider_trace import ProviderTraceCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
|
||||
class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
@@ -141,7 +142,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
if step_id is None or actor is None or not settings.track_provider_trace:
|
||||
return
|
||||
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
self.telemetry_manager.create_provider_trace_async(
|
||||
actor=actor,
|
||||
provider_trace_create=ProviderTraceCreate(
|
||||
@@ -165,5 +166,6 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
step_id=step_id, # Use original step_id for telemetry
|
||||
organization_id=actor.organization_id,
|
||||
),
|
||||
)
|
||||
),
|
||||
label="create_provider_trace",
|
||||
)
|
||||
|
||||
@@ -58,7 +58,7 @@ from letta.services.tool_executor.tool_execution_manager import ToolExecutionMan
|
||||
from letta.settings import model_settings, settings, summarizer_settings
|
||||
from letta.system import package_function_response
|
||||
from letta.types import JsonDict
|
||||
from letta.utils import log_telemetry, united_diff, validate_function_response
|
||||
from letta.utils import log_telemetry, safe_create_task, united_diff, validate_function_response
|
||||
|
||||
|
||||
class LettaAgentV2(BaseAgentV2):
|
||||
@@ -1151,7 +1151,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
step_metrics: StepMetrics,
|
||||
run_id: str | None = None,
|
||||
):
|
||||
task = asyncio.create_task(
|
||||
task = safe_create_task(
|
||||
self.step_manager.record_step_metrics_async(
|
||||
actor=self.actor,
|
||||
step_id=step_id,
|
||||
@@ -1163,7 +1163,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
base_template_id=self.agent_state.base_template_id,
|
||||
)
|
||||
),
|
||||
label="record_step_metrics",
|
||||
)
|
||||
return task
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
|
||||
# TODO needed?
|
||||
@@ -447,7 +448,7 @@ async def _send_message_to_agents_matching_tags_async(
|
||||
timeout=settings.multi_agent_send_message_timeout,
|
||||
)
|
||||
|
||||
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in matching_agents]
|
||||
tasks = [safe_create_task(_send_single(agent_state), label=f"send_to_agent_{agent_state.id}") for agent_state in matching_agents]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
final = []
|
||||
for r in results:
|
||||
@@ -488,7 +489,7 @@ async def _send_message_to_all_agents_in_group_async(sender_agent: "Agent", mess
|
||||
timeout=settings.multi_agent_send_message_timeout,
|
||||
)
|
||||
|
||||
tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in worker_agents]
|
||||
tasks = [safe_create_task(_send_single(agent_state), label=f"send_to_worker_{agent_state.id}") for agent_state in worker_agents]
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
final = []
|
||||
for r in results:
|
||||
|
||||
@@ -24,6 +24,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager, TelemetryManager
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
|
||||
class SleeptimeMultiAgentV2(BaseAgent):
|
||||
@@ -236,7 +237,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
)
|
||||
run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor)
|
||||
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
self._participant_agent_step(
|
||||
foreground_agent_id=self.agent_id,
|
||||
sleeptime_agent_id=sleeptime_agent_id,
|
||||
@@ -244,7 +245,8 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
last_processed_message_id=last_processed_message_id,
|
||||
run_id=run.id,
|
||||
use_assistant_message=True,
|
||||
)
|
||||
),
|
||||
label=f"participant_agent_step_{sleeptime_agent_id}",
|
||||
)
|
||||
return run.id
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.user import User
|
||||
from letta.services.group_manager import GroupManager
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
|
||||
class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
@@ -142,7 +143,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
)
|
||||
run = await self.job_manager.create_job_async(pydantic_job=run, actor=self.actor)
|
||||
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
self._participant_agent_step(
|
||||
foreground_agent_id=self.agent_state.id,
|
||||
sleeptime_agent_id=sleeptime_agent_id,
|
||||
@@ -150,7 +151,8 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
last_processed_message_id=last_processed_message_id,
|
||||
run_id=run.id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
)
|
||||
),
|
||||
label=f"participant_agent_step_{sleeptime_agent_id}",
|
||||
)
|
||||
return run.id
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import AsyncIterator, Dict, List, Optional
|
||||
|
||||
from letta.data_sources.redis_client import AsyncRedisClient
|
||||
from letta.log import get_logger
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -62,7 +63,7 @@ class RedisSSEStreamWriter:
|
||||
"""Start the background flush task."""
|
||||
if not self._running:
|
||||
self._running = True
|
||||
self._flush_task = asyncio.create_task(self._periodic_flush())
|
||||
self._flush_task = safe_create_task(self._periodic_flush(), label="redis_periodic_flush")
|
||||
|
||||
async def stop(self):
|
||||
"""Stop the background flush task and flush remaining data."""
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.server.rest_api.chat_completions_interface import ChatCompletionsStre
|
||||
|
||||
# TODO this belongs in a controller!
|
||||
from letta.server.rest_api.utils import get_letta_server, get_user_message_from_chat_completions_request, sse_async_generator
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
@@ -98,7 +99,7 @@ async def send_message_to_agent_chat_completions(
|
||||
|
||||
# Offload the synchronous message_func to a separate thread
|
||||
streaming_interface.stream_start()
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
asyncio.to_thread(
|
||||
server.send_messages,
|
||||
actor=actor,
|
||||
@@ -106,7 +107,8 @@ async def send_message_to_agent_chat_completions(
|
||||
input_messages=messages,
|
||||
interface=streaming_interface,
|
||||
put_inner_thoughts_first=False,
|
||||
)
|
||||
),
|
||||
label="openai_send_messages",
|
||||
)
|
||||
|
||||
# return a stream
|
||||
|
||||
@@ -1343,12 +1343,13 @@ async def send_message_streaming(
|
||||
),
|
||||
)
|
||||
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
create_background_stream_processor(
|
||||
stream_generator=raw_stream,
|
||||
redis_client=redis_client,
|
||||
run_id=run.id,
|
||||
)
|
||||
),
|
||||
label=f"background_stream_processor_{run.id}",
|
||||
)
|
||||
|
||||
raw_stream = redis_sse_stream_generator(
|
||||
@@ -1609,7 +1610,7 @@ async def send_message_async(
|
||||
run = await server.job_manager.create_job_async(pydantic_job=run, actor=actor)
|
||||
|
||||
# Create asyncio task for background processing
|
||||
task = asyncio.create_task(
|
||||
task = safe_create_task(
|
||||
_process_message_background(
|
||||
run_id=run.id,
|
||||
server=server,
|
||||
@@ -1621,7 +1622,8 @@ async def send_message_async(
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
max_steps=request.max_steps,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
)
|
||||
),
|
||||
label=f"process_message_background_{run.id}",
|
||||
)
|
||||
|
||||
def handle_task_completion(t):
|
||||
@@ -1629,7 +1631,7 @@ async def send_message_async(
|
||||
t.result()
|
||||
except asyncio.CancelledError:
|
||||
logger.error(f"Background task for run {run.id} was cancelled")
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
server.job_manager.update_job_by_id_async(
|
||||
job_id=run.id,
|
||||
job_update=JobUpdate(
|
||||
@@ -1638,11 +1640,12 @@ async def send_message_async(
|
||||
metadata={"error": "Task was cancelled"},
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
),
|
||||
label=f"update_cancelled_job_{run.id}",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unhandled exception in background task for run {run.id}: {e}")
|
||||
asyncio.create_task(
|
||||
safe_create_task(
|
||||
server.job_manager.update_job_by_id_async(
|
||||
job_id=run.id,
|
||||
job_update=JobUpdate(
|
||||
@@ -1651,7 +1654,8 @@ async def send_message_async(
|
||||
metadata={"error": str(e)},
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
),
|
||||
label=f"update_failed_job_{run.id}",
|
||||
)
|
||||
|
||||
task.add_done_callback(handle_task_completion)
|
||||
|
||||
@@ -327,7 +327,7 @@ async def upload_file_to_folder(
|
||||
logger=logger,
|
||||
label="file_processor.process",
|
||||
)
|
||||
safe_create_task(sleeptime_document_ingest_async(server, folder_id, actor), logger=logger, label="sleeptime_document_ingest_async")
|
||||
safe_create_task(sleeptime_document_ingest_async(server, folder_id, actor), label="sleeptime_document_ingest_async")
|
||||
|
||||
return file_metadata
|
||||
|
||||
@@ -467,7 +467,7 @@ async def delete_file_from_folder(
|
||||
logger.info(f"Deleting file {file_id} from pinecone index")
|
||||
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
|
||||
|
||||
asyncio.create_task(sleeptime_document_ingest_async(server, folder_id, actor, clear_history=True))
|
||||
safe_create_task(sleeptime_document_ingest_async(server, folder_id, actor, clear_history=True), label="document_ingest_after_delete")
|
||||
if deleted_file is None:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||||
|
||||
|
||||
@@ -325,7 +325,7 @@ async def upload_file_to_source(
|
||||
logger=logger,
|
||||
label="file_processor.process",
|
||||
)
|
||||
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), logger=logger, label="sleeptime_document_ingest_async")
|
||||
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor), label="sleeptime_document_ingest_async")
|
||||
|
||||
return file_metadata
|
||||
|
||||
@@ -452,7 +452,7 @@ async def delete_file_from_source(
|
||||
logger.info(f"Deleting file {file_id} from pinecone index")
|
||||
await delete_file_records_from_pinecone_index(file_id=file_id, actor=actor)
|
||||
|
||||
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
|
||||
safe_create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True), label="document_ingest_after_delete")
|
||||
if deleted_file is None:
|
||||
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import capture_sentry_exception
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -64,7 +65,7 @@ async def add_keepalive_to_stream(
|
||||
await queue.put(("end", None))
|
||||
|
||||
# Start the stream reader task
|
||||
reader_task = asyncio.create_task(stream_reader())
|
||||
reader_task = safe_create_task(stream_reader(), label="stream_reader")
|
||||
|
||||
try:
|
||||
while True:
|
||||
|
||||
@@ -109,7 +109,7 @@ from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import DatabaseChoice, model_settings, settings, tool_settings
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, make_key
|
||||
from letta.utils import get_friendly_error_msg, get_persona_text, make_key, safe_create_task
|
||||
|
||||
config = LettaConfig.load()
|
||||
logger = get_logger(__name__)
|
||||
@@ -2248,7 +2248,7 @@ class SyncServer(Server):
|
||||
|
||||
# Offload the synchronous message_func to a separate thread
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
task = safe_create_task(
|
||||
asyncio.to_thread(
|
||||
self.send_messages,
|
||||
actor=actor,
|
||||
@@ -2256,7 +2256,8 @@ class SyncServer(Server):
|
||||
input_messages=input_messages,
|
||||
interface=streaming_interface,
|
||||
metadata=metadata,
|
||||
)
|
||||
),
|
||||
label="send_messages_thread",
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
@@ -2363,13 +2364,14 @@ class SyncServer(Server):
|
||||
streaming_interface.metadata = metadata
|
||||
|
||||
streaming_interface.stream_start()
|
||||
task = asyncio.create_task(
|
||||
task = safe_create_task(
|
||||
asyncio.to_thread(
|
||||
letta_multi_agent.step,
|
||||
input_messages=input_messages,
|
||||
chaining=self.chaining,
|
||||
max_chaining_steps=self.max_chaining_steps,
|
||||
)
|
||||
),
|
||||
label="multi_agent_step_thread",
|
||||
)
|
||||
|
||||
if stream_steps:
|
||||
|
||||
@@ -53,7 +53,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import get_latest_alembic_revision
|
||||
from letta.utils import get_latest_alembic_revision, safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -622,10 +622,11 @@ class AgentSerializationManager:
|
||||
|
||||
# Create background task for file processing
|
||||
# TODO: This can be moved to celery or RQ or something
|
||||
task = asyncio.create_task(
|
||||
task = safe_create_task(
|
||||
self._process_file_async(
|
||||
file_metadata=file_metadata, source_id=source_db_id, file_processor=file_processor, actor=actor
|
||||
)
|
||||
),
|
||||
label=f"process_file_{file_metadata.file_name}",
|
||||
)
|
||||
background_tasks.append(task)
|
||||
logger.info(f"Started background processing for file {file_metadata.file_name} (ID: {file_db_id})")
|
||||
|
||||
@@ -43,7 +43,7 @@ from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import tool_settings
|
||||
from letta.utils import enforce_types, printd
|
||||
from letta.utils import enforce_types, printd, safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -869,7 +869,7 @@ class MCPManager:
|
||||
|
||||
# Run connect_to_server in background to avoid blocking
|
||||
# This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database
|
||||
connect_task = asyncio.create_task(temp_client.connect_to_server())
|
||||
connect_task = safe_create_task(temp_client.connect_to_server(), label="mcp_oauth_connect")
|
||||
|
||||
# Give the OAuth flow time to trigger and save the URL
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
@@ -20,6 +20,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
from letta.system import package_summarize_message_no_counts
|
||||
from letta.templates.template_helper import render_template
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -100,7 +101,7 @@ class Summarizer:
|
||||
return in_context_messages, False
|
||||
|
||||
def fire_and_forget(self, coro):
|
||||
task = asyncio.create_task(coro)
|
||||
task = safe_create_task(coro, label="summarizer_background_task")
|
||||
|
||||
def callback(t):
|
||||
try:
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.tool_executor.tool_executor_base import ToolExecutor
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -75,7 +76,10 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
)
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(self._process_agent(agent_id=agent_state.id, message=augmented_message)) for agent_state in matching_agents
|
||||
safe_create_task(
|
||||
self._process_agent(agent_id=agent_state.id, message=augmented_message), label=f"process_agent_{agent_state.id}"
|
||||
)
|
||||
for agent_state in matching_agents
|
||||
]
|
||||
results = await asyncio.gather(*tasks)
|
||||
return str(results)
|
||||
@@ -123,7 +127,7 @@ class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
task = asyncio.create_task(self._process_agent(agent_id=other_agent_id, message=prefixed))
|
||||
task = safe_create_task(self._process_agent(agent_id=other_agent_id, message=prefixed), label=f"send_message_to_{other_agent_id}")
|
||||
|
||||
task.add_done_callback(lambda t: (logger.error(f"Async send_message task failed: {t.exception()}") if t.exception() else None))
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from letta.services.helpers.tool_execution_helper import (
|
||||
from letta.services.helpers.tool_parser_helper import parse_stdout_best_effort
|
||||
from letta.services.tool_sandbox.base import AsyncToolSandboxBase
|
||||
from letta.settings import tool_settings
|
||||
from letta.utils import get_friendly_error_msg, parse_stderr_error_msg
|
||||
from letta.utils import get_friendly_error_msg, parse_stderr_error_msg, safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -89,7 +89,7 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
venv_preparation_task = None
|
||||
if use_venv:
|
||||
venv_path = str(os.path.join(sandbox_dir, local_configs.venv_name))
|
||||
venv_preparation_task = asyncio.create_task(self._prepare_venv(local_configs, venv_path, env))
|
||||
venv_preparation_task = safe_create_task(self._prepare_venv(local_configs, venv_path, env), label="prepare_venv")
|
||||
|
||||
# Generate and write execution script (always with markers, since we rely on stdout)
|
||||
code = await self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
|
||||
|
||||
@@ -16,6 +16,7 @@ from letta.log import get_logger
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.tool_sandbox.modal_constants import CACHE_TTL_SECONDS, DEFAULT_CONFIG_KEY, MODAL_DEPLOYMENTS_KEY
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -197,7 +198,7 @@ class ModalVersionManager:
|
||||
if deployment_key in self._deployments_in_progress:
|
||||
self._deployments_in_progress[deployment_key].set()
|
||||
# Clean up after a short delay to allow waiters to wake up
|
||||
asyncio.create_task(self._cleanup_deployment_marker(deployment_key))
|
||||
safe_create_task(self._cleanup_deployment_marker(deployment_key), label=f"cleanup_deployment_{deployment_key}")
|
||||
|
||||
async def _cleanup_deployment_marker(self, deployment_key: str):
|
||||
"""Clean up deployment marker after a delay."""
|
||||
|
||||
@@ -39,6 +39,7 @@ from letta.constants import (
|
||||
)
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import log_attributes, trace_method
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -1093,14 +1094,35 @@ def make_key(*args, **kwargs):
|
||||
return str((args, tuple(sorted(kwargs.items()))))
|
||||
|
||||
|
||||
def safe_create_task(coro, logger: Logger, label: str = "background task"):
|
||||
# Global set to keep strong references to background tasks
|
||||
_background_tasks: set = set()
|
||||
|
||||
|
||||
def get_background_task_count() -> int:
|
||||
"""Get the current number of background tasks for debugging/monitoring."""
|
||||
return len(_background_tasks)
|
||||
|
||||
|
||||
@trace_method
|
||||
def safe_create_task(coro, label: str = "background task"):
|
||||
async def wrapper():
|
||||
try:
|
||||
await coro
|
||||
except Exception as e:
|
||||
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
|
||||
|
||||
return asyncio.create_task(wrapper())
|
||||
task = asyncio.create_task(wrapper())
|
||||
|
||||
# Add task to the set to maintain strong reference
|
||||
_background_tasks.add(task)
|
||||
|
||||
# Log task count to trace
|
||||
log_attributes({"total_background_task_count": get_background_task_count()})
|
||||
|
||||
# Remove task from set when done to prevent memory leaks
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
def safe_create_file_processing_task(coro, file_metadata, server, actor, logger: Logger, label: str = "file processing task"):
|
||||
@@ -1137,7 +1159,15 @@ def safe_create_file_processing_task(coro, file_metadata, server, actor, logger:
|
||||
except Exception as update_error:
|
||||
logger.error(f"Failed to update file status to ERROR for {file_metadata.id}: {update_error}")
|
||||
|
||||
return asyncio.create_task(wrapper())
|
||||
task = asyncio.create_task(wrapper())
|
||||
|
||||
# Add task to the set to maintain strong reference
|
||||
_background_tasks.add(task)
|
||||
|
||||
# Remove task from set when done to prevent memory leaks
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
return task
|
||||
|
||||
|
||||
class CancellationSignal:
|
||||
@@ -1289,6 +1319,12 @@ def fire_and_forget(coro, task_name: Optional[str] = None, error_callback: Optio
|
||||
|
||||
task = asyncio.create_task(coro)
|
||||
|
||||
# Add task to the set to maintain strong reference
|
||||
_background_tasks.add(task)
|
||||
|
||||
# Remove task from set when done to prevent memory leaks
|
||||
task.add_done_callback(_background_tasks.discard)
|
||||
|
||||
def callback(t):
|
||||
try:
|
||||
t.result() # this re-raises exceptions from the task
|
||||
|
||||
Reference in New Issue
Block a user