feat: support for agent loop job cancelation (#2837)
This commit is contained in:
@@ -12,11 +12,12 @@ import re
|
||||
import subprocess
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import Coroutine
|
||||
from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from logging import Logger
|
||||
from typing import Any, Coroutine, List, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Coroutine, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import demjson3 as demjson
|
||||
@@ -519,7 +520,7 @@ def enforce_types(func):
|
||||
arg_names = inspect.getfullargspec(func).args
|
||||
|
||||
# Pair each argument with its corresponding type hint
|
||||
args_with_hints = dict(zip(arg_names[1:], args[1:])) # Skipping 'self'
|
||||
args_with_hints = dict(zip(arg_names[1:], args[1:], strict=False)) # Skipping 'self'
|
||||
|
||||
# Function to check if a value matches a given type hint
|
||||
def matches_type(value, hint):
|
||||
@@ -557,7 +558,7 @@ def enforce_types(func):
|
||||
return wrapper
|
||||
|
||||
|
||||
def annotate_message_json_list_with_tool_calls(messages: List[dict], allow_tool_roles: bool = False):
|
||||
def annotate_message_json_list_with_tool_calls(messages: list[dict], allow_tool_roles: bool = False):
|
||||
"""Add in missing tool_call_id fields to a list of messages using function call style
|
||||
|
||||
Walk through the list forwards:
|
||||
@@ -946,7 +947,7 @@ def get_human_text(name: str, enforce_limit=True):
|
||||
for file_path in list_human_files():
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
human_text = open(file_path, "r", encoding="utf-8").read().strip()
|
||||
human_text = open(file_path, encoding="utf-8").read().strip()
|
||||
if enforce_limit and len(human_text) > CORE_MEMORY_HUMAN_CHAR_LIMIT:
|
||||
raise ValueError(f"Contents of {name}.txt is over the character limit ({len(human_text)} > {CORE_MEMORY_HUMAN_CHAR_LIMIT})")
|
||||
return human_text
|
||||
@@ -958,7 +959,7 @@ def get_persona_text(name: str, enforce_limit=True):
|
||||
for file_path in list_persona_files():
|
||||
file = os.path.basename(file_path)
|
||||
if f"{name}.txt" == file or name == file:
|
||||
persona_text = open(file_path, "r", encoding="utf-8").read().strip()
|
||||
persona_text = open(file_path, encoding="utf-8").read().strip()
|
||||
if enforce_limit and len(persona_text) > CORE_MEMORY_PERSONA_CHAR_LIMIT:
|
||||
raise ValueError(
|
||||
f"Contents of {name}.txt is over the character limit ({len(persona_text)} > {CORE_MEMORY_PERSONA_CHAR_LIMIT})"
|
||||
@@ -1109,3 +1110,75 @@ def safe_create_task(coro, logger: Logger, label: str = "background task"):
|
||||
logger.exception(f"{label} failed with {type(e).__name__}: {e}")
|
||||
|
||||
return asyncio.create_task(wrapper())
|
||||
|
||||
|
||||
class CancellationSignal:
|
||||
"""
|
||||
A signal that can be checked for cancellation during streaming operations.
|
||||
|
||||
This provides a lightweight way to check if an operation should be cancelled
|
||||
without having to pass job managers and other dependencies through every method.
|
||||
"""
|
||||
|
||||
def __init__(self, job_manager=None, job_id=None, actor=None):
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.user import User
|
||||
from letta.services.job_manager import JobManager
|
||||
|
||||
self.job_manager: JobManager | None = job_manager
|
||||
self.job_id: str | None = job_id
|
||||
self.actor: User | None = actor
|
||||
self._is_cancelled = False
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def is_cancelled(self) -> bool:
|
||||
"""
|
||||
Check if the operation has been cancelled.
|
||||
|
||||
Returns:
|
||||
True if cancelled, False otherwise
|
||||
"""
|
||||
from letta.schemas.enums import JobStatus
|
||||
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
if not self.job_manager or not self.job_id or not self.actor:
|
||||
return False
|
||||
|
||||
try:
|
||||
job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor)
|
||||
self._is_cancelled = job.status == JobStatus.cancelled
|
||||
return self._is_cancelled
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}")
|
||||
return False
|
||||
|
||||
def cancel(self):
|
||||
"""Mark this signal as cancelled locally (for testing or direct cancellation)."""
|
||||
self._is_cancelled = True
|
||||
|
||||
async def check_and_raise_if_cancelled(self):
|
||||
"""
|
||||
Check for cancellation and raise CancelledError if cancelled.
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the operation has been cancelled
|
||||
"""
|
||||
if await self.is_cancelled():
|
||||
self.logger.info(f"Operation cancelled for job {self.job_id}")
|
||||
raise asyncio.CancelledError(f"Job {self.job_id} was cancelled")
|
||||
|
||||
|
||||
class NullCancellationSignal(CancellationSignal):
|
||||
"""A null cancellation signal that is never cancelled."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def is_cancelled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def check_and_raise_if_cancelled(self):
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user