feat: fix cancellation bugs and add testing (#5353)

This commit is contained in:
cthomas
2025-10-10 17:08:30 -07:00
committed by Caren Thomas
parent 7ab44e61fa
commit 128afeb587
4 changed files with 92 additions and 29 deletions

View File

@@ -308,6 +308,7 @@ class LettaAgentV3(LettaAgentV2):
else:
# Check for job cancellation at the start of each step
if run_id and await self._check_run_cancellation(run_id):
self.should_continue = False
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
self.logger.info(f"Agent execution cancelled for run {run_id}")
return

View File

@@ -213,6 +213,7 @@ async def create_background_stream_processor(
run_manager: Optional run manager for updating run status
actor: Optional actor for run status updates
"""
stop_reason = None
if writer is None:
writer = RedisSSEStreamWriter(redis_client)
await writer.start()
@@ -232,6 +233,15 @@ async def create_background_stream_processor(
if is_done:
break
try:
# sorry for this
maybe_json_chunk = chunk.split("data: ")[1]
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
stop_reason = maybe_stop_reason.get("stop_reason")
except:
pass
except Exception as e:
logger.error(f"Error processing stream for run {run_id}: {e}")
# Write error chunk
@@ -251,9 +261,14 @@ async def create_background_stream_processor(
if should_stop_writer:
await writer.stop()
if run_manager and actor:
if stop_reason == "cancelled":
run_status = RunStatus.cancelled
else:
run_status = RunStatus.completed
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn.value),
update=RunUpdate(status=run_status, stop_reason=stop_reason or StopReasonType.end_turn.value),
actor=actor,
)

View File

@@ -1364,9 +1364,14 @@ async def send_message_streaming(
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
if agent_loop.stop_reason.stop_reason.value == "cancelled":
run_status = RunStatus.cancelled
else:
run_status = RunStatus.completed
await runs_manager.update_run_by_id_async(
run_id=run.id,
update=RunUpdate(status=RunStatus.completed, stop_reason=agent_loop.stop_reason.stop_reason.value),
update=RunUpdate(status=run_status, stop_reason=agent_loop.stop_reason.stop_reason.value),
actor=actor,
)
@@ -1620,9 +1625,14 @@ async def _process_message_background(
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
if result.stop_reason.stop_reason == "cancelled":
run_status = RunStatus.cancelled
else:
run_status = RunStatus.completed
await runs_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.completed, stop_reason=result.stop_reason.stop_reason),
update=RunUpdate(status=run_status, stop_reason=result.stop_reason.stop_reason),
actor=actor,
)

View File

@@ -1,4 +1,6 @@
import asyncio
import base64
import itertools
import json
import os
import threading
@@ -164,6 +166,7 @@ def assert_tool_call_response(
llm_config: LLMConfig,
streaming: bool = False,
from_db: bool = False,
with_cancellation: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
@@ -175,10 +178,11 @@ def assert_tool_call_response(
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
llm_config, tool_call=True, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
if not with_cancellation:
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
llm_config, tool_call=True, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
# User message if loaded from db
index = 0
@@ -217,28 +221,30 @@ def assert_tool_call_response(
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(llm_config):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning is non-deterministic, so don't throw if missing
pass
# Messages from second agent step if request has not been cancelled
if not with_cancellation:
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(llm_config):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning is non-deterministic, so don't throw if missing
pass
# Assistant message
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Assistant message
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Stop reason and usage statistics if streaming
if streaming:
assert isinstance(messages[index], LettaStopReason)
assert messages[index].stop_reason == "end_turn"
assert messages[index].stop_reason == ("cancelled" if with_cancellation else "end_turn")
index += 1
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
@@ -280,12 +286,20 @@ async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = Fa
return [m for m in messages if m is not None]
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str):
await asyncio.sleep(0.5)
await client.agents.messages.cancel(agent_id=agent_id)
async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
start = time.time()
while True:
run = await client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "cancelled":
time.sleep(5)
return run
if run.status == "failed":
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
if time.time() - start > timeout:
@@ -501,7 +515,20 @@ async def test_greeting(
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.parametrize(
["send_type", "cancellation"],
list(
itertools.product(
["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"]
)
),
ids=[
f"{s}-{c}"
for s, c in itertools.product(
["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"]
)
],
)
@pytest.mark.asyncio(loop_scope="function")
async def test_tool_call(
disable_e2b_api_key: Any,
@@ -509,10 +536,14 @@ async def test_tool_call(
agent_state: AgentState,
llm_config: LLMConfig,
send_type: str,
cancellation: str,
) -> None:
last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
if cancellation == "with_cancellation":
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id))
if send_type == "step":
response = await client.agents.messages.create(
agent_id=agent_state.id,
@@ -539,16 +570,22 @@ async def test_tool_call(
messages = await accumulate_chunks(response)
run_id = messages[0].run_id
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
assert_tool_call_response(
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
)
if "background" in send_type:
response = client.runs.stream(run_id=run_id, starting_after=0)
messages = await accumulate_chunks(response)
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
assert_tool_call_response(
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
)
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
assert_tool_call_response(
messages_from_db, from_db=True, llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
)
assert run_id is not None
run = await client.runs.retrieve(run_id=run_id)
assert run.status == JobStatus.completed
assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed")