From c162de51278145f55f0e6c987af362e06d341f93 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 22 Jan 2026 10:33:44 -0800 Subject: [PATCH] =?UTF-8?q?fix:=20use=20shared=20event=20+=20.athrow()=20t?= =?UTF-8?q?o=20properly=20set=20stream=5Fwas=5Fcancelle=E2=80=A6=20(#9019)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit fix: use shared event + .athrow() to properly set stream_was_cancelled flag **Problem:** When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained False because `RunCancelledException` was raised in the consumer code (wrapper), which closes the generator from outside. This causes Python to skip the generator's except blocks and jump directly to finally with the wrong flag value. **Solution:** 1. Shared `asyncio.Event` registry for cross-layer cancellation signaling 2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected 3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise) 4. All streaming interfaces check event in `finally` block to set flag correctly 5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE] **Changes:** - streaming_response.py: Event registry + .athrow() injection + graceful handling - openai_streaming_interface.py: 3 classes check event in finally - gemini_streaming_interface.py: Check event in finally - anthropic_*.py: Catch RunCancelledException - simple_llm_stream_adapter.py: Create & pass event to interfaces - streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update - routers/v1/{conversations,runs}.py: Pass event to wrapper - integration_test_human_in_the_loop.py: New test for approval + cancellation **Tests:** - test_tool_call with cancellation (OpenAI models) ✅ - test_approve_with_cancellation (approval flow + concurrent cancel) ✅ **Known cosmetic warnings (pre-existing):** - "Run already in terminal state" - agent loop tries to update after /cancel - "Stream ended without terminal event" - background streaming timing race 👾 Generated with [Letta Code](https://letta.com) Co-authored-by: Letta --- letta/adapters/simple_llm_stream_adapter.py | 7 ++ ..._parallel_tool_call_streaming_interface.py | 5 +- .../anthropic_streaming_interface.py | 9 +-- .../interfaces/gemini_streaming_interface.py | 16 ++++- .../interfaces/openai_streaming_interface.py | 34 ++++++++-- .../rest_api/routers/v1/conversations.py | 3 + letta/server/rest_api/routers/v1/runs.py | 3 + letta/server/rest_api/streaming_response.py | 37 ++++++++-- letta/services/streaming_service.py | 12 ++++ tests/integration_test_cancellation.py | 5 ++ tests/integration_test_human_in_the_loop.py | 67 +++++++++++++++++++ 11 files changed, 177 insertions(+), 21 deletions(-) diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index c3d14ffa..ef1a6c26 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -16,6 +16,7 @@ from letta.schemas.letta_message_content import LettaMessageContentUnion from letta.schemas.provider_trace import ProviderTrace from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User +from letta.server.rest_api.streaming_response import get_cancellation_event_for_run from letta.settings import settings from letta.utils import safe_create_task @@ -70,6 +71,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): # Store request data self.request_data = request_data + # Get cancellation event for this run to enable graceful cancellation (before branching) + cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None + # Instantiate streaming interface if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]: # NOTE: different @@ -102,6 +106,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): requires_approval_tools=requires_approval_tools, run_id=self.run_id, step_id=step_id, + cancellation_event=cancellation_event, ) else: self.interface = SimpleOpenAIStreamingInterface( @@ -112,12 +117,14 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): model=self.llm_config.model, run_id=self.run_id, step_id=step_id, + cancellation_event=cancellation_event, ) elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]: self.interface = SimpleGeminiStreamingInterface( requires_approval_tools=requires_approval_tools, run_id=self.run_id, step_id=step_id, + cancellation_event=cancellation_event, ) else: raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}") diff --git a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py index ffe6ac63..0df80176 100644 --- a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py +++ b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py @@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser +from letta.server.rest_api.streaming_response import RunCancelledException from letta.server.rest_api.utils import decrement_message_uuid logger = get_logger(__name__) @@ -228,10 +229,10 @@ class SimpleAnthropicStreamingInterface: prev_message_type = new_message_type # print(f"Yielding message: {message}") yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index e27d38bc..8c2be4c6 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -41,6 +41,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser +from letta.server.rest_api.streaming_response import RunCancelledException logger = get_logger(__name__) @@ -218,10 +219,10 @@ class AnthropicStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: @@ -726,10 +727,10 @@ class SimpleAnthropicStreamingInterface: prev_message_type = new_message_type # print(f"Yielding message: {message}") yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: diff --git a/letta/interfaces/gemini_streaming_interface.py b/letta/interfaces/gemini_streaming_interface.py index 91fbb502..629da143 100644 --- a/letta/interfaces/gemini_streaming_interface.py +++ b/letta/interfaces/gemini_streaming_interface.py @@ -26,6 +26,7 @@ from letta.schemas.letta_message_content import ( from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall +from letta.server.rest_api.streaming_response import RunCancelledException from letta.server.rest_api.utils import decrement_message_uuid from letta.utils import get_tool_call_id @@ -43,9 +44,11 @@ class SimpleGeminiStreamingInterface: requires_approval_tools: list = [], run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event # self.messages = messages # self.tools = tools @@ -89,6 +92,9 @@ class SimpleGeminiStreamingInterface: # Raw usage from provider (for transparent logging in provider trace) self.raw_usage: dict | None = None + # Track cancellation status + self.stream_was_cancelled: bool = False + def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]: """This is (unusually) in chunked format, instead of merged""" for content in self.content_parts: @@ -137,10 +143,10 @@ class SimpleGeminiStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: @@ -164,7 +170,11 @@ class SimpleGeminiStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: - logger.info("GeminiStreamingInterface: Stream processing complete.") + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + + logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}") async def _process_event( self, diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 36e3dfa6..7a78d813 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import OptimisticJSONParser +from letta.server.rest_api.streaming_response import RunCancelledException from letta.server.rest_api.utils import decrement_message_uuid from letta.services.context_window_calculator.token_counter import create_token_counter from letta.streaming_utils import ( @@ -82,6 +83,7 @@ class OpenAIStreamingInterface: requires_approval_tools: list = [], run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.use_assistant_message = use_assistant_message @@ -93,6 +95,7 @@ class OpenAIStreamingInterface: self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg) @@ -226,14 +229,15 @@ class OpenAIStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback self.stream_was_cancelled = True logger.warning( - "Stream was cancelled (CancelledError). Attempting to process current event. " + "Stream was cancelled (%s). Attempting to process current event. " f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. " f"Error: %s, trace: %s", + type(e).__name__, e, traceback.format_exc(), ) @@ -267,6 +271,10 @@ class OpenAIStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + logger.info( f"OpenAIStreamingInterface: Stream processing complete. " f"Received {self.total_events_received} events, " @@ -561,9 +569,11 @@ class SimpleOpenAIStreamingInterface: model: str = None, run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -715,14 +725,15 @@ class SimpleOpenAIStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback self.stream_was_cancelled = True logger.warning( - "Stream was cancelled (CancelledError). Attempting to process current event. " + "Stream was cancelled (%s). Attempting to process current event. " f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. " f"Error: %s, trace: %s", + type(e).__name__, e, traceback.format_exc(), ) @@ -764,6 +775,10 @@ class SimpleOpenAIStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + logger.info( f"SimpleOpenAIStreamingInterface: Stream processing complete. " f"Received {self.total_events_received} events, " @@ -932,6 +947,7 @@ class SimpleOpenAIResponsesStreamingInterface: model: str = None, run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.is_openai_proxy = is_openai_proxy self.messages = messages @@ -946,6 +962,7 @@ class SimpleOpenAIResponsesStreamingInterface: self.message_id = None self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -1102,14 +1119,15 @@ class SimpleOpenAIResponsesStreamingInterface: ) # Continue to next event rather than killing the stream continue - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback self.stream_was_cancelled = True logger.warning( - "Stream was cancelled (CancelledError). Attempting to process current event. " + "Stream was cancelled (%s). Attempting to process current event. " f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. " f"Error: %s, trace: %s", + type(e).__name__, e, traceback.format_exc(), ) @@ -1136,6 +1154,10 @@ class SimpleOpenAIResponsesStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + logger.info( f"ResponsesAPI Stream processing complete. " f"Received {self.total_events_received} events, " diff --git a/letta/server/rest_api/routers/v1/conversations.py b/letta/server/rest_api/routers/v1/conversations.py index 9cbba747..691e89af 100644 --- a/letta/server/rest_api/routers/v1/conversations.py +++ b/letta/server/rest_api/routers/v1/conversations.py @@ -289,11 +289,14 @@ async def retrieve_conversation_stream( ) if settings.enable_cancellation_aware_streaming: + from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run + stream = cancellation_aware_stream_wrapper( stream_generator=stream, run_manager=server.run_manager, run_id=run.id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run.id), ) if request and request.include_pings and settings.enable_keepalive: diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 30316d46..b4c3973d 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -393,11 +393,14 @@ async def retrieve_stream_for_run( ) if settings.enable_cancellation_aware_streaming: + from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run + stream = cancellation_aware_stream_wrapper( stream_generator=stream, run_manager=server.run_manager, run_id=run_id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run_id), ) if request.include_pings and settings.enable_keepalive: diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 9869ff5c..02d727ff 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -7,6 +7,7 @@ import json import re from collections.abc import AsyncIterator from datetime import datetime, timezone +from typing import Dict, Optional from uuid import uuid4 import anyio @@ -26,6 +27,17 @@ from letta.utils import safe_create_task logger = get_logger(__name__) +# Global registry of cancellation events per run_id +# Note: Events are small and we don't bother cleaning them up +_cancellation_events: Dict[str, asyncio.Event] = {} + + +def get_cancellation_event_for_run(run_id: str) -> asyncio.Event: + """Get or create a cancellation event for a run.""" + if run_id not in _cancellation_events: + _cancellation_events[run_id] = asyncio.Event() + return _cancellation_events[run_id] + class RunCancelledException(Exception): """Exception raised when a run is explicitly cancelled (not due to client timeout)""" @@ -125,6 +137,7 @@ async def cancellation_aware_stream_wrapper( run_id: str, actor: User, cancellation_check_interval: float = 0.5, + cancellation_event: Optional[asyncio.Event] = None, ) -> AsyncIterator[str | bytes]: """ Wraps a stream generator to provide real-time run cancellation checking. @@ -156,11 +169,22 @@ async def cancellation_aware_stream_wrapper( run = await run_manager.get_run_by_id(run_id=run_id, actor=actor) if run.status == RunStatus.cancelled: logger.info(f"Stream cancelled for run {run_id}, interrupting stream") + + # Signal cancellation via shared event if available + if cancellation_event: + cancellation_event.set() + logger.info(f"Set cancellation event for run {run_id}") + # Send cancellation event to client - cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"} - yield f"data: {json.dumps(cancellation_event)}\n\n" - # Raise custom exception for explicit run cancellation - raise RunCancelledException(run_id, f"Run {run_id} was cancelled") + stop_event = {"message_type": "stop_reason", "stop_reason": "cancelled"} + yield f"data: {json.dumps(stop_event)}\n\n" + + # Inject exception INTO the generator so its except blocks can catch it + try: + await stream_generator.athrow(RunCancelledException(run_id, f"Run {run_id} was cancelled")) + except (StopAsyncIteration, RunCancelledException): + # Generator closed gracefully or raised the exception back + break except RunCancelledException: # Re-raise cancellation immediately, don't catch it raise @@ -173,9 +197,10 @@ async def cancellation_aware_stream_wrapper( yield chunk except RunCancelledException: - # Re-raise RunCancelledException to distinguish from client timeout + # Don't re-raise - we already injected the exception into the generator + # The generator has handled it and set its stream_was_cancelled flag logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up") - raise + # Don't raise - let it exit gracefully except asyncio.CancelledError: # Re-raise CancelledError (likely client timeout) to ensure proper cleanup logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up") diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index bc5a4cdc..354e4490 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -38,9 +38,11 @@ from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator from letta.server.rest_api.streaming_response import ( + RunCancelledException, StreamingResponseWithStatusCode, add_keepalive_to_stream, cancellation_aware_stream_wrapper, + get_cancellation_event_for_run, ) from letta.server.rest_api.utils import capture_sentry_exception from letta.services.run_manager import RunManager @@ -168,6 +170,7 @@ class StreamingService: run_manager=self.runs_manager, run_id=run.id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run.id), ) safe_create_task( @@ -195,6 +198,7 @@ class StreamingService: run_manager=self.runs_manager, run_id=run.id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run.id), ) # conditionally wrap with keepalive based on request parameter @@ -451,6 +455,14 @@ class StreamingService: yield f"event: error\ndata: {error_message.model_dump_json()}\n\n" # Send [DONE] marker to properly close the stream yield "data: [DONE]\n\n" + except RunCancelledException as e: + # Run was explicitly cancelled - this is not an error + # The cancellation has already been handled by cancellation_aware_stream_wrapper + logger.info(f"Run {run_id} was cancelled, exiting stream gracefully") + # Send [DONE] to properly close the stream + yield "data: [DONE]\n\n" + # Don't update run status in finally - cancellation is already recorded + run_status = None # Signal to finally block to skip update except Exception as e: run_status = RunStatus.failed stop_reason = LettaStopReason(stop_reason=StopReasonType.error) diff --git a/tests/integration_test_cancellation.py b/tests/integration_test_cancellation.py index 57f34e2d..6cc7a0bc 100644 --- a/tests/integration_test_cancellation.py +++ b/tests/integration_test_cancellation.py @@ -198,3 +198,8 @@ async def test_background_streaming_cancellation( response = await client.runs.messages.stream(run_id=run_id, starting_after=0) messages_from_stream = await accumulate_chunks(response) assert len(messages_from_stream) > 0 + + # Verify the stream contains stop_reason: cancelled (from our new cancellation logic) + stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"] + assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}" + assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'" diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 73584629..be9078a8 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -1,3 +1,4 @@ +import asyncio import logging import uuid from typing import Any, List @@ -1333,3 +1334,69 @@ def test_agent_records_last_stop_reason_after_approval_flow( # Verify final agent state has the most recent stop reason final_agent = client.agents.retrieve(agent_id=agent.id) assert final_agent.last_stop_reason is not None + + +def test_approve_with_cancellation( + client: Letta, + agent: AgentState, +) -> None: + """ + Test that when approval and cancellation happen simultaneously, + the stream returns stop_reason: cancelled and stream_was_cancelled is set. + """ + import threading + import time + + # Step 1: Send message that triggers approval request + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + # Step 2: Start cancellation in background thread + def cancel_after_delay(): + time.sleep(0.3) # Wait for stream to start + client.agents.messages.cancel(agent_id=agent.id) + + cancel_thread = threading.Thread(target=cancel_after_delay, daemon=True) + cancel_thread.start() + + # Step 3: Start approval stream (will be cancelled during processing) + response = client.agents.messages.stream( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], + }, + ], + stream_tokens=True, + ) + + # Step 4: Accumulate chunks + messages = accumulate_chunks(response) + + # Step 5: Verify we got chunks AND a cancelled stop reason + assert len(messages) > 0, "Should receive at least some chunks before cancellation" + + # Find stop_reason in messages + stop_reasons = [msg for msg in messages if hasattr(msg, "message_type") and msg.message_type == "stop_reason"] + assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason, got {len(stop_reasons)}" + assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'" + + # Step 6: Verify run status is cancelled + runs = client.runs.list(agent_ids=[agent.id]) + latest_run = runs.items[0] + assert latest_run.status == "cancelled", f"Expected run status 'cancelled', got '{latest_run.status}'" + + # Wait for cancel thread to finish + cancel_thread.join(timeout=1.0) + + logger.info(f"✅ Test passed: approval with cancellation handled correctly, received {len(messages)} chunks")