fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)
fix: use shared event + .athrow() to properly set stream_was_cancelled flag
**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.
**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]
**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation
**Tests:**
- test_tool_call with cancellation (OpenAI models) ✅
- test_approve_with_cancellation (approval flow + concurrent cancel) ✅
**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race
👾 Generated with [Letta Code](https://letta.com)
Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.streaming_utils import (
|
||||
@@ -82,6 +83,7 @@ class OpenAIStreamingInterface:
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
@@ -93,6 +95,7 @@ class OpenAIStreamingInterface:
|
||||
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
|
||||
@@ -226,14 +229,15 @@ class OpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -267,6 +271,10 @@ class OpenAIStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"OpenAIStreamingInterface: Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
@@ -561,9 +569,11 @@ class SimpleOpenAIStreamingInterface:
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
|
||||
@@ -715,14 +725,15 @@ class SimpleOpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -764,6 +775,10 @@ class SimpleOpenAIStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"SimpleOpenAIStreamingInterface: Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
@@ -932,6 +947,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.is_openai_proxy = is_openai_proxy
|
||||
self.messages = messages
|
||||
@@ -946,6 +962,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
self.message_id = None
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
@@ -1102,14 +1119,15 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
)
|
||||
# Continue to next event rather than killing the stream
|
||||
continue
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -1136,6 +1154,10 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"ResponsesAPI Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
|
||||
Reference in New Issue
Block a user