fix: fix cancellation issues without making too many changes to message_ids persistence (#6442)
This commit is contained in:
committed by
Caren Thomas
parent
1f7165afc4
commit
f417e53638
@@ -162,7 +162,8 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
new_in_context_messages.extend(follow_up_messages)
|
||||
else:
|
||||
# User is trying to send a regular message
|
||||
if current_in_context_messages and current_in_context_messages[-1].role == "approval":
|
||||
# if current_in_context_messages and current_in_context_messages[-1].role == "approval":
|
||||
if current_in_context_messages and current_in_context_messages[-1].is_approval_request():
|
||||
raise PendingApprovalError(pending_request_id=current_in_context_messages[-1].id)
|
||||
|
||||
# Create a new user message from the input but dont store it yet
|
||||
|
||||
@@ -44,6 +44,7 @@ from letta.server.rest_api.utils import (
|
||||
create_approval_request_message_from_llm_response,
|
||||
create_letta_messages_from_llm_response,
|
||||
create_parallel_tool_messages_from_llm_response,
|
||||
create_tool_returns_for_denials,
|
||||
)
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.summarizer.summarizer_all import summarize_all
|
||||
@@ -701,6 +702,14 @@ class LettaAgentV3(LettaAgentV2):
|
||||
finally:
|
||||
self.logger.debug("Running cleanup for agent loop run: %s", run_id)
|
||||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||
|
||||
# update message ids
|
||||
message_ids = [m.id for m in messages]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
try:
|
||||
if step_progression == StepProgression.FINISHED:
|
||||
if not self.should_continue:
|
||||
@@ -932,19 +941,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
# 4. Handle denial cases
|
||||
if tool_call_denials:
|
||||
# Convert ToolCallDenial objects to ToolReturn objects using shared helper
|
||||
# Group denials by reason to potentially batch them, but for now process individually
|
||||
for tool_call_denial in tool_call_denials:
|
||||
tool_call_id = tool_call_denial.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
packaged_function_response = package_function_response(
|
||||
was_success=False,
|
||||
response_string=f"Error: request to call tool denied. User reason: {tool_call_denial.reason}",
|
||||
denial_returns = create_tool_returns_for_denials(
|
||||
tool_calls=[tool_call_denial],
|
||||
denial_reason=tool_call_denial.reason,
|
||||
timezone=agent_state.timezone,
|
||||
)
|
||||
tool_return = ToolReturn(
|
||||
tool_call_id=tool_call_id,
|
||||
func_response=packaged_function_response,
|
||||
status="error",
|
||||
)
|
||||
result_tool_returns.append(tool_return)
|
||||
result_tool_returns.extend(denial_returns)
|
||||
|
||||
# 5. Unified tool execution path (works for both single and multiple tools)
|
||||
|
||||
|
||||
@@ -195,6 +195,8 @@ PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
|
||||
REQUEST_HEARTBEAT_PARAM = "request_heartbeat"
|
||||
REQUEST_HEARTBEAT_DESCRIPTION = "Request an immediate heartbeat after function execution. You MUST set this value to `True` if you want to send a follow-up message or run a follow-up tool call (chain multiple tools together). If set to `False` (the default), then the chain of execution will end immediately after this function call."
|
||||
|
||||
# Automated tool call denials
|
||||
TOOL_CALL_DENIAL_ON_CANCEL = "The user cancelled the request, so the tool call was denied."
|
||||
|
||||
# Structured output models
|
||||
STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"}
|
||||
|
||||
@@ -1659,6 +1659,7 @@ async def cancel_message(
|
||||
|
||||
Note to cancel active runs associated with an agent, redis is required.
|
||||
"""
|
||||
# TODO: WHY DOES THIS CANCEL A LIST OF RUNS?
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
if not settings.track_agent_run:
|
||||
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
|
||||
@@ -1685,12 +1686,12 @@ async def cancel_message(
|
||||
if run.metadata.get("lettuce"):
|
||||
lettuce_client = await LettuceClient.create()
|
||||
await lettuce_client.cancel(run_id)
|
||||
success = await server.run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.cancelled),
|
||||
actor=actor,
|
||||
)
|
||||
results[run_id] = "cancelled" if success else "failed"
|
||||
try:
|
||||
run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id)
|
||||
except Exception as e:
|
||||
results[run_id] = "failed"
|
||||
continue
|
||||
results[run_id] = "cancelled"
|
||||
return results
|
||||
|
||||
|
||||
|
||||
@@ -213,6 +213,80 @@ def create_approval_response_message_from_input(
|
||||
]
|
||||
|
||||
|
||||
def create_tool_returns_for_denials(
|
||||
tool_calls: List[OpenAIToolCall],
|
||||
denial_reason: str,
|
||||
timezone: str,
|
||||
) -> List[ToolReturn]:
|
||||
"""
|
||||
Create ToolReturn objects with error status for denied tool calls.
|
||||
|
||||
This is used when tool calls are denied either by:
|
||||
- User explicitly denying approval
|
||||
- Run cancellation (automated denial)
|
||||
|
||||
Args:
|
||||
tool_calls: List of tool calls that were denied
|
||||
denial_reason: Reason for denial (e.g., user reason or cancellation message)
|
||||
timezone: Agent timezone for timestamp formatting
|
||||
|
||||
Returns:
|
||||
List of ToolReturn objects with error status
|
||||
"""
|
||||
tool_returns = []
|
||||
for tool_call in tool_calls:
|
||||
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
packaged_function_response = package_function_response(
|
||||
was_success=False,
|
||||
response_string=f"Error: request to call tool denied. User reason: {denial_reason}",
|
||||
timezone=timezone,
|
||||
)
|
||||
tool_return = ToolReturn(
|
||||
tool_call_id=tool_call_id,
|
||||
func_response=packaged_function_response,
|
||||
status="error",
|
||||
)
|
||||
tool_returns.append(tool_return)
|
||||
return tool_returns
|
||||
|
||||
|
||||
def create_tool_message_from_returns(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
tool_returns: List[ToolReturn],
|
||||
run_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
) -> Message:
|
||||
"""
|
||||
Create a tool message with error returns for denied/failed tool calls.
|
||||
|
||||
This creates a properly formatted tool message that can be added to the
|
||||
conversation history to reflect tool call denials or failures.
|
||||
|
||||
Args:
|
||||
agent_id: ID of the agent
|
||||
model: Model identifier
|
||||
tool_returns: List of ToolReturn objects (typically with error status)
|
||||
run_id: Optional run ID
|
||||
step_id: Optional step ID
|
||||
|
||||
Returns:
|
||||
Message with role="tool" containing the tool returns
|
||||
"""
|
||||
return Message(
|
||||
role=MessageRole.tool,
|
||||
content=[TextContent(text=tr.func_response) for tr in tool_returns],
|
||||
agent_id=agent_id,
|
||||
model=model,
|
||||
tool_calls=[],
|
||||
tool_call_id=tool_returns[0].tool_call_id if tool_returns else None,
|
||||
tool_returns=tool_returns,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
created_at=get_utc_time(),
|
||||
)
|
||||
|
||||
|
||||
def create_approval_request_message_from_llm_response(
|
||||
agent_id: str,
|
||||
model: str,
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
from datetime import datetime
|
||||
from multiprocessing import Value
|
||||
from pickletools import pyunicode
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from httpx import AsyncClient
|
||||
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.orm.agent import Agent as AgentModel
|
||||
@@ -314,7 +316,7 @@ class RunManager:
|
||||
needs_callback = False
|
||||
callback_url = None
|
||||
not_completed_before = not bool(run.completed_at)
|
||||
is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed}
|
||||
is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed, RunStatus.cancelled}
|
||||
if is_terminal_update and not_completed_before and run.callback_url:
|
||||
needs_callback = True
|
||||
callback_url = run.callback_url
|
||||
@@ -558,3 +560,129 @@ class RunManager:
|
||||
actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc"
|
||||
)
|
||||
return steps
|
||||
|
||||
@enforce_types
|
||||
async def cancel_run(self, actor: PydanticUser, agent_id: Optional[str] = None, run_id: Optional[str] = None) -> None:
|
||||
"""Cancel a run."""
|
||||
|
||||
# make sure run_id and agent_id are not both None
|
||||
if not run_id:
|
||||
# get the last agent run
|
||||
if not agent_id:
|
||||
raise ValueError("Agent ID is required to cancel a run by ID")
|
||||
logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.")
|
||||
run_ids = await self.list_runs(
|
||||
actor=actor,
|
||||
ascending=False,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
run_ids = [run.id for run in run_ids]
|
||||
else:
|
||||
# get the agent
|
||||
run = await self.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if not run:
|
||||
raise NoResultFound(f"Run with id {run_id} not found")
|
||||
agent_id = run.agent_id
|
||||
|
||||
logger.debug(f"Cancelling run {run_id} for agent {agent_id}")
|
||||
|
||||
# check if run can be cancelled (cannot cancel a completed, failed, or cancelled run)
|
||||
if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]:
|
||||
logger.error(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}")
|
||||
raise LettaInvalidArgumentError(
|
||||
f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}"
|
||||
)
|
||||
|
||||
# Check if agent is waiting for approval by examining the last message
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
current_in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
||||
was_pending_approval = current_in_context_messages and current_in_context_messages[-1].is_approval_request()
|
||||
|
||||
# cancel the run
|
||||
# NOTE: this should update the agent's last stop reason to cancelled
|
||||
run = await self.update_run_by_id_async(
|
||||
run_id=run_id, update=RunUpdate(status=RunStatus.cancelled, stop_reason=StopReasonType.cancelled), actor=actor
|
||||
)
|
||||
|
||||
# cleanup the agent's state
|
||||
# if was pending approval, we need to cleanup the approval state
|
||||
if was_pending_approval:
|
||||
logger.debug(f"Agent was waiting for approval, adding denial messages for run {run_id}")
|
||||
approval_request_message = current_in_context_messages[-1]
|
||||
|
||||
# Ensure the approval request has tool calls to deny
|
||||
if approval_request_message.tool_calls:
|
||||
from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL
|
||||
from letta.schemas.letta_message import ApprovalReturn
|
||||
from letta.schemas.message import ApprovalCreate
|
||||
from letta.server.rest_api.utils import (
|
||||
create_approval_response_message_from_input,
|
||||
create_tool_message_from_returns,
|
||||
create_tool_returns_for_denials,
|
||||
)
|
||||
|
||||
# Create denials for ALL pending tool calls
|
||||
denials = [
|
||||
ApprovalReturn(
|
||||
tool_call_id=tool_call.id,
|
||||
approve=False,
|
||||
reason=TOOL_CALL_DENIAL_ON_CANCEL,
|
||||
)
|
||||
for tool_call in approval_request_message.tool_calls
|
||||
]
|
||||
|
||||
# Create an ApprovalCreate input with the denials
|
||||
approval_input = ApprovalCreate(
|
||||
approvals=denials,
|
||||
approval_request_id=approval_request_message.id,
|
||||
)
|
||||
|
||||
# Use the standard function to create properly formatted approval response messages
|
||||
approval_response_messages = create_approval_response_message_from_input(
|
||||
agent_state=agent_state,
|
||||
input_message=approval_input,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
# Create tool returns for ALL denied tool calls using shared helper
|
||||
# This handles all pending tool calls at once since they all have the same denial reason
|
||||
tool_returns = create_tool_returns_for_denials(
|
||||
tool_calls=approval_request_message.tool_calls, # ALL pending tool calls
|
||||
denial_reason=TOOL_CALL_DENIAL_ON_CANCEL,
|
||||
timezone=agent_state.timezone,
|
||||
)
|
||||
|
||||
# Create tool message with all denial returns using shared helper
|
||||
tool_message = create_tool_message_from_returns(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
tool_returns=tool_returns,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
# Combine approval response and tool messages
|
||||
new_messages = approval_response_messages + [tool_message]
|
||||
|
||||
# Insert the approval response and tool messages into the database
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=new_messages,
|
||||
actor=actor,
|
||||
run_id=run_id,
|
||||
)
|
||||
logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
|
||||
|
||||
# Update the agent's message_ids to include the new messages (approval + tool message)
|
||||
agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
|
||||
await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
|
||||
|
||||
logger.debug(
|
||||
f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. "
|
||||
f"Approval request message ID: {approval_request_message.id}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Last message is an approval request but has no tool_calls. "
|
||||
f"Message ID: {approval_request_message.id}, Run ID: {run_id}"
|
||||
)
|
||||
|
||||
return run
|
||||
|
||||
@@ -173,7 +173,7 @@ async def test_background_streaming_cancellation(
|
||||
) -> None:
|
||||
agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
delay = 5 if llm_config.model == "gpt-5" else 1.5
|
||||
delay = 1.5
|
||||
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
|
||||
|
||||
response = await client.agents.messages.stream(
|
||||
|
||||
1455
tests/managers/test_cancellation.py
Normal file
1455
tests/managers/test_cancellation.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user