From f417e536385cd42473b42fcf1bf09ebeb4767724 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Sat, 29 Nov 2025 23:08:19 -0800 Subject: [PATCH] fix: fix cancellation issues without making too many changes to `message_ids` persistence (#6442) --- letta/agents/helpers.py | 3 +- letta/agents/letta_agent_v3.py | 25 +- letta/constants.py | 2 + letta/server/rest_api/routers/v1/agents.py | 13 +- letta/server/rest_api/utils.py | 74 + letta/services/run_manager.py | 130 +- tests/integration_test_cancellation.py | 2 +- tests/managers/test_cancellation.py | 1455 ++++++++++++++++++++ 8 files changed, 1685 insertions(+), 19 deletions(-) create mode 100644 tests/managers/test_cancellation.py diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 5cddbb89..796a4771 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -162,7 +162,8 @@ async def _prepare_in_context_messages_no_persist_async( new_in_context_messages.extend(follow_up_messages) else: # User is trying to send a regular message - if current_in_context_messages and current_in_context_messages[-1].role == "approval": + # if current_in_context_messages and current_in_context_messages[-1].role == "approval": + if current_in_context_messages and current_in_context_messages[-1].is_approval_request(): raise PendingApprovalError(pending_request_id=current_in_context_messages[-1].id) # Create a new user message from the input but dont store it yet diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 3af7fad0..ce691540 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -44,6 +44,7 @@ from letta.server.rest_api.utils import ( create_approval_request_message_from_llm_response, create_letta_messages_from_llm_response, create_parallel_tool_messages_from_llm_response, + create_tool_returns_for_denials, ) from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema from letta.services.summarizer.summarizer_all import summarize_all @@ -701,6 +702,14 @@ class LettaAgentV3(LettaAgentV2): finally: self.logger.debug("Running cleanup for agent loop run: %s", run_id) self.logger.info("Running final update. Step Progression: %s", step_progression) + + # update message ids + message_ids = [m.id for m in messages] + await self.agent_manager.update_message_ids_async( + agent_id=self.agent_state.id, + message_ids=message_ids, + actor=self.actor, + ) try: if step_progression == StepProgression.FINISHED: if not self.should_continue: @@ -932,19 +941,15 @@ class LettaAgentV3(LettaAgentV2): # 4. Handle denial cases if tool_call_denials: + # Convert ToolCallDenial objects to ToolReturn objects using shared helper + # Group denials by reason to potentially batch them, but for now process individually for tool_call_denial in tool_call_denials: - tool_call_id = tool_call_denial.id or f"call_{uuid.uuid4().hex[:8]}" - packaged_function_response = package_function_response( - was_success=False, - response_string=f"Error: request to call tool denied. User reason: {tool_call_denial.reason}", + denial_returns = create_tool_returns_for_denials( + tool_calls=[tool_call_denial], + denial_reason=tool_call_denial.reason, timezone=agent_state.timezone, ) - tool_return = ToolReturn( - tool_call_id=tool_call_id, - func_response=packaged_function_response, - status="error", - ) - result_tool_returns.append(tool_return) + result_tool_returns.extend(denial_returns) # 5. Unified tool execution path (works for both single and multiple tools) diff --git a/letta/constants.py b/letta/constants.py index 4b3549a5..72ec844e 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -195,6 +195,8 @@ PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg" REQUEST_HEARTBEAT_PARAM = "request_heartbeat" REQUEST_HEARTBEAT_DESCRIPTION = "Request an immediate heartbeat after function execution. You MUST set this value to `True` if you want to send a follow-up message or run a follow-up tool call (chain multiple tools together). If set to `False` (the default), then the chain of execution will end immediately after this function call." +# Automated tool call denials +TOOL_CALL_DENIAL_ON_CANCEL = "The user cancelled the request, so the tool call was denied." # Structured output models STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"} diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 41ae7022..7e8330c9 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1659,6 +1659,7 @@ async def cancel_message( Note to cancel active runs associated with an agent, redis is required. """ + # TODO: WHY DOES THIS CANCEL A LIST OF RUNS? actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) if not settings.track_agent_run: raise HTTPException(status_code=400, detail="Agent run tracking is disabled") @@ -1685,12 +1686,12 @@ async def cancel_message( if run.metadata.get("lettuce"): lettuce_client = await LettuceClient.create() await lettuce_client.cancel(run_id) - success = await server.run_manager.update_run_by_id_async( - run_id=run_id, - update=RunUpdate(status=RunStatus.cancelled), - actor=actor, - ) - results[run_id] = "cancelled" if success else "failed" + try: + run = await server.run_manager.cancel_run(actor=actor, agent_id=agent_id, run_id=run_id) + except Exception as e: + results[run_id] = "failed" + continue + results[run_id] = "cancelled" return results diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 8b2125d0..81084731 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -213,6 +213,80 @@ def create_approval_response_message_from_input( ] +def create_tool_returns_for_denials( + tool_calls: List[OpenAIToolCall], + denial_reason: str, + timezone: str, +) -> List[ToolReturn]: + """ + Create ToolReturn objects with error status for denied tool calls. + + This is used when tool calls are denied either by: + - User explicitly denying approval + - Run cancellation (automated denial) + + Args: + tool_calls: List of tool calls that were denied + denial_reason: Reason for denial (e.g., user reason or cancellation message) + timezone: Agent timezone for timestamp formatting + + Returns: + List of ToolReturn objects with error status + """ + tool_returns = [] + for tool_call in tool_calls: + tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}" + packaged_function_response = package_function_response( + was_success=False, + response_string=f"Error: request to call tool denied. User reason: {denial_reason}", + timezone=timezone, + ) + tool_return = ToolReturn( + tool_call_id=tool_call_id, + func_response=packaged_function_response, + status="error", + ) + tool_returns.append(tool_return) + return tool_returns + + +def create_tool_message_from_returns( + agent_id: str, + model: str, + tool_returns: List[ToolReturn], + run_id: Optional[str] = None, + step_id: Optional[str] = None, +) -> Message: + """ + Create a tool message with error returns for denied/failed tool calls. + + This creates a properly formatted tool message that can be added to the + conversation history to reflect tool call denials or failures. + + Args: + agent_id: ID of the agent + model: Model identifier + tool_returns: List of ToolReturn objects (typically with error status) + run_id: Optional run ID + step_id: Optional step ID + + Returns: + Message with role="tool" containing the tool returns + """ + return Message( + role=MessageRole.tool, + content=[TextContent(text=tr.func_response) for tr in tool_returns], + agent_id=agent_id, + model=model, + tool_calls=[], + tool_call_id=tool_returns[0].tool_call_id if tool_returns else None, + tool_returns=tool_returns, + run_id=run_id, + step_id=step_id, + created_at=get_utc_time(), + ) + + def create_approval_request_message_from_llm_response( agent_id: str, model: str, diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index b890de50..52f6ad2a 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -1,9 +1,11 @@ from datetime import datetime +from multiprocessing import Value from pickletools import pyunicode from typing import List, Literal, Optional from httpx import AsyncClient +from letta.errors import LettaInvalidArgumentError from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.orm.agent import Agent as AgentModel @@ -314,7 +316,7 @@ class RunManager: needs_callback = False callback_url = None not_completed_before = not bool(run.completed_at) - is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed} + is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed, RunStatus.cancelled} if is_terminal_update and not_completed_before and run.callback_url: needs_callback = True callback_url = run.callback_url @@ -558,3 +560,129 @@ class RunManager: actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc" ) return steps + + @enforce_types + async def cancel_run(self, actor: PydanticUser, agent_id: Optional[str] = None, run_id: Optional[str] = None) -> None: + """Cancel a run.""" + + # make sure run_id and agent_id are not both None + if not run_id: + # get the last agent run + if not agent_id: + raise ValueError("Agent ID is required to cancel a run by ID") + logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.") + run_ids = await self.list_runs( + actor=actor, + ascending=False, + agent_id=agent_id, + ) + run_ids = [run.id for run in run_ids] + else: + # get the agent + run = await self.get_run_by_id(run_id=run_id, actor=actor) + if not run: + raise NoResultFound(f"Run with id {run_id} not found") + agent_id = run.agent_id + + logger.debug(f"Cancelling run {run_id} for agent {agent_id}") + + # check if run can be cancelled (cannot cancel a completed, failed, or cancelled run) + if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]: + logger.error(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}") + raise LettaInvalidArgumentError( + f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}" + ) + + # Check if agent is waiting for approval by examining the last message + agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) + current_in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) + was_pending_approval = current_in_context_messages and current_in_context_messages[-1].is_approval_request() + + # cancel the run + # NOTE: this should update the agent's last stop reason to cancelled + run = await self.update_run_by_id_async( + run_id=run_id, update=RunUpdate(status=RunStatus.cancelled, stop_reason=StopReasonType.cancelled), actor=actor + ) + + # cleanup the agent's state + # if was pending approval, we need to cleanup the approval state + if was_pending_approval: + logger.debug(f"Agent was waiting for approval, adding denial messages for run {run_id}") + approval_request_message = current_in_context_messages[-1] + + # Ensure the approval request has tool calls to deny + if approval_request_message.tool_calls: + from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL + from letta.schemas.letta_message import ApprovalReturn + from letta.schemas.message import ApprovalCreate + from letta.server.rest_api.utils import ( + create_approval_response_message_from_input, + create_tool_message_from_returns, + create_tool_returns_for_denials, + ) + + # Create denials for ALL pending tool calls + denials = [ + ApprovalReturn( + tool_call_id=tool_call.id, + approve=False, + reason=TOOL_CALL_DENIAL_ON_CANCEL, + ) + for tool_call in approval_request_message.tool_calls + ] + + # Create an ApprovalCreate input with the denials + approval_input = ApprovalCreate( + approvals=denials, + approval_request_id=approval_request_message.id, + ) + + # Use the standard function to create properly formatted approval response messages + approval_response_messages = create_approval_response_message_from_input( + agent_state=agent_state, + input_message=approval_input, + run_id=run_id, + ) + + # Create tool returns for ALL denied tool calls using shared helper + # This handles all pending tool calls at once since they all have the same denial reason + tool_returns = create_tool_returns_for_denials( + tool_calls=approval_request_message.tool_calls, # ALL pending tool calls + denial_reason=TOOL_CALL_DENIAL_ON_CANCEL, + timezone=agent_state.timezone, + ) + + # Create tool message with all denial returns using shared helper + tool_message = create_tool_message_from_returns( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + tool_returns=tool_returns, + run_id=run_id, + ) + + # Combine approval response and tool messages + new_messages = approval_response_messages + [tool_message] + + # Insert the approval response and tool messages into the database + persisted_messages = await self.message_manager.create_many_messages_async( + pydantic_msgs=new_messages, + actor=actor, + run_id=run_id, + ) + logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)") + + # Update the agent's message_ids to include the new messages (approval + tool message) + agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages] + await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor) + + logger.debug( + f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. " + f"Approval request message ID: {approval_request_message.id}" + ) + else: + logger.warning( + f"Last message is an approval request but has no tool_calls. " + f"Message ID: {approval_request_message.id}, Run ID: {run_id}" + ) + + return run diff --git a/tests/integration_test_cancellation.py b/tests/integration_test_cancellation.py index 45042897..c6791612 100644 --- a/tests/integration_test_cancellation.py +++ b/tests/integration_test_cancellation.py @@ -173,7 +173,7 @@ async def test_background_streaming_cancellation( ) -> None: agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config) - delay = 5 if llm_config.model == "gpt-5" else 1.5 + delay = 1.5 _cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay)) response = await client.agents.messages.stream( diff --git a/tests/managers/test_cancellation.py b/tests/managers/test_cancellation.py new file mode 100644 index 00000000..554b700f --- /dev/null +++ b/tests/managers/test_cancellation.py @@ -0,0 +1,1455 @@ +""" +Tests for agent cancellation at different points in the execution loop. + +These tests use mocking and deterministic control to test cancellation at specific +points in the agent execution flow, covering all the issues documented in CANCELLATION_ISSUES.md. +""" + +import asyncio +from typing import AsyncGenerator +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest + +from letta.agents.agent_loop import AgentLoop +from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL +from letta.schemas.agent import CreateAgent +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import MessageRole, RunStatus +from letta.schemas.letta_request import LettaStreamingRequest +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate +from letta.schemas.model import ModelSettings +from letta.schemas.run import Run as PydanticRun, RunUpdate +from letta.server.server import SyncServer +from letta.services.streaming_service import StreamingService + + +@pytest.fixture +async def test_agent_with_tool(server: SyncServer, default_user, print_tool): + """Create a test agent with letta_v1_agent type (uses LettaAgentV3).""" + agent_state = await server.agent_manager.create_agent_async( + agent_create=CreateAgent( + name="test_cancellation_agent", + agent_type="letta_v1_agent", + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tool_ids=[print_tool.id], + include_base_tools=False, + ), + actor=default_user, + ) + yield agent_state + + +@pytest.fixture +async def test_run(server: SyncServer, default_user, test_agent_with_tool): + """Create a test run for cancellation testing.""" + run = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=test_agent_with_tool.id, + status=RunStatus.created, + ), + actor=default_user, + ) + yield run + + +class TestMessageStateDesyncIssues: + """ + Test Issue #2: Message State Desync Issues + Tests that message state stays consistent between client and server during cancellation. + """ + + @pytest.mark.asyncio + async def test_message_state_consistency_after_cancellation( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test that message state is consistent after cancellation. + + Verifies: + - response_messages list matches persisted messages + - response_messages_for_metadata list matches persisted messages + - agent.message_ids includes all persisted messages + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + input_messages = [MessageCreate(role=MessageRole.user, content="Call print_tool with 'test'")] + + # Cancel after first step + call_count = [0] + + async def mock_check_cancellation(run_id): + call_count[0] += 1 + if call_count[0] > 1: + await server.run_manager.cancel_run( + actor=default_user, + run_id=run_id, + ) + return True + return False + + agent_loop._check_run_cancellation = mock_check_cancellation + + # Execute step + result = await agent_loop.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + # Get messages from database + db_messages = await server.message_manager.list_messages( + actor=default_user, + agent_id=test_agent_with_tool.id, + run_id=test_run.id, + limit=1000, + ) + + # Verify response_messages count matches result messages + assert len(agent_loop.response_messages) == len(result.messages), ( + f"response_messages ({len(agent_loop.response_messages)}) should match result.messages ({len(result.messages)})" + ) + + # Verify persisted message count is reasonable + assert len(db_messages) > 0, "Should have persisted messages from completed step" + + # CRITICAL CHECK: Verify agent state after cancellation + agent_after_cancel = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + + # Verify last_stop_reason is set to cancelled + assert agent_after_cancel.last_stop_reason == "cancelled", ( + f"Agent's last_stop_reason should be 'cancelled', got '{agent_after_cancel.last_stop_reason}'" + ) + + agent_message_ids = set(agent_after_cancel.message_ids or []) + db_message_ids = {m.id for m in db_messages} + + # Check for desync: every message in DB must be in agent.message_ids + messages_in_db_not_in_agent = db_message_ids - agent_message_ids + + assert len(messages_in_db_not_in_agent) == 0, ( + f"MESSAGE DESYNC: {len(messages_in_db_not_in_agent)} messages in DB but not in agent.message_ids\n" + f"Missing message IDs: {messages_in_db_not_in_agent}\n" + f"This indicates message_ids was not updated after cancellation." + ) + + @pytest.mark.asyncio + async def test_agent_can_continue_after_cancellation( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test that agent can continue execution after a cancelled run. + + Verifies: + - Agent state is not corrupted after cancellation + - Subsequent runs complete successfully + - Message IDs are properly updated + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + # First run: cancel it + input_messages_1 = [MessageCreate(role=MessageRole.user, content="First message")] + + # Cancel immediately + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + result_1 = await agent_loop.step( + input_messages=input_messages_1, + max_steps=5, + run_id=test_run.id, + ) + + assert result_1.stop_reason.stop_reason == "cancelled" + + # Get agent state after cancellation + agent_after_cancel = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + + # Verify last_stop_reason is set to cancelled + assert agent_after_cancel.last_stop_reason == "cancelled", ( + f"Agent's last_stop_reason should be 'cancelled', got '{agent_after_cancel.last_stop_reason}'" + ) + + message_ids_after_cancel = len(agent_after_cancel.message_ids or []) + + # Second run: complete it successfully + test_run_2 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=test_agent_with_tool.id, + status=RunStatus.created, + ), + actor=default_user, + ) + + # Reload agent loop with fresh state + agent_loop_2 = AgentLoop.load( + agent_state=await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ), + actor=default_user, + ) + + input_messages_2 = [MessageCreate(role=MessageRole.user, content="Second message")] + + result_2 = await agent_loop_2.step( + input_messages=input_messages_2, + max_steps=5, + run_id=test_run_2.id, + ) + + # Verify second run completed successfully + assert result_2.stop_reason.stop_reason != "cancelled", f"Second run should complete, got {result_2.stop_reason.stop_reason}" + + # Get agent state after completion + agent_after_complete = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + message_ids_after_complete = len(agent_after_complete.message_ids or []) + + # Verify message count increased + assert message_ids_after_complete >= message_ids_after_cancel, ( + f"Message IDs should increase or stay same: " + f"after_cancel={message_ids_after_cancel}, after_complete={message_ids_after_complete}" + ) + + # CRITICAL CHECK: Verify agent.message_ids consistency with DB for BOTH runs + # Check first run (cancelled) + db_messages_run1 = await server.message_manager.list_messages( + actor=default_user, + agent_id=test_agent_with_tool.id, + run_id=test_run.id, + limit=1000, + ) + + # Check second run (completed) + db_messages_run2 = await server.message_manager.list_messages( + actor=default_user, + agent_id=test_agent_with_tool.id, + run_id=test_run_2.id, + limit=1000, + ) + + agent_message_ids = set(agent_after_complete.message_ids or []) + all_db_message_ids = {m.id for m in db_messages_run1} | {m.id for m in db_messages_run2} + + # Check for desync: every message in DB must be in agent.message_ids + messages_in_db_not_in_agent = all_db_message_ids - agent_message_ids + + assert len(messages_in_db_not_in_agent) == 0, ( + f"MESSAGE DESYNC: {len(messages_in_db_not_in_agent)} messages in DB but not in agent.message_ids\n" + f"Missing message IDs: {messages_in_db_not_in_agent}\n" + f"Run 1 (cancelled) had {len(db_messages_run1)} messages\n" + f"Run 2 (completed) had {len(db_messages_run2)} messages\n" + f"Agent has {len(agent_message_ids)} message_ids total\n" + f"This indicates message_ids was not updated properly after cancellation or continuation." + ) + + @pytest.mark.asyncio + async def test_approval_request_message_ids_desync_with_background_token_streaming( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + bash_tool, + ): + """ + Test for the specific desync bug with BACKGROUND + TOKEN STREAMING. + + This is the EXACT scenario where the bug occurs in production: + - background=True (background streaming) + - stream_tokens=True (token streaming) + - Agent calls HITL tool requiring approval + - Run is cancelled during approval + + Bug Scenario: + 1. Agent calls HITL tool requiring approval + 2. Approval request message is persisted to DB + 3. Run is cancelled while processing in background with token streaming + 4. Approval request message ID is NOT in agent.message_ids + 5. Result: "Desync detected - cursor last: X, in-context last: Y" + """ + # Add bash_tool to agent (requires approval) + await server.agent_manager.attach_tool_async( + agent_id=test_agent_with_tool.id, + tool_id=bash_tool.id, + actor=default_user, + ) + + # Get initial message count + agent_before = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + initial_message_ids = set(agent_before.message_ids or []) + print(f"\nInitial message_ids count: {len(initial_message_ids)}") + + # Create streaming service + streaming_service = StreamingService(server) + + # Create request with BACKGROUND + TOKEN STREAMING (the key conditions!) + request = LettaStreamingRequest( + messages=[MessageCreate(role=MessageRole.user, content="Please run the bash_tool with operation 'test'")], + max_steps=5, + stream_tokens=True, # TOKEN STREAMING - KEY CONDITION + background=True, # BACKGROUND STREAMING - KEY CONDITION + ) + + print("\nšŸ”„ Starting agent with BACKGROUND + TOKEN STREAMING...") + print(f" stream_tokens={request.stream_tokens}") + print(f" background={request.background}") + + # Start the background streaming agent + run, stream_response = await streaming_service.create_agent_stream( + agent_id=test_agent_with_tool.id, + actor=default_user, + request=request, + run_type="test_desync", + ) + + assert run is not None, "Run should be created for background streaming" + print(f"\nāœ… Run created: {run.id}") + print(f" Status: {run.status}") + + # Cancel almost immediately - we want to interrupt DURING processing, not after + # The bug happens when cancellation interrupts the approval flow mid-execution + print("\nā³ Starting background task, will cancel quickly to catch mid-execution...") + await asyncio.sleep(0.3) # Just enough time for LLM to start, but not complete + + # NOW CANCEL THE RUN WHILE IT'S STILL PROCESSING - This is where the bug happens! + print("\nāŒ CANCELLING RUN while in background + token streaming mode (MID-EXECUTION)...") + await server.run_manager.cancel_run( + actor=default_user, + run_id=run.id, + ) + + # Give cancellation time to propagate and background task to react + print("ā³ Waiting for cancellation to propagate through background task...") + await asyncio.sleep(2) # Let the background task detect cancellation and clean up + + # Check run status after cancellation + run_status = await server.run_manager.get_run_by_id(run.id, actor=default_user) + print(f"\nšŸ“Š Run status after cancel: {run_status.status}") + print(f" Stop reason: {run_status.stop_reason}") + + # Get messages from DB AFTER cancellation + db_messages_after_cancel = await server.message_manager.list_messages( + actor=default_user, + agent_id=test_agent_with_tool.id, + run_id=run.id, + limit=1000, + ) + print(f"\nšŸ“Ø Messages in DB after cancel: {len(db_messages_after_cancel)}") + for msg in db_messages_after_cancel: + print(f" - {msg.id}: role={msg.role}") + + # Get agent state AFTER cancellation + agent_after_cancel = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + + # Verify last_stop_reason is set to cancelled + print(f"\nšŸ” Agent last_stop_reason: {agent_after_cancel.last_stop_reason}") + assert agent_after_cancel.last_stop_reason == "cancelled", ( + f"Agent's last_stop_reason should be 'cancelled', got '{agent_after_cancel.last_stop_reason}'" + ) + + agent_message_ids = set(agent_after_cancel.message_ids or []) + new_message_ids = agent_message_ids - initial_message_ids + print(f"\nšŸ“ Agent message_ids after cancel: {len(agent_message_ids)}") + print(f" New message_ids in this run: {len(new_message_ids)}") + + db_message_ids = {m.id for m in db_messages_after_cancel} + + # CRITICAL CHECK: Every message in DB must be in agent.message_ids + messages_in_db_not_in_agent = db_message_ids - agent_message_ids + + if messages_in_db_not_in_agent: + # THIS IS THE DESYNC BUG! + print("\nāŒ DESYNC BUG DETECTED!") + print(f"šŸ› Found {len(messages_in_db_not_in_agent)} messages in DB but NOT in agent.message_ids") + print(" This bug occurs specifically with: background=True + stream_tokens=True") + + missing_messages = [m for m in db_messages_after_cancel if m.id in messages_in_db_not_in_agent] + + print("\nšŸ” Missing messages details:") + for m in missing_messages: + print(f" - ID: {m.id}") + print(f" Role: {m.role}") + print(f" Created: {m.created_at}") + if hasattr(m, "content"): + content_preview = str(m.content)[:100] if m.content else "None" + print(f" Content: {content_preview}...") + + # Get the last message IDs for the exact error message format + cursor_last = list(db_message_ids)[-1] if db_message_ids else None + in_context_last = list(agent_message_ids)[-1] if agent_message_ids else None + + print("\nšŸ’„ This causes the EXACT error reported:") + print(f" 'Desync detected - cursor last: {cursor_last},") + print(f" in-context last: {in_context_last}'") + + assert False, ( + f"šŸ› DESYNC DETECTED IN BACKGROUND + TOKEN STREAMING MODE\n\n" + f"Found {len(messages_in_db_not_in_agent)} messages in DB but not in agent.message_ids\n\n" + f"This reproduces the reported bug:\n" + f" 'Desync detected - cursor last: {cursor_last},\n" + f" in-context last: {in_context_last}'\n\n" + f"Missing message IDs: {messages_in_db_not_in_agent}\n\n" + f"Root cause: With background=True + stream_tokens=True, approval request messages\n" + f"are persisted to DB but NOT added to agent.message_ids when cancellation occurs\n" + f"during HITL approval flow.\n\n" + f"Fix location: Check approval flow in letta_agent_v3.py:442-486 and background\n" + f"streaming wrapper in streaming_service.py:138-146" + ) + + # Also check reverse: agent.message_ids shouldn't have messages not in DB + messages_in_agent_not_in_db = agent_message_ids - db_message_ids + messages_in_agent_not_in_db = messages_in_agent_not_in_db - initial_message_ids + + if messages_in_agent_not_in_db: + print("\nāŒ REVERSE DESYNC DETECTED!") + print(f"Found {len(messages_in_agent_not_in_db)} message IDs in agent.message_ids but NOT in DB") + + assert False, ( + f"REVERSE DESYNC: {len(messages_in_agent_not_in_db)} messages in agent.message_ids but not in DB\n" + f"Message IDs: {messages_in_agent_not_in_db}" + ) + + # If we get here, message IDs are consistent! + print("\nāœ… No desync detected - message IDs are consistent between DB and agent state") + print(f" DB message count: {len(db_message_ids)}") + print(f" Agent message_ids count: {len(agent_message_ids)}") + print("\n Either the bug is fixed, or we need to adjust test timing/conditions.") + + +class TestStreamingCancellation: + """ + Test cancellation during different streaming modes. + """ + + @pytest.mark.asyncio + async def test_token_streaming_cancellation( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test cancellation during token streaming mode. + + This tests Issue #3: Cancellation During LLM Streaming (token mode). + + Verifies: + - Cancellation can be detected during token streaming + - Partial messages are handled correctly + - Stop reason is set to 'cancelled' + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + input_messages = [MessageCreate(role=MessageRole.user, content="Hello")] + + # Cancel after first chunk + cancel_triggered = [False] + + async def mock_check_cancellation(run_id): + if cancel_triggered[0]: + return True + return False + + agent_loop._check_run_cancellation = mock_check_cancellation + + # Mock streaming + async def cancel_during_stream(): + """Generator that simulates streaming and cancels mid-stream.""" + chunks_yielded = 0 + stream = agent_loop.stream( + input_messages=input_messages, + max_steps=5, + stream_tokens=True, + run_id=test_run.id, + ) + + async for chunk in stream: + chunks_yielded += 1 + yield chunk + + # Cancel after a few chunks + if chunks_yielded == 2 and not cancel_triggered[0]: + cancel_triggered[0] = True + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Consume the stream + chunks = [] + try: + async for chunk in cancel_during_stream(): + chunks.append(chunk) + except Exception as e: + # May raise exception on cancellation + pass + + # Verify we got some chunks before cancellation + assert len(chunks) > 0, "Should receive at least some chunks before cancellation" + + @pytest.mark.asyncio + async def test_step_streaming_cancellation( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test cancellation during step streaming mode (not token streaming). + + Verifies: + - Cancellation detected between steps + - Completed steps are streamed fully + - Partial step is not streamed + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + input_messages = [MessageCreate(role=MessageRole.user, content="Call print_tool with 'message'")] + + # Cancel after first step + call_count = [0] + + async def mock_check_cancellation(run_id): + call_count[0] += 1 + if call_count[0] > 1: + await server.run_manager.cancel_run( + actor=default_user, + run_id=run_id, + ) + return True + return False + + agent_loop._check_run_cancellation = mock_check_cancellation + + # Stream with step streaming (not token streaming) + chunks = [] + stream = agent_loop.stream( + input_messages=input_messages, + max_steps=5, + stream_tokens=False, # Step streaming + run_id=test_run.id, + ) + + async for chunk in stream: + chunks.append(chunk) + + # Verify we got chunks from the first step + assert len(chunks) > 0, "Should receive chunks from first step before cancellation" + + # Verify cancellation was detected + assert agent_loop.stop_reason.stop_reason == "cancelled" + + +class TestToolExecutionCancellation: + """ + Test cancellation during tool execution. + This tests Issue #2C: Token streaming tool return desync. + """ + + @pytest.mark.asyncio + async def test_cancellation_during_tool_execution( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + print_tool, + ): + """ + Test cancellation while tool is executing. + + Verifies: + - Tool execution completes or is interrupted cleanly + - Tool return messages are consistent + - Database state matches client state + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + input_messages = [MessageCreate(role=MessageRole.user, content="Call print_tool with 'test message'")] + + # Mock the tool execution to detect cancellation + tool_execution_started = [False] + tool_execution_completed = [False] + + original_execute = agent_loop._execute_tool + + async def mock_execute_tool(target_tool, tool_args, agent_state, agent_step_span, step_id): + tool_execution_started[0] = True + + # Cancel during tool execution + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Call original (tool execution should complete) + result = await original_execute(target_tool, tool_args, agent_state, agent_step_span, step_id) + + tool_execution_completed[0] = True + return result + + agent_loop._execute_tool = mock_execute_tool + + # Execute step + result = await agent_loop.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + # Verify tool execution started + assert tool_execution_started[0], "Tool execution should have started" + + # Verify cancellation was eventually detected + # (may be after tool completes, at next step boundary) + assert result.stop_reason.stop_reason == "cancelled" + + # If tool completed, verify its messages are persisted + if tool_execution_completed[0]: + db_messages = await server.message_manager.list_messages( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + run_messages = [m for m in db_messages if m.run_id == test_run.id] + tool_returns = [m for m in run_messages if m.role == "tool"] + + # If tool executed, should have a tool return message + assert len(tool_returns) > 0, "Should have persisted tool return message" + + +class TestResourceCleanupAfterCancellation: + """ + Test Issue #6: Resource Cleanup Issues + Tests that resources are properly cleaned up after cancellation. + """ + + @pytest.mark.asyncio + async def test_stop_reason_set_correctly_on_cancellation( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test that stop_reason is set to 'cancelled' not 'end_turn' or other. + + This tests Issue #6: Resource Cleanup Issues. + The finally block should set stop_reason to 'cancelled' when appropriate. + + Verifies: + - stop_reason is 'cancelled' when run is cancelled + - stop_reason is not 'end_turn' or 'completed' for cancelled runs + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + # Cancel before execution + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + input_messages = [MessageCreate(role=MessageRole.user, content="Hello")] + + result = await agent_loop.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + # Verify stop reason is cancelled, not end_turn + assert result.stop_reason.stop_reason == "cancelled", f"Stop reason should be 'cancelled', got '{result.stop_reason.stop_reason}'" + + # Verify run status in database + run = await server.run_manager.get_run_by_id(run_id=test_run.id, actor=default_user) + assert run.status == RunStatus.cancelled, f"Run status should be cancelled, got {run.status}" + + @pytest.mark.asyncio + async def test_response_messages_cleared_after_cancellation( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test that internal message buffers are properly managed after cancellation. + + Verifies: + - response_messages list is in expected state after cancellation + - No memory leaks from accumulated messages + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + # Execute and cancel + call_count = [0] + + async def mock_check_cancellation(run_id): + call_count[0] += 1 + if call_count[0] > 1: + await server.run_manager.cancel_run( + actor=default_user, + run_id=run_id, + ) + return True + return False + + agent_loop._check_run_cancellation = mock_check_cancellation + + input_messages = [MessageCreate(role=MessageRole.user, content="Call print_tool with 'test'")] + + result = await agent_loop.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + # Verify response_messages is not empty (contains messages from completed step) + # or is properly cleared depending on implementation + response_msg_count = len(agent_loop.response_messages) + + # The exact behavior may vary, but we're checking that the state is reasonable + assert response_msg_count >= 0, "response_messages should be in valid state" + + # Verify no excessive accumulation + assert response_msg_count < 100, "response_messages should not have excessive accumulation" + + +class TestApprovalFlowCancellation: + """ + Test Issue #5: Approval Flow + Cancellation + Tests edge cases with HITL tool approvals and cancellation. + """ + + @pytest.mark.asyncio + async def test_cancellation_while_waiting_for_approval( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + bash_tool, + ): + """ + Test cancellation while agent is waiting for tool approval. + + This tests the scenario where: + 1. Agent calls a tool requiring approval + 2. Run is cancelled while waiting for approval + 3. Agent should detect cancellation and not process approval + + Verifies: + - Run status is cancelled + - Agent does not process approval after cancellation + - No tool execution happens + """ + # Add bash_tool which requires approval to agent + await server.agent_manager.attach_tool_async( + agent_id=test_agent_with_tool.id, + tool_id=bash_tool.id, + actor=default_user, + ) + + # Reload agent with new tool + test_agent_with_tool = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ) + + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + input_messages = [MessageCreate(role=MessageRole.user, content="Call bash_tool with operation 'test'")] + + # Execute step - should stop at approval request + result = await agent_loop.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + # Verify we got approval request + assert result.stop_reason.stop_reason == "requires_approval", f"Should stop for approval, got {result.stop_reason.stop_reason}" + + # Now cancel the run while "waiting for approval" + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Reload agent loop with fresh state + agent_loop_2 = AgentLoop.load( + agent_state=await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ), + actor=default_user, + ) + + # Try to continue - should detect cancellation + result_2 = await agent_loop_2.step( + input_messages=[MessageCreate(role=MessageRole.user, content="Hello")], # No new input, just continuing + max_steps=5, + run_id=test_run.id, + ) + + # Should detect cancellation + assert result_2.stop_reason.stop_reason == "cancelled", f"Should detect cancellation, got {result_2.stop_reason.stop_reason}" + + @pytest.mark.asyncio + async def test_agent_state_after_cancelled_approval( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + bash_tool, + ): + """ + Test that agent state is consistent after approval request is cancelled. + + This addresses the issue where agents say they are "awaiting approval" + even though the run is cancelled. + + Verifies: + - Agent can continue after cancelled approval + - No phantom "awaiting approval" state + - Messages reflect actual state + """ + # Add bash_tool which requires approval + await server.agent_manager.attach_tool_async( + agent_id=test_agent_with_tool.id, + tool_id=bash_tool.id, + actor=default_user, + ) + + # Reload agent with new tool + test_agent_with_tool = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ) + + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + # First run: trigger approval request then cancel + input_messages_1 = [MessageCreate(role=MessageRole.user, content="Call bash_tool with operation 'test'")] + + result_1 = await agent_loop.step( + input_messages=input_messages_1, + max_steps=5, + run_id=test_run.id, + ) + + assert result_1.stop_reason.stop_reason == "requires_approval" + + # Cancel the run + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Get messages to check for "awaiting approval" state + messages_after_cancel = await server.message_manager.list_messages( + actor=default_user, + agent_id=test_agent_with_tool.id, + run_id=test_run.id, + limit=1000, + ) + + # Check for approval request messages + approval_messages = [m for m in messages_after_cancel if m.role == "approval_request"] + + # Second run: try to execute normally (should work, not stuck in approval) + test_run_2 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=test_agent_with_tool.id, + status=RunStatus.created, + ), + actor=default_user, + ) + + agent_loop_2 = AgentLoop.load( + agent_state=await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ), + actor=default_user, + ) + + # Call a different tool that doesn't require approval + input_messages_2 = [MessageCreate(role=MessageRole.user, content="Call print_tool with message 'hello'")] + + result_2 = await agent_loop_2.step( + input_messages=input_messages_2, + max_steps=5, + run_id=test_run_2.id, + ) + + # Should complete normally, not be stuck in approval state + assert result_2.stop_reason.stop_reason != "requires_approval", "Agent should not be stuck in approval state from cancelled run" + + @pytest.mark.asyncio + async def test_approval_state_persisted_correctly_after_cancel( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + bash_tool, + ): + """ + Test that approval state is correctly persisted/cleaned after cancellation. + + This addresses the specific issue mentioned: + "agents say they are awaiting approval despite the run not being shown as pending approval" + + Verifies: + - Run status matches actual state + - No phantom "pending approval" status + - Messages accurately reflect cancellation + """ + # Add bash_tool + await server.agent_manager.attach_tool_async( + agent_id=test_agent_with_tool.id, + tool_id=bash_tool.id, + actor=default_user, + ) + + test_agent_with_tool = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ) + + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + # Trigger approval + result = await agent_loop.step( + input_messages=[MessageCreate(role=MessageRole.user, content="Call bash_tool with 'test'")], + max_steps=5, + run_id=test_run.id, + ) + + assert result.stop_reason.stop_reason == "requires_approval" + + # Cancel the run + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Verify run status is cancelled, NOT pending_approval + run_after_cancel = await server.run_manager.get_run_by_id(run_id=test_run.id, actor=default_user) + assert run_after_cancel.status == RunStatus.cancelled, f"Run status should be cancelled, got {run_after_cancel.status}" + + # Agent should be able to start fresh run + test_run_3 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=test_agent_with_tool.id, + status=RunStatus.created, + ), + actor=default_user, + ) + + agent_loop_3 = AgentLoop.load( + agent_state=await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ), + actor=default_user, + ) + + # Should be able to make normal call + result_3 = await agent_loop_3.step( + input_messages=[MessageCreate(role=MessageRole.user, content="Call print_tool with 'test'")], + max_steps=5, + run_id=test_run_3.id, + ) + + # Should complete normally + assert result_3.stop_reason.stop_reason != "requires_approval", "New run should not be stuck in approval state" + + @pytest.mark.asyncio + async def test_approval_request_message_ids_desync( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + bash_tool, + ): + """ + Test for the specific desync bug reported: + "Desync detected - cursor last: message-X, in-context last: message-Y" + + Bug Scenario: + 1. Agent calls HITL tool requiring approval + 2. Approval request message is persisted to DB + 3. Run is cancelled + 4. Approval request message ID is NOT in agent.message_ids + 5. Result: cursor desync between DB and agent state + + This is the root cause of the reported error: + "Desync detected - cursor last: message-c07fa1ec..., in-context last: message-a2615dc3..." + + The bug happens because: + - Database contains the approval_request message + - Agent's message_ids list does NOT contain the approval_request message ID + - Causes cursor/pagination to fail + + Verifies: + - If approval request is in DB, it must be in agent.message_ids + - Cancellation doesn't cause partial message persistence + - Cursor consistency between DB and agent state + """ + # Add bash_tool which requires approval + await server.agent_manager.attach_tool_async( + agent_id=test_agent_with_tool.id, + tool_id=bash_tool.id, + actor=default_user, + ) + + # Get initial message count + agent_before = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + initial_message_ids = set(agent_before.message_ids or []) + + # Reload agent with new tool + test_agent_with_tool = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ) + + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + # Trigger approval request + result = await agent_loop.step( + input_messages=[MessageCreate(role=MessageRole.user, content="Call bash_tool with 'test'")], + max_steps=5, + run_id=test_run.id, + ) + + assert result.stop_reason.stop_reason == "requires_approval", f"Expected requires_approval, got {result.stop_reason.stop_reason}" + + # Get all messages from database for this run + db_messages = await server.message_manager.list_messages( + actor=default_user, + agent_id=test_agent_with_tool.id, + run_id=test_run.id, + limit=1000, + ) + + # Cancel the run + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Get agent state after cancellation + agent_after_cancel = await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + ) + + agent_message_ids = set(agent_after_cancel.message_ids or []) + + # Get all messages from database again + db_messages_after = await server.message_manager.list_messages( + actor=default_user, + agent_id=test_agent_with_tool.id, + run_id=test_run.id, + limit=1000, + ) + + db_message_ids = {m.id for m in db_messages_after} + + # CRITICAL CHECK: Every message in DB must be in agent.message_ids + messages_in_db_not_in_agent = db_message_ids - agent_message_ids + + if messages_in_db_not_in_agent: + # THIS IS THE DESYNC BUG! + missing_messages = [m for m in db_messages_after if m.id in messages_in_db_not_in_agent] + missing_details = [f"ID: {m.id}, Role: {m.role}, Created: {m.created_at}" for m in missing_messages] + + # Get the cursor values that would cause the error + cursor_last = list(db_message_ids)[-1] if db_message_ids else None + in_context_last = list(agent_message_ids)[-1] if agent_message_ids else None + + assert False, ( + f"DESYNC DETECTED: {len(messages_in_db_not_in_agent)} messages in DB but not in agent.message_ids\n\n" + f"This is the reported bug:\n" + f" 'Desync detected - cursor last: {cursor_last}, in-context last: {in_context_last}'\n\n" + f"Missing messages:\n" + "\n".join(missing_details) + "\n\n" + f"Agent message_ids count: {len(agent_message_ids)}\n" + f"DB messages count: {len(db_message_ids)}\n\n" + f"Root cause: Approval request message was persisted to DB but not added to agent.message_ids\n" + f"when cancellation occurred during HITL approval flow." + ) + + # Also check the inverse: agent.message_ids shouldn't have messages not in DB + messages_in_agent_not_in_db = agent_message_ids - db_message_ids + messages_in_agent_not_in_db = messages_in_agent_not_in_db - initial_message_ids + + if messages_in_agent_not_in_db: + assert False, ( + f"REVERSE DESYNC: {len(messages_in_agent_not_in_db)} messages in agent.message_ids but not in DB\n" + f"Message IDs: {messages_in_agent_not_in_db}" + ) + + @pytest.mark.asyncio + async def test_parallel_tool_calling_cancellation_with_denials( + self, + server: SyncServer, + default_user, + bash_tool, + ): + """ + Test that parallel tool calls receive proper denial messages on cancellation. + + This tests the scenario where: + 1. Agent has parallel tool calling enabled + 2. Agent calls a tool 3 times in parallel (requiring approval) + 3. Run is cancelled while waiting for approval + 4. All 3 tool calls receive denial messages with TOOL_CALL_DENIAL_ON_CANCEL + 5. Agent can still be messaged again (creating a new run) + + Verifies: + - All parallel tool calls get proper denial messages + - Denial messages contain TOOL_CALL_DENIAL_ON_CANCEL reason + - Agent state is not corrupted + - New runs can be created after cancellation + """ + # Create agent with parallel tool calling enabled + config = LLMConfig.default_config("gpt-4o-mini") + config.parallel_tool_calls = True + agent_state = await server.agent_manager.create_agent_async( + agent_create=CreateAgent( + name="test_parallel_tool_calling_agent", + agent_type="letta_v1_agent", + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tool_ids=[bash_tool.id], + include_base_tools=False, + ), + actor=default_user, + ) + + # Create a run + test_run = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=agent_state.id, + status=RunStatus.created, + ), + actor=default_user, + ) + + # Load agent loop + agent_loop = AgentLoop.load(agent_state=agent_state, actor=default_user) + + # Prompt the agent to call bash_tool 3 times + # The agent should make parallel tool calls since parallel_tool_calls is enabled + input_messages = [ + MessageCreate( + role=MessageRole.user, + content="Please call bash_tool three times with operations: 'ls', 'pwd', and 'echo test'", + ) + ] + + # Execute step - should stop at approval request with multiple tool calls + result = await agent_loop.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + # Verify we got approval request + assert result.stop_reason.stop_reason == "requires_approval", f"Should stop for approval, got {result.stop_reason.stop_reason}" + + # Get the approval request message to see how many tool calls were made + db_messages_before_cancel = await server.message_manager.list_messages( + actor=default_user, + agent_id=agent_state.id, + run_id=test_run.id, + limit=1000, + ) + + # should not be possible to message the agent (Pending approval) + from letta.errors import PendingApprovalError + + with pytest.raises(PendingApprovalError): + test_run2 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=agent_state.id, + status=RunStatus.created, + ), + actor=default_user, + ) + await agent_loop.step( + input_messages=[MessageCreate(role=MessageRole.user, content="Hello, how are you?")], + max_steps=5, + run_id=test_run2.id, + ) + + from letta.schemas.letta_message import ApprovalRequestMessage + + approval_request_messages = [m for m in result.messages if isinstance(m, ApprovalRequestMessage)] + assert len(approval_request_messages) > 0, "Should have at least one approval request message" + + # Get the last approval request message (should have the tool calls) + approval_request = approval_request_messages[-1] + tool_calls = approval_request.tool_calls if hasattr(approval_request, "tool_calls") else [] + num_tool_calls = len(tool_calls) + + print(f"\nFound {num_tool_calls} tool calls in approval request") + + # The agent might not always make exactly 3 parallel calls depending on the LLM, + # but we should have at least 1 tool call. For the test to be meaningful, + # we want multiple tool calls, but we'll verify whatever the LLM decides + assert num_tool_calls >= 1, f"Should have at least 1 tool call, got {num_tool_calls}" + + # Now cancel the run while "waiting for approval" + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Get messages after cancellation + db_messages_after_cancel = await server.message_manager.list_messages( + actor=default_user, + agent_id=agent_state.id, + run_id=test_run.id, + limit=1000, + ) + + # Find tool return messages (these should be the denial messages) + tool_return_messages = [m for m in db_messages_after_cancel if m.role == "tool"] + + print(f"Found {len(tool_return_messages)} tool return messages after cancellation") + + # Verify we got denial messages for all tool calls + assert len(tool_return_messages) == num_tool_calls, ( + f"Should have {num_tool_calls} tool return messages (one per tool call), got {len(tool_return_messages)}" + ) + + # Verify each tool return message contains the denial reason + for tool_return_msg in tool_return_messages: + # Check if message has tool_returns (new format) or tool_return (old format) + print("TOOL RETURN MESSAGE:\n\n", tool_return_msg) + if hasattr(tool_return_msg, "tool_returns") and tool_return_msg.tool_returns: + # New format: list of tool returns + for tool_return in tool_return_msg.tool_returns: + assert TOOL_CALL_DENIAL_ON_CANCEL in tool_return.func_response, ( + f"Tool return should contain denial message, got: {tool_return.tool_return}" + ) + elif hasattr(tool_return_msg, "tool_return"): + # Old format: single tool_return field + assert TOOL_CALL_DENIAL_ON_CANCEL in tool_return_msg.content, ( + f"Tool return should contain denial message, got: {tool_return_msg.tool_return}" + ) + elif hasattr(tool_return_msg, "content"): + # Check content field + content_str = str(tool_return_msg.content) + assert TOOL_CALL_DENIAL_ON_CANCEL in content_str, f"Tool return content should contain denial message, got: {content_str}" + + # Verify run status is cancelled + run_after_cancel = await server.run_manager.get_run_by_id(run_id=test_run.id, actor=default_user) + assert run_after_cancel.status == RunStatus.cancelled, f"Run status should be cancelled, got {run_after_cancel.status}" + + # Verify agent can be messaged again (create a new run) + test_run_2 = await server.run_manager.create_run( + pydantic_run=PydanticRun( + agent_id=agent_state.id, + status=RunStatus.created, + ), + actor=default_user, + ) + + # Reload agent loop with fresh state + agent_loop_2 = AgentLoop.load( + agent_state=await server.agent_manager.get_agent_by_id_async( + agent_id=agent_state.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ), + actor=default_user, + ) + + # Send a simple message that doesn't require approval + input_messages_2 = [MessageCreate(role=MessageRole.user, content="Hello, how are you?")] + + result_2 = await agent_loop_2.step( + input_messages=input_messages_2, + max_steps=5, + run_id=test_run_2.id, + ) + + # Verify second run completed successfully (not cancelled, not stuck in approval) + assert result_2.stop_reason.stop_reason != "cancelled", ( + f"Second run should not be cancelled, got {result_2.stop_reason.stop_reason}" + ) + assert result_2.stop_reason.stop_reason != "requires_approval", ( + f"Second run should not require approval for simple message, got {result_2.stop_reason.stop_reason}" + ) + + # Verify the second run has messages + db_messages_run2 = await server.message_manager.list_messages( + actor=default_user, + agent_id=agent_state.id, + run_id=test_run_2.id, + limit=1000, + ) + assert len(db_messages_run2) > 0, "Second run should have messages" + + +class TestEdgeCases: + """ + Test edge cases and boundary conditions for cancellation. + """ + + @pytest.mark.asyncio + async def test_cancellation_with_max_steps_reached( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test interaction between max_steps and cancellation. + + Verifies: + - If both max_steps and cancellation occur, correct stop_reason is set + - Cancellation takes precedence over max_steps + """ + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + # Cancel after second step, but max_steps=2 + call_count = [0] + + async def mock_check_cancellation(run_id): + call_count[0] += 1 + if call_count[0] >= 2: + await server.run_manager.cancel_run( + actor=default_user, + run_id=run_id, + ) + return True + return False + + agent_loop._check_run_cancellation = mock_check_cancellation + + input_messages = [MessageCreate(role=MessageRole.user, content="Call print_tool with 'test'")] + + result = await agent_loop.step( + input_messages=input_messages, + max_steps=2, # Will hit max_steps around the same time as cancellation + run_id=test_run.id, + ) + + # Stop reason could be either cancelled or max_steps depending on timing + # Both are acceptable in this edge case + assert result.stop_reason.stop_reason in ["cancelled", "max_steps"], ( + f"Stop reason should be cancelled or max_steps, got {result.stop_reason.stop_reason}" + ) + + @pytest.mark.asyncio + async def test_double_cancellation( + self, + server: SyncServer, + default_user, + test_agent_with_tool, + test_run, + ): + """ + Test that cancelling an already-cancelled run is handled gracefully. + + Verifies: + - No errors when checking already-cancelled run + - State remains consistent + """ + # Cancel the run + await server.run_manager.cancel_run( + actor=default_user, + run_id=test_run.id, + ) + + # Load agent loop + agent_loop = AgentLoop.load(agent_state=test_agent_with_tool, actor=default_user) + + input_messages = [MessageCreate(role=MessageRole.user, content="Hello")] + + # First execution - should detect cancellation + result_1 = await agent_loop.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + assert result_1.stop_reason.stop_reason == "cancelled" + + # Try to execute again with same cancelled run - should handle gracefully + agent_loop_2 = AgentLoop.load( + agent_state=await server.agent_manager.get_agent_by_id_async( + agent_id=test_agent_with_tool.id, + actor=default_user, + include_relationships=["memory", "tools", "sources"], + ), + actor=default_user, + ) + + result_2 = await agent_loop_2.step( + input_messages=input_messages, + max_steps=5, + run_id=test_run.id, + ) + + # Should still detect as cancelled + assert result_2.stop_reason.stop_reason == "cancelled"