fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)
fix: use shared event + .athrow() to properly set stream_was_cancelled flag
**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.
**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]
**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation
**Tests:**
- test_tool_call with cancellation (OpenAI models) ✅
- test_approve_with_cancellation (approval flow + concurrent cancel) ✅
**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race
👾 Generated with [Letta Code](https://letta.com)
Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -198,3 +198,8 @@ async def test_background_streaming_cancellation(
|
||||
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
|
||||
messages_from_stream = await accumulate_chunks(response)
|
||||
assert len(messages_from_stream) > 0
|
||||
|
||||
# Verify the stream contains stop_reason: cancelled (from our new cancellation logic)
|
||||
stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
|
||||
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}"
|
||||
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, List
|
||||
@@ -1333,3 +1334,69 @@ def test_agent_records_last_stop_reason_after_approval_flow(
|
||||
# Verify final agent state has the most recent stop reason
|
||||
final_agent = client.agents.retrieve(agent_id=agent.id)
|
||||
assert final_agent.last_stop_reason is not None
|
||||
|
||||
|
||||
def test_approve_with_cancellation(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
"""
|
||||
Test that when approval and cancellation happen simultaneously,
|
||||
the stream returns stop_reason: cancelled and stream_was_cancelled is set.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
# Step 1: Send message that triggers approval request
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
||||
|
||||
# Step 2: Start cancellation in background thread
|
||||
def cancel_after_delay():
|
||||
time.sleep(0.3) # Wait for stream to start
|
||||
client.agents.messages.cancel(agent_id=agent.id)
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_after_delay, daemon=True)
|
||||
cancel_thread.start()
|
||||
|
||||
# Step 3: Start approval stream (will be cancelled during processing)
|
||||
response = client.agents.messages.stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approvals": [
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": True,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
# Step 4: Accumulate chunks
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
# Step 5: Verify we got chunks AND a cancelled stop reason
|
||||
assert len(messages) > 0, "Should receive at least some chunks before cancellation"
|
||||
|
||||
# Find stop_reason in messages
|
||||
stop_reasons = [msg for msg in messages if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
|
||||
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason, got {len(stop_reasons)}"
|
||||
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
|
||||
|
||||
# Step 6: Verify run status is cancelled
|
||||
runs = client.runs.list(agent_ids=[agent.id])
|
||||
latest_run = runs.items[0]
|
||||
assert latest_run.status == "cancelled", f"Expected run status 'cancelled', got '{latest_run.status}'"
|
||||
|
||||
# Wait for cancel thread to finish
|
||||
cancel_thread.join(timeout=1.0)
|
||||
|
||||
logger.info(f"✅ Test passed: approval with cancellation handled correctly, received {len(messages)} chunks")
|
||||
|
||||
Reference in New Issue
Block a user