feat: cleanup cancellation code and add more logging (#6588)
This commit is contained in:
committed by
Caren Thomas
parent
70c57c5072
commit
c8fa77a01f
@@ -377,3 +377,10 @@ class AgentExportProcessingError(AgentFileExportError):
|
||||
|
||||
class AgentFileImportError(Exception):
|
||||
"""Exception raised during agent file import operations"""
|
||||
|
||||
|
||||
class RunCancelError(LettaError):
|
||||
"""Error raised when a run cannot be cancelled."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message=message)
|
||||
|
||||
@@ -168,7 +168,12 @@ class AnthropicClient(LLMClientBase):
|
||||
if hasattr(llm_config, "response_format") and isinstance(llm_config.response_format, JsonSchemaResponseFormat):
|
||||
betas.append("structured-outputs-2025-11-13")
|
||||
|
||||
return await client.beta.messages.create(**request_data, betas=betas)
|
||||
# log failed requests
|
||||
try:
|
||||
return await client.beta.messages.create(**request_data, betas=betas)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Anthropic request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
|
||||
@trace_method
|
||||
async def send_llm_batch_request_async(
|
||||
|
||||
@@ -140,11 +140,16 @@ class GoogleVertexClient(LLMClientBase):
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
|
||||
client = self._get_client()
|
||||
response = await client.aio.models.generate_content_stream(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
|
||||
try:
|
||||
response = await client.aio.models.generate_content_stream(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming Google Vertex request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
# Direct yield - keeps response alive in generator's local scope throughout iteration
|
||||
# This is required because the SDK's connection lifecycle is tied to the response object
|
||||
async for chunk in response:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from typing import Any, List, Optional
|
||||
@@ -762,17 +763,25 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
# Route based on payload shape: Responses uses 'input', Chat Completions uses 'messages'
|
||||
if "input" in request_data and "messages" not in request_data:
|
||||
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
# stream_options={"include_usage": True},
|
||||
)
|
||||
try:
|
||||
response_stream: AsyncStream[ResponseStreamEvent] = await client.responses.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
# stream_options={"include_usage": True},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming OpenAI Responses request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
else:
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
try:
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data,
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming OpenAI Chat Completions request: {e} with request data: {json.dumps(request_data)}")
|
||||
raise e
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
|
||||
@@ -25,6 +25,7 @@ from letta.errors import (
|
||||
AgentFileImportError,
|
||||
AgentNotFoundForExportError,
|
||||
PendingApprovalError,
|
||||
RunCancelError,
|
||||
)
|
||||
from letta.groups.sleeptime_multi_agent_v4 import SleeptimeMultiAgentV4
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||
@@ -1682,6 +1683,7 @@ async def cancel_message(
|
||||
run_ids = [run_id]
|
||||
|
||||
results = {}
|
||||
failed_to_cancel = []
|
||||
for run_id in run_ids:
|
||||
run = await server.run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.metadata.get("lettuce"):
|
||||
@@ -1691,8 +1693,14 @@ async def cancel_message(
|
||||
run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
|
||||
except Exception as e:
|
||||
results[run_id] = "failed"
|
||||
logger.error(f"Failed to cancel run {run_id}: {str(e)}")
|
||||
failed_to_cancel.append(run_id)
|
||||
continue
|
||||
results[run_id] = "cancelled"
|
||||
logger.info(f"Cancelled run {run_id}")
|
||||
|
||||
if failed_to_cancel:
|
||||
raise RunCancelError(f"Failed to cancel runs: {failed_to_cancel}")
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -686,17 +686,28 @@ class RunManager:
|
||||
# Combine approval response and tool messages
|
||||
new_messages = approval_response_messages + [tool_message]
|
||||
|
||||
# Insert the approval response and tool messages into the database
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=new_messages,
|
||||
actor=actor,
|
||||
run_id=run_id,
|
||||
)
|
||||
logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
|
||||
# Checkpoint the new messages
|
||||
from letta.agents.agent_loop import AgentLoop
|
||||
|
||||
# Update the agent's message_ids to include the new messages (approval + tool message)
|
||||
agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
|
||||
await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
|
||||
agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor)
|
||||
new_in_context_messages = current_in_context_messages + new_messages
|
||||
await agent_loop._checkpoint_messages(
|
||||
run_id=run_id,
|
||||
step_id=approval_request_message.step_id,
|
||||
new_messages=new_messages,
|
||||
in_context_messages=new_in_context_messages,
|
||||
)
|
||||
|
||||
# persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
# pydantic_msgs=new_messages,
|
||||
# actor=actor,
|
||||
# run_id=run_id,
|
||||
# )
|
||||
# logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
|
||||
|
||||
## Update the agent's message_ids to include the new messages (approval + tool message)
|
||||
# agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
|
||||
# await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
logger.debug(
|
||||
f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. "
|
||||
|
||||
@@ -1261,77 +1261,6 @@ def safe_create_file_processing_task(coro, file_metadata, server, actor, logger:
|
||||
return task
|
||||
|
||||
|
||||
class CancellationSignal:
|
||||
"""
|
||||
A signal that can be checked for cancellation during streaming operations.
|
||||
|
||||
This provides a lightweight way to check if an operation should be cancelled
|
||||
without having to pass job managers and other dependencies through every method.
|
||||
"""
|
||||
|
||||
def __init__(self, job_manager=None, job_id=None, actor=None):
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.user import User
|
||||
from letta.services.job_manager import JobManager
|
||||
|
||||
self.job_manager: JobManager | None = job_manager
|
||||
self.job_id: str | None = job_id
|
||||
self.actor: User | None = actor
|
||||
self._is_cancelled = False
|
||||
self.logger = get_logger(__name__)
|
||||
|
||||
async def is_cancelled(self) -> bool:
|
||||
"""
|
||||
Check if the operation has been cancelled.
|
||||
|
||||
Returns:
|
||||
True if cancelled, False otherwise
|
||||
"""
|
||||
from letta.schemas.enums import JobStatus
|
||||
|
||||
if self._is_cancelled:
|
||||
return True
|
||||
|
||||
if not self.job_manager or not self.job_id or not self.actor:
|
||||
return False
|
||||
|
||||
try:
|
||||
job = await self.job_manager.get_job_by_id_async(job_id=self.job_id, actor=self.actor)
|
||||
self._is_cancelled = job.status == JobStatus.cancelled
|
||||
return self._is_cancelled
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to check cancellation status for job {self.job_id}: {e}")
|
||||
return False
|
||||
|
||||
def cancel(self):
|
||||
"""Mark this signal as cancelled locally (for testing or direct cancellation)."""
|
||||
self._is_cancelled = True
|
||||
|
||||
async def check_and_raise_if_cancelled(self):
|
||||
"""
|
||||
Check for cancellation and raise CancelledError if cancelled.
|
||||
|
||||
Raises:
|
||||
asyncio.CancelledError: If the operation has been cancelled
|
||||
"""
|
||||
if await self.is_cancelled():
|
||||
self.logger.info(f"Operation cancelled for job {self.job_id}")
|
||||
raise asyncio.CancelledError(f"Job {self.job_id} was cancelled")
|
||||
|
||||
|
||||
class NullCancellationSignal(CancellationSignal):
|
||||
"""A null cancellation signal that is never cancelled."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
async def is_cancelled(self) -> bool:
|
||||
return False
|
||||
|
||||
async def check_and_raise_if_cancelled(self):
|
||||
pass
|
||||
|
||||
|
||||
async def get_latest_alembic_revision() -> str:
|
||||
"""Get the current alembic revision ID from the alembic_version table."""
|
||||
from letta.server.db import db_registry
|
||||
|
||||
Reference in New Issue
Block a user