fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)

fix: use shared event + .athrow() to properly set stream_was_cancelled flag

**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.

**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]

**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation

**Tests:**
- test_tool_call with cancellation (OpenAI models) 
- test_approve_with_cancellation (approval flow + concurrent cancel) 

**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race

👾 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
cthomas
2026-01-22 10:33:44 -08:00
committed by Caren Thomas
parent 5ca0f55079
commit c162de5127
11 changed files with 177 additions and 21 deletions

View File

@@ -16,6 +16,7 @@ from letta.schemas.letta_message_content import LettaMessageContentUnion
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.streaming_response import get_cancellation_event_for_run
from letta.settings import settings
from letta.utils import safe_create_task
@@ -70,6 +71,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
# Store request data
self.request_data = request_data
# Get cancellation event for this run to enable graceful cancellation (before branching)
cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None
# Instantiate streaming interface
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
# NOTE: different
@@ -102,6 +106,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
else:
self.interface = SimpleOpenAIStreamingInterface(
@@ -112,12 +117,14 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
model=self.llm_config.model,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
self.interface = SimpleGeminiStreamingInterface(
requires_approval_tools=requires_approval_tools,
run_id=self.run_id,
step_id=step_id,
cancellation_event=cancellation_event,
)
else:
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")

View File

@@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
logger = get_logger(__name__)
@@ -228,10 +229,10 @@ class SimpleAnthropicStreamingInterface:
prev_message_type = new_message_type
# print(f"Yielding message: {message}")
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:

View File

@@ -41,6 +41,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
from letta.server.rest_api.streaming_response import RunCancelledException
logger = get_logger(__name__)
@@ -218,10 +219,10 @@ class AnthropicStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
@@ -726,10 +727,10 @@ class SimpleAnthropicStreamingInterface:
prev_message_type = new_message_type
# print(f"Yielding message: {message}")
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:

View File

@@ -26,6 +26,7 @@ from letta.schemas.letta_message_content import (
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
from letta.utils import get_tool_call_id
@@ -43,9 +44,11 @@ class SimpleGeminiStreamingInterface:
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# self.messages = messages
# self.tools = tools
@@ -89,6 +92,9 @@ class SimpleGeminiStreamingInterface:
# Raw usage from provider (for transparent logging in provider trace)
self.raw_usage: dict | None = None
# Track cancellation status
self.stream_was_cancelled: bool = False
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
"""This is (unusually) in chunked format, instead of merged"""
for content in self.content_parts:
@@ -137,10 +143,10 @@ class SimpleGeminiStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
@@ -164,7 +170,11 @@ class SimpleGeminiStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
logger.info("GeminiStreamingInterface: Stream processing complete.")
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}")
async def _process_event(
self,

View File

@@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
from letta.services.context_window_calculator.token_counter import create_token_counter
from letta.streaming_utils import (
@@ -82,6 +83,7 @@ class OpenAIStreamingInterface:
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.use_assistant_message = use_assistant_message
@@ -93,6 +95,7 @@ class OpenAIStreamingInterface:
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
@@ -226,14 +229,15 @@ class OpenAIStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -267,6 +271,10 @@ class OpenAIStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"OpenAIStreamingInterface: Stream processing complete. "
f"Received {self.total_events_received} events, "
@@ -561,9 +569,11 @@ class SimpleOpenAIStreamingInterface:
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -715,14 +725,15 @@ class SimpleOpenAIStreamingInterface:
message_index += 1
prev_message_type = new_message_type
yield message
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -764,6 +775,10 @@ class SimpleOpenAIStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"SimpleOpenAIStreamingInterface: Stream processing complete. "
f"Received {self.total_events_received} events, "
@@ -932,6 +947,7 @@ class SimpleOpenAIResponsesStreamingInterface:
model: str = None,
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.is_openai_proxy = is_openai_proxy
self.messages = messages
@@ -946,6 +962,7 @@ class SimpleOpenAIResponsesStreamingInterface:
self.message_id = None
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
@@ -1102,14 +1119,15 @@ class SimpleOpenAIResponsesStreamingInterface:
)
# Continue to next event rather than killing the stream
continue
except asyncio.CancelledError as e:
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
self.stream_was_cancelled = True
logger.warning(
"Stream was cancelled (CancelledError). Attempting to process current event. "
"Stream was cancelled (%s). Attempting to process current event. "
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
f"Error: %s, trace: %s",
type(e).__name__,
e,
traceback.format_exc(),
)
@@ -1136,6 +1154,10 @@ class SimpleOpenAIResponsesStreamingInterface:
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(
f"ResponsesAPI Stream processing complete. "
f"Received {self.total_events_received} events, "

View File

@@ -289,11 +289,14 @@ async def retrieve_conversation_stream(
)
if settings.enable_cancellation_aware_streaming:
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
stream = cancellation_aware_stream_wrapper(
stream_generator=stream,
run_manager=server.run_manager,
run_id=run.id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run.id),
)
if request and request.include_pings and settings.enable_keepalive:

View File

@@ -393,11 +393,14 @@ async def retrieve_stream_for_run(
)
if settings.enable_cancellation_aware_streaming:
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
stream = cancellation_aware_stream_wrapper(
stream_generator=stream,
run_manager=server.run_manager,
run_id=run_id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run_id),
)
if request.include_pings and settings.enable_keepalive:

View File

@@ -7,6 +7,7 @@ import json
import re
from collections.abc import AsyncIterator
from datetime import datetime, timezone
from typing import Dict, Optional
from uuid import uuid4
import anyio
@@ -26,6 +27,17 @@ from letta.utils import safe_create_task
logger = get_logger(__name__)
# Global registry of cancellation events per run_id
# Note: Events are small and we don't bother cleaning them up
_cancellation_events: Dict[str, asyncio.Event] = {}
def get_cancellation_event_for_run(run_id: str) -> asyncio.Event:
"""Get or create a cancellation event for a run."""
if run_id not in _cancellation_events:
_cancellation_events[run_id] = asyncio.Event()
return _cancellation_events[run_id]
class RunCancelledException(Exception):
"""Exception raised when a run is explicitly cancelled (not due to client timeout)"""
@@ -125,6 +137,7 @@ async def cancellation_aware_stream_wrapper(
run_id: str,
actor: User,
cancellation_check_interval: float = 0.5,
cancellation_event: Optional[asyncio.Event] = None,
) -> AsyncIterator[str | bytes]:
"""
Wraps a stream generator to provide real-time run cancellation checking.
@@ -156,11 +169,22 @@ async def cancellation_aware_stream_wrapper(
run = await run_manager.get_run_by_id(run_id=run_id, actor=actor)
if run.status == RunStatus.cancelled:
logger.info(f"Stream cancelled for run {run_id}, interrupting stream")
# Signal cancellation via shared event if available
if cancellation_event:
cancellation_event.set()
logger.info(f"Set cancellation event for run {run_id}")
# Send cancellation event to client
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
yield f"data: {json.dumps(cancellation_event)}\n\n"
# Raise custom exception for explicit run cancellation
raise RunCancelledException(run_id, f"Run {run_id} was cancelled")
stop_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
yield f"data: {json.dumps(stop_event)}\n\n"
# Inject exception INTO the generator so its except blocks can catch it
try:
await stream_generator.athrow(RunCancelledException(run_id, f"Run {run_id} was cancelled"))
except (StopAsyncIteration, RunCancelledException):
# Generator closed gracefully or raised the exception back
break
except RunCancelledException:
# Re-raise cancellation immediately, don't catch it
raise
@@ -173,9 +197,10 @@ async def cancellation_aware_stream_wrapper(
yield chunk
except RunCancelledException:
# Re-raise RunCancelledException to distinguish from client timeout
# Don't re-raise - we already injected the exception into the generator
# The generator has handled it and set its stream_was_cancelled flag
logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up")
raise
# Don't raise - let it exit gracefully
except asyncio.CancelledError:
# Re-raise CancelledError (likely client timeout) to ensure proper cleanup
logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up")

View File

@@ -38,9 +38,11 @@ from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
from letta.server.rest_api.streaming_response import (
RunCancelledException,
StreamingResponseWithStatusCode,
add_keepalive_to_stream,
cancellation_aware_stream_wrapper,
get_cancellation_event_for_run,
)
from letta.server.rest_api.utils import capture_sentry_exception
from letta.services.run_manager import RunManager
@@ -168,6 +170,7 @@ class StreamingService:
run_manager=self.runs_manager,
run_id=run.id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run.id),
)
safe_create_task(
@@ -195,6 +198,7 @@ class StreamingService:
run_manager=self.runs_manager,
run_id=run.id,
actor=actor,
cancellation_event=get_cancellation_event_for_run(run.id),
)
# conditionally wrap with keepalive based on request parameter
@@ -451,6 +455,14 @@ class StreamingService:
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
# Send [DONE] marker to properly close the stream
yield "data: [DONE]\n\n"
except RunCancelledException as e:
# Run was explicitly cancelled - this is not an error
# The cancellation has already been handled by cancellation_aware_stream_wrapper
logger.info(f"Run {run_id} was cancelled, exiting stream gracefully")
# Send [DONE] to properly close the stream
yield "data: [DONE]\n\n"
# Don't update run status in finally - cancellation is already recorded
run_status = None # Signal to finally block to skip update
except Exception as e:
run_status = RunStatus.failed
stop_reason = LettaStopReason(stop_reason=StopReasonType.error)

View File

@@ -198,3 +198,8 @@ async def test_background_streaming_cancellation(
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages_from_stream = await accumulate_chunks(response)
assert len(messages_from_stream) > 0
# Verify the stream contains stop_reason: cancelled (from our new cancellation logic)
stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}"
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
import uuid
from typing import Any, List
@@ -1333,3 +1334,69 @@ def test_agent_records_last_stop_reason_after_approval_flow(
# Verify final agent state has the most recent stop reason
final_agent = client.agents.retrieve(agent_id=agent.id)
assert final_agent.last_stop_reason is not None
def test_approve_with_cancellation(
client: Letta,
agent: AgentState,
) -> None:
"""
Test that when approval and cancellation happen simultaneously,
the stream returns stop_reason: cancelled and stream_was_cancelled is set.
"""
import threading
import time
# Step 1: Send message that triggers approval request
response = client.agents.messages.create(
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
tool_call_id = response.messages[-1].tool_call.tool_call_id
# Step 2: Start cancellation in background thread
def cancel_after_delay():
time.sleep(0.3) # Wait for stream to start
client.agents.messages.cancel(agent_id=agent.id)
cancel_thread = threading.Thread(target=cancel_after_delay, daemon=True)
cancel_thread.start()
# Step 3: Start approval stream (will be cancelled during processing)
response = client.agents.messages.stream(
agent_id=agent.id,
messages=[
{
"type": "approval",
"approvals": [
{
"type": "approval",
"approve": True,
"tool_call_id": tool_call_id,
},
],
},
],
stream_tokens=True,
)
# Step 4: Accumulate chunks
messages = accumulate_chunks(response)
# Step 5: Verify we got chunks AND a cancelled stop reason
assert len(messages) > 0, "Should receive at least some chunks before cancellation"
# Find stop_reason in messages
stop_reasons = [msg for msg in messages if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason, got {len(stop_reasons)}"
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
# Step 6: Verify run status is cancelled
runs = client.runs.list(agent_ids=[agent.id])
latest_run = runs.items[0]
assert latest_run.status == "cancelled", f"Expected run status 'cancelled', got '{latest_run.status}'"
# Wait for cancel thread to finish
cancel_thread.join(timeout=1.0)
logger.info(f"✅ Test passed: approval with cancellation handled correctly, received {len(messages)} chunks")