fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)
fix: use shared event + .athrow() to properly set stream_was_cancelled flag
**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.
**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]
**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation
**Tests:**
- test_tool_call with cancellation (OpenAI models) ✅
- test_approve_with_cancellation (approval flow + concurrent cancel) ✅
**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race
👾 Generated with [Letta Code](https://letta.com)
Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -16,6 +16,7 @@ from letta.schemas.letta_message_content import LettaMessageContentUnion
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.streaming_response import get_cancellation_event_for_run
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
|
||||
@@ -70,6 +71,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
# Store request data
|
||||
self.request_data = request_data
|
||||
|
||||
# Get cancellation event for this run to enable graceful cancellation (before branching)
|
||||
cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None
|
||||
|
||||
# Instantiate streaming interface
|
||||
if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]:
|
||||
# NOTE: different
|
||||
@@ -102,6 +106,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
cancellation_event=cancellation_event,
|
||||
)
|
||||
else:
|
||||
self.interface = SimpleOpenAIStreamingInterface(
|
||||
@@ -112,12 +117,14 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
model=self.llm_config.model,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
cancellation_event=cancellation_event,
|
||||
)
|
||||
elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]:
|
||||
self.interface = SimpleGeminiStreamingInterface(
|
||||
requires_approval_tools=requires_approval_tools,
|
||||
run_id=self.run_id,
|
||||
step_id=step_id,
|
||||
cancellation_event=cancellation_event,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}")
|
||||
|
||||
@@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -228,10 +229,10 @@ class SimpleAnthropicStreamingInterface:
|
||||
prev_message_type = new_message_type
|
||||
# print(f"Yielding message: {message}")
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
|
||||
@@ -41,6 +41,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -218,10 +219,10 @@ class AnthropicStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
@@ -726,10 +727,10 @@ class SimpleAnthropicStreamingInterface:
|
||||
prev_message_type = new_message_type
|
||||
# print(f"Yielding message: {message}")
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
|
||||
@@ -26,6 +26,7 @@ from letta.schemas.letta_message_content import (
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
from letta.utils import get_tool_call_id
|
||||
|
||||
@@ -43,9 +44,11 @@ class SimpleGeminiStreamingInterface:
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
# self.messages = messages
|
||||
# self.tools = tools
|
||||
@@ -89,6 +92,9 @@ class SimpleGeminiStreamingInterface:
|
||||
# Raw usage from provider (for transparent logging in provider trace)
|
||||
self.raw_usage: dict | None = None
|
||||
|
||||
# Track cancellation status
|
||||
self.stream_was_cancelled: bool = False
|
||||
|
||||
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
|
||||
"""This is (unusually) in chunked format, instead of merged"""
|
||||
for content in self.content_parts:
|
||||
@@ -137,10 +143,10 @@ class SimpleGeminiStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
||||
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
||||
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
||||
new_message_type = message.message_type
|
||||
if new_message_type != prev_message_type:
|
||||
@@ -164,7 +170,11 @@ class SimpleGeminiStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
logger.info("GeminiStreamingInterface: Stream processing complete.")
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}")
|
||||
|
||||
async def _process_event(
|
||||
self,
|
||||
|
||||
@@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.streaming_response import RunCancelledException
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.streaming_utils import (
|
||||
@@ -82,6 +83,7 @@ class OpenAIStreamingInterface:
|
||||
requires_approval_tools: list = [],
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
@@ -93,6 +95,7 @@ class OpenAIStreamingInterface:
|
||||
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
|
||||
@@ -226,14 +229,15 @@ class OpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -267,6 +271,10 @@ class OpenAIStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"OpenAIStreamingInterface: Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
@@ -561,9 +569,11 @@ class SimpleOpenAIStreamingInterface:
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
|
||||
@@ -715,14 +725,15 @@ class SimpleOpenAIStreamingInterface:
|
||||
message_index += 1
|
||||
prev_message_type = new_message_type
|
||||
yield message
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -764,6 +775,10 @@ class SimpleOpenAIStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"SimpleOpenAIStreamingInterface: Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
@@ -932,6 +947,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
model: str = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
cancellation_event: Optional["asyncio.Event"] = None,
|
||||
):
|
||||
self.is_openai_proxy = is_openai_proxy
|
||||
self.messages = messages
|
||||
@@ -946,6 +962,7 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
self.message_id = None
|
||||
self.run_id = run_id
|
||||
self.step_id = step_id
|
||||
self.cancellation_event = cancellation_event
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_message_id = Message.generate_id()
|
||||
@@ -1102,14 +1119,15 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
)
|
||||
# Continue to next event rather than killing the stream
|
||||
continue
|
||||
except asyncio.CancelledError as e:
|
||||
except (asyncio.CancelledError, RunCancelledException) as e:
|
||||
import traceback
|
||||
|
||||
self.stream_was_cancelled = True
|
||||
logger.warning(
|
||||
"Stream was cancelled (CancelledError). Attempting to process current event. "
|
||||
"Stream was cancelled (%s). Attempting to process current event. "
|
||||
f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. "
|
||||
f"Error: %s, trace: %s",
|
||||
type(e).__name__,
|
||||
e,
|
||||
traceback.format_exc(),
|
||||
)
|
||||
@@ -1136,6 +1154,10 @@ class SimpleOpenAIResponsesStreamingInterface:
|
||||
yield LettaStopReason(stop_reason=StopReasonType.error)
|
||||
raise e
|
||||
finally:
|
||||
# Check if cancellation was signaled via shared event
|
||||
if self.cancellation_event and self.cancellation_event.is_set():
|
||||
self.stream_was_cancelled = True
|
||||
|
||||
logger.info(
|
||||
f"ResponsesAPI Stream processing complete. "
|
||||
f"Received {self.total_events_received} events, "
|
||||
|
||||
@@ -289,11 +289,14 @@ async def retrieve_conversation_stream(
|
||||
)
|
||||
|
||||
if settings.enable_cancellation_aware_streaming:
|
||||
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
|
||||
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=stream,
|
||||
run_manager=server.run_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
cancellation_event=get_cancellation_event_for_run(run.id),
|
||||
)
|
||||
|
||||
if request and request.include_pings and settings.enable_keepalive:
|
||||
|
||||
@@ -393,11 +393,14 @@ async def retrieve_stream_for_run(
|
||||
)
|
||||
|
||||
if settings.enable_cancellation_aware_streaming:
|
||||
from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run
|
||||
|
||||
stream = cancellation_aware_stream_wrapper(
|
||||
stream_generator=stream,
|
||||
run_manager=server.run_manager,
|
||||
run_id=run_id,
|
||||
actor=actor,
|
||||
cancellation_event=get_cancellation_event_for_run(run_id),
|
||||
)
|
||||
|
||||
if request.include_pings and settings.enable_keepalive:
|
||||
|
||||
@@ -7,6 +7,7 @@ import json
|
||||
import re
|
||||
from collections.abc import AsyncIterator
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
import anyio
|
||||
@@ -26,6 +27,17 @@ from letta.utils import safe_create_task
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# Global registry of cancellation events per run_id
|
||||
# Note: Events are small and we don't bother cleaning them up
|
||||
_cancellation_events: Dict[str, asyncio.Event] = {}
|
||||
|
||||
|
||||
def get_cancellation_event_for_run(run_id: str) -> asyncio.Event:
|
||||
"""Get or create a cancellation event for a run."""
|
||||
if run_id not in _cancellation_events:
|
||||
_cancellation_events[run_id] = asyncio.Event()
|
||||
return _cancellation_events[run_id]
|
||||
|
||||
|
||||
class RunCancelledException(Exception):
|
||||
"""Exception raised when a run is explicitly cancelled (not due to client timeout)"""
|
||||
@@ -125,6 +137,7 @@ async def cancellation_aware_stream_wrapper(
|
||||
run_id: str,
|
||||
actor: User,
|
||||
cancellation_check_interval: float = 0.5,
|
||||
cancellation_event: Optional[asyncio.Event] = None,
|
||||
) -> AsyncIterator[str | bytes]:
|
||||
"""
|
||||
Wraps a stream generator to provide real-time run cancellation checking.
|
||||
@@ -156,11 +169,22 @@ async def cancellation_aware_stream_wrapper(
|
||||
run = await run_manager.get_run_by_id(run_id=run_id, actor=actor)
|
||||
if run.status == RunStatus.cancelled:
|
||||
logger.info(f"Stream cancelled for run {run_id}, interrupting stream")
|
||||
|
||||
# Signal cancellation via shared event if available
|
||||
if cancellation_event:
|
||||
cancellation_event.set()
|
||||
logger.info(f"Set cancellation event for run {run_id}")
|
||||
|
||||
# Send cancellation event to client
|
||||
cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
|
||||
yield f"data: {json.dumps(cancellation_event)}\n\n"
|
||||
# Raise custom exception for explicit run cancellation
|
||||
raise RunCancelledException(run_id, f"Run {run_id} was cancelled")
|
||||
stop_event = {"message_type": "stop_reason", "stop_reason": "cancelled"}
|
||||
yield f"data: {json.dumps(stop_event)}\n\n"
|
||||
|
||||
# Inject exception INTO the generator so its except blocks can catch it
|
||||
try:
|
||||
await stream_generator.athrow(RunCancelledException(run_id, f"Run {run_id} was cancelled"))
|
||||
except (StopAsyncIteration, RunCancelledException):
|
||||
# Generator closed gracefully or raised the exception back
|
||||
break
|
||||
except RunCancelledException:
|
||||
# Re-raise cancellation immediately, don't catch it
|
||||
raise
|
||||
@@ -173,9 +197,10 @@ async def cancellation_aware_stream_wrapper(
|
||||
yield chunk
|
||||
|
||||
except RunCancelledException:
|
||||
# Re-raise RunCancelledException to distinguish from client timeout
|
||||
# Don't re-raise - we already injected the exception into the generator
|
||||
# The generator has handled it and set its stream_was_cancelled flag
|
||||
logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up")
|
||||
raise
|
||||
# Don't raise - let it exit gracefully
|
||||
except asyncio.CancelledError:
|
||||
# Re-raise CancelledError (likely client timeout) to ensure proper cleanup
|
||||
logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up")
|
||||
|
||||
@@ -38,9 +38,11 @@ from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator
|
||||
from letta.server.rest_api.streaming_response import (
|
||||
RunCancelledException,
|
||||
StreamingResponseWithStatusCode,
|
||||
add_keepalive_to_stream,
|
||||
cancellation_aware_stream_wrapper,
|
||||
get_cancellation_event_for_run,
|
||||
)
|
||||
from letta.server.rest_api.utils import capture_sentry_exception
|
||||
from letta.services.run_manager import RunManager
|
||||
@@ -168,6 +170,7 @@ class StreamingService:
|
||||
run_manager=self.runs_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
cancellation_event=get_cancellation_event_for_run(run.id),
|
||||
)
|
||||
|
||||
safe_create_task(
|
||||
@@ -195,6 +198,7 @@ class StreamingService:
|
||||
run_manager=self.runs_manager,
|
||||
run_id=run.id,
|
||||
actor=actor,
|
||||
cancellation_event=get_cancellation_event_for_run(run.id),
|
||||
)
|
||||
|
||||
# conditionally wrap with keepalive based on request parameter
|
||||
@@ -451,6 +455,14 @@ class StreamingService:
|
||||
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
|
||||
# Send [DONE] marker to properly close the stream
|
||||
yield "data: [DONE]\n\n"
|
||||
except RunCancelledException as e:
|
||||
# Run was explicitly cancelled - this is not an error
|
||||
# The cancellation has already been handled by cancellation_aware_stream_wrapper
|
||||
logger.info(f"Run {run_id} was cancelled, exiting stream gracefully")
|
||||
# Send [DONE] to properly close the stream
|
||||
yield "data: [DONE]\n\n"
|
||||
# Don't update run status in finally - cancellation is already recorded
|
||||
run_status = None # Signal to finally block to skip update
|
||||
except Exception as e:
|
||||
run_status = RunStatus.failed
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.error)
|
||||
|
||||
@@ -198,3 +198,8 @@ async def test_background_streaming_cancellation(
|
||||
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
|
||||
messages_from_stream = await accumulate_chunks(response)
|
||||
assert len(messages_from_stream) > 0
|
||||
|
||||
# Verify the stream contains stop_reason: cancelled (from our new cancellation logic)
|
||||
stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
|
||||
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}"
|
||||
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from typing import Any, List
|
||||
@@ -1333,3 +1334,69 @@ def test_agent_records_last_stop_reason_after_approval_flow(
|
||||
# Verify final agent state has the most recent stop reason
|
||||
final_agent = client.agents.retrieve(agent_id=agent.id)
|
||||
assert final_agent.last_stop_reason is not None
|
||||
|
||||
|
||||
def test_approve_with_cancellation(
|
||||
client: Letta,
|
||||
agent: AgentState,
|
||||
) -> None:
|
||||
"""
|
||||
Test that when approval and cancellation happen simultaneously,
|
||||
the stream returns stop_reason: cancelled and stream_was_cancelled is set.
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
# Step 1: Send message that triggers approval request
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
tool_call_id = response.messages[-1].tool_call.tool_call_id
|
||||
|
||||
# Step 2: Start cancellation in background thread
|
||||
def cancel_after_delay():
|
||||
time.sleep(0.3) # Wait for stream to start
|
||||
client.agents.messages.cancel(agent_id=agent.id)
|
||||
|
||||
cancel_thread = threading.Thread(target=cancel_after_delay, daemon=True)
|
||||
cancel_thread.start()
|
||||
|
||||
# Step 3: Start approval stream (will be cancelled during processing)
|
||||
response = client.agents.messages.stream(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
{
|
||||
"type": "approval",
|
||||
"approvals": [
|
||||
{
|
||||
"type": "approval",
|
||||
"approve": True,
|
||||
"tool_call_id": tool_call_id,
|
||||
},
|
||||
],
|
||||
},
|
||||
],
|
||||
stream_tokens=True,
|
||||
)
|
||||
|
||||
# Step 4: Accumulate chunks
|
||||
messages = accumulate_chunks(response)
|
||||
|
||||
# Step 5: Verify we got chunks AND a cancelled stop reason
|
||||
assert len(messages) > 0, "Should receive at least some chunks before cancellation"
|
||||
|
||||
# Find stop_reason in messages
|
||||
stop_reasons = [msg for msg in messages if hasattr(msg, "message_type") and msg.message_type == "stop_reason"]
|
||||
assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason, got {len(stop_reasons)}"
|
||||
assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'"
|
||||
|
||||
# Step 6: Verify run status is cancelled
|
||||
runs = client.runs.list(agent_ids=[agent.id])
|
||||
latest_run = runs.items[0]
|
||||
assert latest_run.status == "cancelled", f"Expected run status 'cancelled', got '{latest_run.status}'"
|
||||
|
||||
# Wait for cancel thread to finish
|
||||
cancel_thread.join(timeout=1.0)
|
||||
|
||||
logger.info(f"✅ Test passed: approval with cancellation handled correctly, received {len(messages)} chunks")
|
||||
|
||||
Reference in New Issue
Block a user