feat: support for agent loop job cancelation (#2837)

This commit is contained in:
Andy Li
2025-07-02 14:31:16 -07:00
committed by GitHub
parent e9f7601892
commit f9bb757a98
17 changed files with 940 additions and 281 deletions

View File

@@ -12,11 +12,12 @@ import re
import subprocess
import sys
import uuid
from collections.abc import Coroutine
from contextlib import contextmanager
from datetime import datetime, timezone
from functools import wraps
from logging import Logger
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
from typing import Any, Coroutine, Union, _GenericAlias, get_args, get_origin, get_type_hints
from urllib.parse import urljoin, urlparse
import demjson3 as demjson
@@ -519,7 +520,7 @@ def enforce_types(func):
arg_names = inspect.getfullargspec(func).args
# Pair each argument with its corresponding type hint
args_with_hints = dict(zip(arg_names[1:], args[1:])) # Skipping 'self'
args_with_hints = dict(zip(arg_names[1:], args[1:], strict=False)) # Skipping 'self'
# Function to check if a value matches a given type hint
def matches_type(value, hint):
@@ -557,7 +558,7 @@ def enforce_types(func):
return wrapper
def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False):
def annotate_message_json_list_with_tool_calls(messages: list[dict], allow_tool_roles: bool = False):
"""Add in missing tool_call_id fields to a list of messages using function call style
Walk through the list forwards:
@@ -946,7 +947,7 @@ def get_human_text(name: str, enforce_limit=True):
for file_path in list_human_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
human_text = open(file_path, "r", encoding="utf-8").read().strip()
human_text = open(file_path, encoding="utf-8").read().strip()
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
return human_text
@@ -958,7 +959,7 @@ def get_persona_text(name: str, enforce_limit=True):
for file_path in list_persona_files():
file = os.path.basename(file_path)
if f"{name}.txt" == file or name == file:
persona_text = open(file_path, "r", encoding="utf-8").read().strip()
persona_text = open(file_path, encoding="utf-8").read().strip()
if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT:
raise ValueError(
f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})"
@@ -1109,3 +1110,75 @@ def safe_create_task(coro, logger: Logger, label: str = "background task"):
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
return asyncio.create_task(wrapper())
class CancellationSignal:
"""
A signal that can be checked for cancellation during streaming operations.
This provides a lightweight way to check if an operation should be cancelled
without having to pass job managers and other dependencies through every method.
"""
def __init__(self, job_manager=None, job_id=None, actor=None):
from letta.log import get_logger
from letta.schemas.user import User
from letta.services.job_manager import JobManager
self.job_manager: JobManager | None = job_manager
self.job_id: str | None = job_id
self.actor: User | None = actor
self._is_cancelled = False
self.logger = get_logger(__name__)
async def is_cancelled(self) -> bool:
"""
Check if the operation has been cancelled.
Returns:
True if cancelled, False otherwise
"""
from letta.schemas.enums import JobStatus
if self._is_cancelled:
return True
if not self.job_manager or not self.job_id or not self.actor:
return False
try:
job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor)
self._is_cancelled = job.status == JobStatus.cancelled
return self._is_cancelled
except Exception as e:
self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}")
return False
def cancel(self):
"""Mark this signal as cancelled locally (for testing or direct cancellation)."""
self._is_cancelled = True
async def check_and_raise_if_cancelled(self):
"""
Check for cancellation and raise CancelledError if cancelled.
Raises:
asyncio.CancelledError: If the operation has been cancelled
"""
if await self.is_cancelled():
self.logger.info(f"Operation cancelled for job {self.job_id}")
raise asyncio.CancelledError(f"Job {self.job_id} was cancelled")
class NullCancellationSignal(CancellationSignal):
"""A null cancellation signal that is never cancelled."""
def __init__(self):
super().__init__()
async def is_cancelled(self) -> bool:
return False
async def check_and_raise_if_cancelled(self):
pass