fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)

fix: use shared event + .athrow() to properly set stream_was_cancelled flag

**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.

**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]

**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation

**Tests:**
- test_tool_call with cancellation (OpenAI models) 
- test_approve_with_cancellation (approval flow + concurrent cancel) 

**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race

👾 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
cthomas
2026-01-22 10:33:44 -08:00
committed by Caren Thomas
parent 5ca0f55079
commit c162de5127
11 changed files with 177 additions and 21 deletions

View File

@@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
from letta.services.context_window_calculator.token_counter import create_token_counter
from letta.streaming_utils import (
@@ -82,6 +83,7 @@ class OpenAIStreamingInterface:
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.use_assistant_message = use_assistant_message
@@ -93,6 +95,7 @@ class OpenAIStreamingInterface:
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
@@ -226,14 +229,15 @@ class OpenAIStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -267,6 +271,10 @@ class OpenAIStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"OpenAIStreamingInterface: Stream processing complete. "
f"Received {self.total_events_received} events, "
@@ -561,9 +569,11 @@ class SimpleOpenAIStreamingInterface:
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -715,14 +725,15 @@ class SimpleOpenAIStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -764,6 +775,10 @@ class SimpleOpenAIStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"SimpleOpenAIStreamingInterface: Stream processing complete. "
f"Received {self.total_events_received} events, "
@@ -932,6 +947,7 @@ class SimpleOpenAIResponsesStreamingInterface:
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.is_openai_proxy = is_openai_proxy
self.messages = messages
@@ -946,6 +962,7 @@ class SimpleOpenAIResponsesStreamingInterface:
self.message_id = None
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -1102,14 +1119,15 @@ class SimpleOpenAIResponsesStreamingInterface:
)
# Continue to next event rather than killing the stream
continue
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -1136,6 +1154,10 @@ class SimpleOpenAIResponsesStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"ResponsesAPI Stream processing complete. "
f"Received {self.total_events_received} events, "