feat: fix cancellation bugs and add testing (#5353)
This commit is contained in:
@@ -308,6 +308,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
else:
|
||||
# Check for job cancellation at the start of each step
|
||||
if run_id and await self._check_run_cancellation(run_id):
|
||||
self.should_continue = False
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.cancelled.value)
|
||||
self.logger.info(f"Agent execution cancelled for run {run_id}")
|
||||
return
|
||||
|
||||
@@ -213,6 +213,7 @@ async def create_background_stream_processor(
|
||||
run_manager: Optional run manager for updating run status
|
||||
actor: Optional actor for run status updates
|
||||
"""
|
||||
stop_reason = None
|
||||
if writer is None:
|
||||
writer = RedisSSEStreamWriter(redis_client)
|
||||
await writer.start()
|
||||
@@ -232,6 +233,15 @@ async def create_background_stream_processor(
|
||||
if is_done:
|
||||
break
|
||||
|
||||
try:
|
||||
# sorry for this
|
||||
maybe_json_chunk = chunk.split("data: ")[1]
|
||||
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
|
||||
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
|
||||
stop_reason = maybe_stop_reason.get("stop_reason")
|
||||
except:
|
||||
pass
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream for run {run_id}: {e}")
|
||||
# Write error chunk
|
||||
@@ -251,9 +261,14 @@ async def create_background_stream_processor(
|
||||
if should_stop_writer:
|
||||
await writer.stop()
|
||||
if run_manager and actor:
|
||||
if stop_reason == "cancelled":
|
||||
run_status = RunStatus.cancelled
|
||||
else:
|
||||
run_status = RunStatus.completed
|
||||
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn.value),
|
||||
update=RunUpdate(status=run_status, stop_reason=stop_reason or StopReasonType.end_turn.value),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@@ -1364,9 +1364,14 @@ async def send_message_streaming(
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
|
||||
if agent_loop.stop_reason.stop_reason.value == "cancelled":
|
||||
run_status = RunStatus.cancelled
|
||||
else:
|
||||
run_status = RunStatus.completed
|
||||
|
||||
await runs_manager.update_run_by_id_async(
|
||||
run_id=run.id,
|
||||
update=RunUpdate(status=RunStatus.completed, stop_reason=agent_loop.stop_reason.stop_reason.value),
|
||||
update=RunUpdate(status=run_status, stop_reason=agent_loop.stop_reason.stop_reason.value),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
@@ -1620,9 +1625,14 @@ async def _process_message_background(
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
|
||||
if result.stop_reason.stop_reason == "cancelled":
|
||||
run_status = RunStatus.cancelled
|
||||
else:
|
||||
run_status = RunStatus.completed
|
||||
|
||||
await runs_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.completed, stop_reason=result.stop_reason.stop_reason),
|
||||
update=RunUpdate(status=run_status, stop_reason=result.stop_reason.stop_reason),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import base64
|
||||
import itertools
|
||||
import json
|
||||
import os
|
||||
import threading
|
||||
@@ -164,6 +166,7 @@ def assert_tool_call_response(
|
||||
llm_config: LLMConfig,
|
||||
streaming: bool = False,
|
||||
from_db: bool = False,
|
||||
with_cancellation: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Asserts that the messages list follows the expected sequence:
|
||||
@@ -175,10 +178,11 @@ def assert_tool_call_response(
|
||||
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
|
||||
]
|
||||
|
||||
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
|
||||
llm_config, tool_call=True, streaming=streaming, from_db=from_db
|
||||
)
|
||||
assert expected_message_count_min <= len(messages) <= expected_message_count_max
|
||||
if not with_cancellation:
|
||||
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
|
||||
llm_config, tool_call=True, streaming=streaming, from_db=from_db
|
||||
)
|
||||
assert expected_message_count_min <= len(messages) <= expected_message_count_max
|
||||
|
||||
# User message if loaded from db
|
||||
index = 0
|
||||
@@ -217,28 +221,30 @@ def assert_tool_call_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
|
||||
# Reasoning message if reasoning enabled
|
||||
otid_suffix = 0
|
||||
try:
|
||||
if is_reasoner_model(llm_config):
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
# Messages from second agent step if request has not been cancelled
|
||||
if not with_cancellation:
|
||||
# Reasoning message if reasoning enabled
|
||||
otid_suffix = 0
|
||||
try:
|
||||
if is_reasoner_model(llm_config):
|
||||
assert isinstance(messages[index], ReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
except:
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
# Assistant message
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
# Assistant message
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
|
||||
# Stop reason and usage statistics if streaming
|
||||
if streaming:
|
||||
assert isinstance(messages[index], LettaStopReason)
|
||||
assert messages[index].stop_reason == "end_turn"
|
||||
assert messages[index].stop_reason == ("cancelled" if with_cancellation else "end_turn")
|
||||
index += 1
|
||||
assert isinstance(messages[index], LettaUsageStatistics)
|
||||
assert messages[index].prompt_tokens > 0
|
||||
@@ -280,12 +286,20 @@ async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = Fa
|
||||
return [m for m in messages if m is not None]
|
||||
|
||||
|
||||
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str):
|
||||
await asyncio.sleep(0.5)
|
||||
await client.agents.messages.cancel(agent_id=agent_id)
|
||||
|
||||
|
||||
async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
||||
start = time.time()
|
||||
while True:
|
||||
run = await client.runs.retrieve(run_id)
|
||||
if run.status == "completed":
|
||||
return run
|
||||
if run.status == "cancelled":
|
||||
time.sleep(5)
|
||||
return run
|
||||
if run.status == "failed":
|
||||
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
|
||||
if time.time() - start > timeout:
|
||||
@@ -501,7 +515,20 @@ async def test_greeting(
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
|
||||
@pytest.mark.parametrize(
|
||||
["send_type", "cancellation"],
|
||||
list(
|
||||
itertools.product(
|
||||
["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"]
|
||||
)
|
||||
),
|
||||
ids=[
|
||||
f"{s}-{c}"
|
||||
for s, c in itertools.product(
|
||||
["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"]
|
||||
)
|
||||
],
|
||||
)
|
||||
@pytest.mark.asyncio(loop_scope="function")
|
||||
async def test_tool_call(
|
||||
disable_e2b_api_key: Any,
|
||||
@@ -509,10 +536,14 @@ async def test_tool_call(
|
||||
agent_state: AgentState,
|
||||
llm_config: LLMConfig,
|
||||
send_type: str,
|
||||
cancellation: str,
|
||||
) -> None:
|
||||
last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
||||
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||
|
||||
if cancellation == "with_cancellation":
|
||||
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id))
|
||||
|
||||
if send_type == "step":
|
||||
response = await client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
@@ -539,16 +570,22 @@ async def test_tool_call(
|
||||
messages = await accumulate_chunks(response)
|
||||
run_id = messages[0].run_id
|
||||
|
||||
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
|
||||
assert_tool_call_response(
|
||||
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
|
||||
)
|
||||
|
||||
if "background" in send_type:
|
||||
response = client.runs.stream(run_id=run_id, starting_after=0)
|
||||
messages = await accumulate_chunks(response)
|
||||
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
|
||||
assert_tool_call_response(
|
||||
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
|
||||
)
|
||||
|
||||
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
assert_tool_call_response(
|
||||
messages_from_db, from_db=True, llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
|
||||
)
|
||||
|
||||
assert run_id is not None
|
||||
run = await client.runs.retrieve(run_id=run_id)
|
||||
assert run.status == JobStatus.completed
|
||||
assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed")
|
||||
|
||||
Reference in New Issue
Block a user