fix: fix cancellation issues without making too many changes to message_ids persistence (#6442)

This commit is contained in:
Sarah Wooders
2025-11-29 23:08:19 -08:00
committed by Caren Thomas
parent 1f7165afc4
commit f417e53638
8 changed files with 1685 additions and 19 deletions

View File

@@ -162,7 +162,8 @@ async def _prepare_in_context_messages_no_persist_async(
new_in_context_messages.extend(follow_up_messages)
else:
# User is trying to send a regular message
if current_in_context_messages and current_in_context_messages[-1].role == "approval":
# if current_in_context_messages and current_in_context_messages[-1].role == "approval":
if current_in_context_messages and current_in_context_messages[-1].is_approval_request():
raise PendingApprovalError(pending_request_id=current_in_context_messages[-1].id)
# Create a new user message from the input but dont store it yet

View File

@@ -44,6 +44,7 @@ from letta.server.rest_api.utils import (
create_approval_request_message_from_llm_response,
create_letta_messages_from_llm_response,
create_parallel_tool_messages_from_llm_response,
create_tool_returns_for_denials,
)
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.summarizer.summarizer_all import summarize_all
@@ -701,6 +702,14 @@ class LettaAgentV3(LettaAgentV2):
finally:
self.logger.debug("Running cleanup for agent loop run: %s", run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
# update message ids
message_ids = [m.id for m in messages]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=message_ids,
actor=self.actor,
)
try:
if step_progression == StepProgression.FINISHED:
if not self.should_continue:
@@ -932,19 +941,15 @@ class LettaAgentV3(LettaAgentV2):
# 4. Handle denial cases
if tool_call_denials:
# Convert ToolCallDenial objects to ToolReturn objects using shared helper
# Group denials by reason to potentially batch them, but for now process individually
for tool_call_denial in tool_call_denials:
tool_call_id = tool_call_denial.id or f"call_{uuid.uuid4().hex[:8]}"
packaged_function_response = package_function_response(
was_success=False,
response_string=f"Error: request to call tool denied. User reason: {tool_call_denial.reason}",
denial_returns = create_tool_returns_for_denials(
tool_calls=[tool_call_denial],
denial_reason=tool_call_denial.reason,
timezone=agent_state.timezone,
)
tool_return = ToolReturn(
tool_call_id=tool_call_id,
func_response=packaged_function_response,
status="error",
)
result_tool_returns.append(tool_return)
result_tool_returns.extend(denial_returns)
# 5. Unified tool execution path (works for both single and multiple tools)

View File

@@ -195,6 +195,8 @@ PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
REQUEST_HEARTBEAT_PARAM = "request_heartbeat"
REQUEST_HEARTBEAT_DESCRIPTION = "Request an immediate heartbeat after function execution. You MUST set this value to `True` if you want to send a follow-up message or run a follow-up tool call (chain multiple tools together). If set to `False` (the default), then the chain of execution will end immediately after this function call."
# Automated tool call denials
TOOL_CALL_DENIAL_ON_CANCEL = "The user cancelled the request, so the tool call was denied."
# Structured output models
STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"}

View File

@@ -1659,6 +1659,7 @@ async def cancel_message(
Note to cancel active runs associated with an agent, redis is required.
"""
# TODO: WHY DOES THIS CANCEL A LIST OF RUNS?
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
if not settings.track_agent_run:
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
@@ -1685,12 +1686,12 @@ async def cancel_message(
if run.metadata.get("lettuce"):
lettuce_client = await LettuceClient.create()
await lettuce_client.cancel(run_id)
success = await server.run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.cancelled),
actor=actor,
)
results[run_id] = "cancelled" if success else "failed"
try:
run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
except Exception as e:
results[run_id] = "failed"
continue
results[run_id] = "cancelled"
return results

View File

@@ -213,6 +213,80 @@ def create_approval_response_message_from_input(
]
def create_tool_returns_for_denials(
tool_calls: List[OpenAIToolCall],
denial_reason: str,
timezone: str,
) -> List[ToolReturn]:
"""
Create ToolReturn objects with error status for denied tool calls.
This is used when tool calls are denied either by:
- User explicitly denying approval
- Run cancellation (automated denial)
Args:
tool_calls: List of tool calls that were denied
denial_reason: Reason for denial (e.g., user reason or cancellation message)
timezone: Agent timezone for timestamp formatting
Returns:
List of ToolReturn objects with error status
"""
tool_returns = []
for tool_call in tool_calls:
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
packaged_function_response = package_function_response(
was_success=False,
response_string=f"Error: request to call tool denied. User reason: {denial_reason}",
timezone=timezone,
)
tool_return = ToolReturn(
tool_call_id=tool_call_id,
func_response=packaged_function_response,
status="error",
)
tool_returns.append(tool_return)
return tool_returns
def create_tool_message_from_returns(
agent_id: str,
model: str,
tool_returns: List[ToolReturn],
run_id: Optional[str] = None,
step_id: Optional[str] = None,
) -> Message:
"""
Create a tool message with error returns for denied/failed tool calls.
This creates a properly formatted tool message that can be added to the
conversation history to reflect tool call denials or failures.
Args:
agent_id: ID of the agent
model: Model identifier
tool_returns: List of ToolReturn objects (typically with error status)
run_id: Optional run ID
step_id: Optional step ID
Returns:
Message with role="tool" containing the tool returns
"""
return Message(
role=MessageRole.tool,
content=[TextContent(text=tr.func_response) for tr in tool_returns],
agent_id=agent_id,
model=model,
tool_calls=[],
tool_call_id=tool_returns[0].tool_call_id if tool_returns else None,
tool_returns=tool_returns,
run_id=run_id,
step_id=step_id,
created_at=get_utc_time(),
)
def create_approval_request_message_from_llm_response(
agent_id: str,
model: str,

View File

@@ -1,9 +1,11 @@
from datetime import datetime
from multiprocessing import Value
from pickletools import pyunicode
from typing import List, Literal, Optional
from httpx import AsyncClient
from letta.errors import LettaInvalidArgumentError
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.orm.agent import Agent as AgentModel
@@ -314,7 +316,7 @@ class RunManager:
needs_callback = False
callback_url = None
not_completed_before = not bool(run.completed_at)
is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed}
is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed, RunStatus.cancelled}
if is_terminal_update and not_completed_before and run.callback_url:
needs_callback = True
callback_url = run.callback_url
@@ -558,3 +560,129 @@ class RunManager:
actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc"
)
return steps
@enforce_types
async def cancel_run(self, actor: PydanticUser, agent_id: Optional[str] = None, run_id: Optional[str] = None) -> None:
"""Cancel a run."""
# make sure run_id and agent_id are not both None
if not run_id:
# get the last agent run
if not agent_id:
raise ValueError("Agent ID is required to cancel a run by ID")
logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.")
run_ids = await self.list_runs(
actor=actor,
ascending=False,
agent_id=agent_id,
)
run_ids = [run.id for run in run_ids]
else:
# get the agent
run = await self.get_run_by_id(run_id=run_id, actor=actor)
if not run:
raise NoResultFound(f"Run with id {run_id} not found")
agent_id = run.agent_id
logger.debug(f"Cancelling run {run_id} for agent {agent_id}")
# check if run can be cancelled (cannot cancel a completed, failed, or cancelled run)
if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]:
logger.error(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}")
raise LettaInvalidArgumentError(
f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}"
)
# Check if agent is waiting for approval by examining the last message
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
current_in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
was_pending_approval = current_in_context_messages and current_in_context_messages[-1].is_approval_request()
# cancel the run
# NOTE: this should update the agent's last stop reason to cancelled
run = await self.update_run_by_id_async(
run_id=run_id, update=RunUpdate(status=RunStatus.cancelled, stop_reason=StopReasonType.cancelled), actor=actor
)
# cleanup the agent's state
# if was pending approval, we need to cleanup the approval state
if was_pending_approval:
logger.debug(f"Agent was waiting for approval, adding denial messages for run {run_id}")
approval_request_message = current_in_context_messages[-1]
# Ensure the approval request has tool calls to deny
if approval_request_message.tool_calls:
from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL
from letta.schemas.letta_message import ApprovalReturn
from letta.schemas.message import ApprovalCreate
from letta.server.rest_api.utils import (
create_approval_response_message_from_input,
create_tool_message_from_returns,
create_tool_returns_for_denials,
)
# Create denials for ALL pending tool calls
denials = [
ApprovalReturn(
tool_call_id=tool_call.id,
approve=False,
reason=TOOL_CALL_DENIAL_ON_CANCEL,
)
for tool_call in approval_request_message.tool_calls
]
# Create an ApprovalCreate input with the denials
approval_input = ApprovalCreate(
approvals=denials,
approval_request_id=approval_request_message.id,
)
# Use the standard function to create properly formatted approval response messages
approval_response_messages = create_approval_response_message_from_input(
agent_state=agent_state,
input_message=approval_input,
run_id=run_id,
)
# Create tool returns for ALL denied tool calls using shared helper
# This handles all pending tool calls at once since they all have the same denial reason
tool_returns = create_tool_returns_for_denials(
tool_calls=approval_request_message.tool_calls, # ALL pending tool calls
denial_reason=TOOL_CALL_DENIAL_ON_CANCEL,
timezone=agent_state.timezone,
)
# Create tool message with all denial returns using shared helper
tool_message = create_tool_message_from_returns(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
tool_returns=tool_returns,
run_id=run_id,
)
# Combine approval response and tool messages
new_messages = approval_response_messages + [tool_message]
# Insert the approval response and tool messages into the database
persisted_messages = await self.message_manager.create_many_messages_async(
pydantic_msgs=new_messages,
actor=actor,
run_id=run_id,
)
logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
# Update the agent's message_ids to include the new messages (approval + tool message)
agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
logger.debug(
f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. "
f"Approval request message ID: {approval_request_message.id}"
)
else:
logger.warning(
f"Last message is an approval request but has no tool_calls. "
f"Message ID: {approval_request_message.id}, Run ID: {run_id}"
)
return run

View File

@@ -173,7 +173,7 @@ async def test_background_streaming_cancellation(
) -> None:
agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config)
delay = 5 if llm_config.model == "gpt-5" else 1.5
delay = 1.5
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
response = await client.agents.messages.stream(

File diff suppressed because it is too large Load Diff