feat: cleanup cancellation code and add more logging (#6588)

This commit is contained in:
Sarah Wooders
2025-12-10 11:56:12 -08:00
committed by Caren Thomas
parent 70c57c5072
commit c8fa77a01f
7 changed files with 71 additions and 97 deletions

View File

@@ -377,3 +377,10 @@ class AgentExportProcessingError(AgentFileExportError):
class AgentFileImportError(Exception):
"""Exception raised during agent file import operations"""
class RunCancelError(LettaError):
"""Error raised when a run cannot be cancelled."""
def __init__(self, message: str):
super().__init__(message=message)

View File

@@ -168,7 +168,12 @@ class AnthropicClient(LLMClientBase):
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
betas.append("structured-outputs-2025-11-13")
return await client.beta.messages.create(**request_data, betas=betas)
# log failed requests
try:
return await client.beta.messages.create(**request_data, betas=betas)
except Exception as e:
logger.error(f"Error streaming Anthropic request: {e} with request data: {json.dumps(request_data)}")
raise e
@trace_method
async def send_llm_batch_request_async(

View File

@@ -140,11 +140,16 @@ class GoogleVertexClient(LLMClientBase):
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
client = self._get_client()
response = await client.aio.models.generate_content_stream(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
try:
response = await client.aio.models.generate_content_stream(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
except Exception as e:
logger.error(f"Error streaming Google Vertex request: {e} with request data: {json.dumps(request_data)}")
raise e
# Direct yield - keeps response alive in generator's local scope throughout iteration
# This is required because the SDK's connection lifecycle is tied to the response object
async for chunk in response:

View File

@@ -1,4 +1,5 @@
import asyncio
import json
import os
import time
from typing import Any, List, Optional
@@ -762,17 +763,25 @@ class OpenAIClient(LLMClientBase):
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
if "input" in request_data and "messages" not in request_data:
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
**request_data,
stream=True,
# stream_options={"include_usage": True},
)
try:
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
**request_data,
stream=True,
# stream_options={"include_usage": True},
)
except Exception as e:
logger.error(f"Error streaming OpenAI Responses request: {e} with request data: {json.dumps(request_data)}")
raise e
else:
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data,
stream=True,
stream_options={"include_usage": True},
)
try:
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data,
stream=True,
stream_options={"include_usage": True},
)
except Exception as e:
logger.error(f"Error streaming OpenAI Chat Completions request: {e} with request data: {json.dumps(request_data)}")
raise e
return response_stream
@trace_method

View File

@@ -25,6 +25,7 @@ from letta.errors import (
AgentFileImportError,
AgentNotFoundForExportError,
PendingApprovalError,
RunCancelError,
)
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
@@ -1682,6 +1683,7 @@ async def cancel_message(
run_ids = [run_id]
results = {}
failed_to_cancel = []
for run_id in run_ids:
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
if run.metadata.get("lettuce"):
@@ -1691,8 +1693,14 @@ async def cancel_message(
run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
except Exception as e:
results[run_id] = "failed"
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
failed_to_cancel.append(run_id)
continue
results[run_id] = "cancelled"
logger.info(f"Cancelled run {run_id}")
if failed_to_cancel:
raise RunCancelError(f"Failed to cancel runs: {failed_to_cancel}")
return results

View File

@@ -686,17 +686,28 @@ class RunManager:
# Combine approval response and tool messages
new_messages = approval_response_messages + [tool_message]
# Insert the approval response and tool messages into the database
persisted_messages = await self.message_manager.create_many_messages_async(
pydantic_msgs=new_messages,
actor=actor,
run_id=run_id,
)
logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
# Checkpoint the new messages
from letta.agents.agent_loop import AgentLoop
# Update the agent's message_ids to include the new messages (approval + tool message)
agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor)
new_in_context_messages = current_in_context_messages + new_messages
await agent_loop._checkpoint_messages(
run_id=run_id,
step_id=approval_request_message.step_id,
new_messages=new_messages,
in_context_messages=new_in_context_messages,
)
# persisted_messages = await self.message_manager.create_many_messages_async(
# pydantic_msgs=new_messages,
# actor=actor,
# run_id=run_id,
# )
# logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
## Update the agent's message_ids to include the new messages (approval + tool message)
# agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
# await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
logger.debug(
f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. "

View File

@@ -1261,77 +1261,6 @@ def safe_create_file_processing_task(coro, file_metadata, server, actor, logger:
return task
class CancellationSignal:
"""
A signal that can be checked for cancellation during streaming operations.
This provides a lightweight way to check if an operation should be cancelled
without having to pass job managers and other dependencies through every method.
"""
def __init__(self, job_manager=None, job_id=None, actor=None):
from letta.log import get_logger
from letta.schemas.user import User
from letta.services.job_manager import JobManager
self.job_manager: JobManager | None = job_manager
self.job_id: str | None = job_id
self.actor: User | None = actor
self._is_cancelled = False
self.logger = get_logger(__name__)
async def is_cancelled(self) -> bool:
"""
Check if the operation has been cancelled.
Returns:
True if cancelled, False otherwise
"""
from letta.schemas.enums import JobStatus
if self._is_cancelled:
return True
if not self.job_manager or not self.job_id or not self.actor:
return False
try:
job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor)
self._is_cancelled = job.status == JobStatus.cancelled
return self._is_cancelled
except Exception as e:
self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}")
return False
def cancel(self):
"""Mark this signal as cancelled locally (for testing or direct cancellation)."""
self._is_cancelled = True
async def check_and_raise_if_cancelled(self):
"""
Check for cancellation and raise CancelledError if cancelled.
Raises:
asyncio.CancelledError: If the operation has been cancelled
"""
if await self.is_cancelled():
self.logger.info(f"Operation cancelled for job {self.job_id}")
raise asyncio.CancelledError(f"Job {self.job_id} was cancelled")
class NullCancellationSignal(CancellationSignal):
"""A null cancellation signal that is never cancelled."""
def __init__(self):
super().__init__()
async def is_cancelled(self) -> bool:
return False
async def check_and_raise_if_cancelled(self):
pass
async def get_latest_alembic_revision() -> str:
"""Get the current alembic revision ID from the alembic_version table."""
from letta.server.db import db_registry