From 128afeb587bed9d392c4e2d05c15ac3ad455ba9b Mon Sep 17 00:00:00 2001 From: cthomas Date: Fri, 10 Oct 2025 17:08:30 -0700 Subject: [PATCH] feat: fix cancellation bugs and add testing (#5353) --- letta/agents/letta_agent_v3.py | 1 + letta/server/rest_api/redis_stream_manager.py | 17 +++- letta/server/rest_api/routers/v1/agents.py | 14 ++- tests/integration_test_send_message_v2.py | 89 +++++++++++++------ 4 files changed, 92 insertions(+), 29 deletions(-) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 2d1cf841..873645bd 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -308,6 +308,7 @@ class LettaAgentV3(LettaAgentV2): else: # Check for job cancellation at the start of each step if run_id and await self._check_run_cancellation(run_id): + self.should_continue = False self.stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value) self.logger.info(f"Agent execution cancelled for run {run_id}") return diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py index 3d98205f..b361466d 100644 --- a/letta/server/rest_api/redis_stream_manager.py +++ b/letta/server/rest_api/redis_stream_manager.py @@ -213,6 +213,7 @@ async def create_background_stream_processor( run_manager: Optional run manager for updating run status actor: Optional actor for run status updates """ + stop_reason = None if writer is None: writer = RedisSSEStreamWriter(redis_client) await writer.start() @@ -232,6 +233,15 @@ async def create_background_stream_processor( if is_done: break + try: + # sorry for this + maybe_json_chunk = chunk.split("data: ")[1] + maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None + if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason": + stop_reason = maybe_stop_reason.get("stop_reason") + except: + pass + except Exception as e: logger.error(f"Error processing stream for run {run_id}: {e}") # Write error chunk @@ -251,9 +261,14 @@ async def create_background_stream_processor( if should_stop_writer: await writer.stop() if run_manager and actor: + if stop_reason == "cancelled": + run_status = RunStatus.cancelled + else: + run_status = RunStatus.completed + await run_manager.update_run_by_id_async( run_id=run_id, - update=RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn.value), + update=RunUpdate(status=run_status, stop_reason=stop_reason or StopReasonType.end_turn.value), actor=actor, ) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 9dd2ca16..314207b5 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1364,9 +1364,14 @@ async def send_message_streaming( runs_manager = RunManager() from letta.schemas.enums import RunStatus + if agent_loop.stop_reason.stop_reason.value == "cancelled": + run_status = RunStatus.cancelled + else: + run_status = RunStatus.completed + await runs_manager.update_run_by_id_async( run_id=run.id, - update=RunUpdate(status=RunStatus.completed, stop_reason=agent_loop.stop_reason.stop_reason.value), + update=RunUpdate(status=run_status, stop_reason=agent_loop.stop_reason.stop_reason.value), actor=actor, ) @@ -1620,9 +1625,14 @@ async def _process_message_background( runs_manager = RunManager() from letta.schemas.enums import RunStatus + if result.stop_reason.stop_reason == "cancelled": + run_status = RunStatus.cancelled + else: + run_status = RunStatus.completed + await runs_manager.update_run_by_id_async( run_id=run_id, - update=RunUpdate(status=RunStatus.completed, stop_reason=result.stop_reason.stop_reason), + update=RunUpdate(status=run_status, stop_reason=result.stop_reason.stop_reason), actor=actor, ) diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 1ea9d9ac..1eefbc31 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -1,4 +1,6 @@ +import asyncio import base64 +import itertools import json import os import threading @@ -164,6 +166,7 @@ def assert_tool_call_response( llm_config: LLMConfig, streaming: bool = False, from_db: bool = False, + with_cancellation: bool = False, ) -> None: """ Asserts that the messages list follows the expected sequence: @@ -175,10 +178,11 @@ def assert_tool_call_response( msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) ] - expected_message_count_min, expected_message_count_max = get_expected_message_count_range( - llm_config, tool_call=True, streaming=streaming, from_db=from_db - ) - assert expected_message_count_min <= len(messages) <= expected_message_count_max + if not with_cancellation: + expected_message_count_min, expected_message_count_max = get_expected_message_count_range( + llm_config, tool_call=True, streaming=streaming, from_db=from_db + ) + assert expected_message_count_min <= len(messages) <= expected_message_count_max # User message if loaded from db index = 0 @@ -217,28 +221,30 @@ def assert_tool_call_response( assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 - # Reasoning message if reasoning enabled - otid_suffix = 0 - try: - if is_reasoner_model(llm_config): - assert isinstance(messages[index], ReasoningMessage) - assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) - index += 1 - otid_suffix += 1 - except: - # Reasoning is non-deterministic, so don't throw if missing - pass + # Messages from second agent step if request has not been cancelled + if not with_cancellation: + # Reasoning message if reasoning enabled + otid_suffix = 0 + try: + if is_reasoner_model(llm_config): + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + except: + # Reasoning is non-deterministic, so don't throw if missing + pass - # Assistant message - assert isinstance(messages[index], AssistantMessage) - assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) - index += 1 - otid_suffix += 1 + # Assistant message + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 # Stop reason and usage statistics if streaming if streaming: assert isinstance(messages[index], LettaStopReason) - assert messages[index].stop_reason == "end_turn" + assert messages[index].stop_reason == ("cancelled" if with_cancellation else "end_turn") index += 1 assert isinstance(messages[index], LettaUsageStatistics) assert messages[index].prompt_tokens > 0 @@ -280,12 +286,20 @@ async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = Fa return [m for m in messages if m is not None] +async def cancel_run_after_delay(client: AsyncLetta, agent_id: str): + await asyncio.sleep(0.5) + await client.agents.messages.cancel(agent_id=agent_id) + + async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: start = time.time() while True: run = await client.runs.retrieve(run_id) if run.status == "completed": return run + if run.status == "cancelled": + time.sleep(5) + return run if run.status == "failed": raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") if time.time() - start > timeout: @@ -501,7 +515,20 @@ async def test_greeting( TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS], ) -@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"]) +@pytest.mark.parametrize( + ["send_type", "cancellation"], + list( + itertools.product( + ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"] + ) + ), + ids=[ + f"{s}-{c}" + for s, c in itertools.product( + ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"] + ) + ], +) @pytest.mark.asyncio(loop_scope="function") async def test_tool_call( disable_e2b_api_key: Any, @@ -509,10 +536,14 @@ async def test_tool_call( agent_state: AgentState, llm_config: LLMConfig, send_type: str, + cancellation: str, ) -> None: last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1) agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + if cancellation == "with_cancellation": + _cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id)) + if send_type == "step": response = await client.agents.messages.create( agent_id=agent_state.id, @@ -539,16 +570,22 @@ async def test_tool_call( messages = await accumulate_chunks(response) run_id = messages[0].run_id - assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config) + assert_tool_call_response( + messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation") + ) if "background" in send_type: response = client.runs.stream(run_id=run_id, starting_after=0) messages = await accumulate_chunks(response) - assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config) + assert_tool_call_response( + messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation") + ) messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) - assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) + assert_tool_call_response( + messages_from_db, from_db=True, llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation") + ) assert run_id is not None run = await client.runs.retrieve(run_id=run_id) - assert run.status == JobStatus.completed + assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed")