fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)

fix: use shared event + .athrow() to properly set stream_was_cancelled flag

**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.

**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]

**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation

**Tests:**
- test_tool_call with cancellation (OpenAI models) 
- test_approve_with_cancellation (approval flow + concurrent cancel) 

**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race

👾 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
cthomas
2026-01-22 10:33:44 -08:00
committed by Caren Thomas
parent 5ca0f55079
commit c162de5127
11 changed files with 177 additions and 21 deletions

View File

@@ -198,3 +198,8 @@ async def test_background_streaming_cancellation(
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages_from_stream = await accumulate_chunks(response)
assert len(messages_from_stream) > 0
# Verify the stream contains stop_reason: cancelled (from our new cancellation logic)
stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}"
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
import uuid
from typing import Any, List
@@ -1333,3 +1334,69 @@ def test_agent_records_last_stop_reason_after_approval_flow(
# Verify final agent state has the most recent stop reason
final_agent = client.agents.retrieve(agent_id=agent.id)
assert final_agent.last_stop_reason is not None
def test_approve_with_cancellation(
client: Letta,
agent: AgentState,
) -> None:
"""
Test that when approval and cancellation happen simultaneously,
the stream returns stop_reason: cancelled and stream_was_cancelled is set.
"""
import threading
import time
# Step 1: Send message that triggers approval request
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
# Step 2: Start cancellation in background thread
def cancel_after_delay():
time.sleep(0.3) # Wait for stream to start
client.agents.messages.cancel(agent_id=agent.id)
cancel_thread = threading.Thread(target=cancel_after_delay, daemon=True)
cancel_thread.start()
# Step 3: Start approval stream (will be cancelled during processing)
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
{
"type": "approval",
"approvals": [
{
"type": "approval",
"approve": True,
"tool_call_id": tool_call_id,
},
],
},
],
stream_tokens=True,
)
# Step 4: Accumulate chunks
messages = accumulate_chunks(response)
# Step 5: Verify we got chunks AND a cancelled stop reason
assert len(messages) > 0, "Should receive at least some chunks before cancellation"
# Find stop_reason in messages
stop_reasons = [msg for msg in messages if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason, got {len(stop_reasons)}"
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
# Step 6: Verify run status is cancelled
runs = client.runs.list(agent_ids=[agent.id])
latest_run = runs.items[0]
assert latest_run.status == "cancelled", f"Expected run status 'cancelled', got '{latest_run.status}'"
# Wait for cancel thread to finish
cancel_thread.join(timeout=1.0)
logger.info(f"✅ Test passed: approval with cancellation handled correctly, received {len(messages)} chunks")