diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index c3d14ffa..ef1a6c26 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -16,6 +16,7 @@ from letta.schemas.letta_message_content import LettaMessageContentUnion from letta.schemas.provider_trace import ProviderTrace from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User +from letta.server.rest_api.streaming_response import get_cancellation_event_for_run from letta.settings import settings from letta.utils import safe_create_task @@ -70,6 +71,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): # Store request data self.request_data = request_data + # Get cancellation event for this run to enable graceful cancellation (before branching) + cancellation_event = get_cancellation_event_for_run(self.run_id) if self.run_id else None + # Instantiate streaming interface if self.llm_config.model_endpoint_type in [ProviderType.anthropic, ProviderType.bedrock]: # NOTE: different @@ -102,6 +106,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): requires_approval_tools=requires_approval_tools, run_id=self.run_id, step_id=step_id, + cancellation_event=cancellation_event, ) else: self.interface = SimpleOpenAIStreamingInterface( @@ -112,12 +117,14 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): model=self.llm_config.model, run_id=self.run_id, step_id=step_id, + cancellation_event=cancellation_event, ) elif self.llm_config.model_endpoint_type in [ProviderType.google_ai, ProviderType.google_vertex]: self.interface = SimpleGeminiStreamingInterface( requires_approval_tools=requires_approval_tools, run_id=self.run_id, step_id=step_id, + cancellation_event=cancellation_event, ) else: raise ValueError(f"Streaming not supported for provider {self.llm_config.model_endpoint_type}") diff --git a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py index ffe6ac63..0df80176 100644 --- a/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py +++ b/letta/interfaces/anthropic_parallel_tool_call_streaming_interface.py @@ -39,6 +39,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser +from letta.server.rest_api.streaming_response import RunCancelledException from letta.server.rest_api.utils import decrement_message_uuid logger = get_logger(__name__) @@ -228,10 +229,10 @@ class SimpleAnthropicStreamingInterface: prev_message_type = new_message_type # print(f"Yielding message: {message}") yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index e27d38bc..8c2be4c6 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -41,6 +41,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser +from letta.server.rest_api.streaming_response import RunCancelledException logger = get_logger(__name__) @@ -218,10 +219,10 @@ class AnthropicStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: @@ -726,10 +727,10 @@ class SimpleAnthropicStreamingInterface: prev_message_type = new_message_type # print(f"Yielding message: {message}") yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: diff --git a/letta/interfaces/gemini_streaming_interface.py b/letta/interfaces/gemini_streaming_interface.py index 91fbb502..629da143 100644 --- a/letta/interfaces/gemini_streaming_interface.py +++ b/letta/interfaces/gemini_streaming_interface.py @@ -26,6 +26,7 @@ from letta.schemas.letta_message_content import ( from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall +from letta.server.rest_api.streaming_response import RunCancelledException from letta.server.rest_api.utils import decrement_message_uuid from letta.utils import get_tool_call_id @@ -43,9 +44,11 @@ class SimpleGeminiStreamingInterface: requires_approval_tools: list = [], run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event # self.messages = messages # self.tools = tools @@ -89,6 +92,9 @@ class SimpleGeminiStreamingInterface: # Raw usage from provider (for transparent logging in provider trace) self.raw_usage: dict | None = None + # Track cancellation status + self.stream_was_cancelled: bool = False + def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]: """This is (unusually) in chunked format, instead of merged""" for content in self.content_parts: @@ -137,10 +143,10 @@ class SimpleGeminiStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback - logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc()) + logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: @@ -164,7 +170,11 @@ class SimpleGeminiStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: - logger.info("GeminiStreamingInterface: Stream processing complete.") + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + + logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}") async def _process_event( self, diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 36e3dfa6..7a78d813 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -54,6 +54,7 @@ from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import OptimisticJSONParser +from letta.server.rest_api.streaming_response import RunCancelledException from letta.server.rest_api.utils import decrement_message_uuid from letta.services.context_window_calculator.token_counter import create_token_counter from letta.streaming_utils import ( @@ -82,6 +83,7 @@ class OpenAIStreamingInterface: requires_approval_tools: list = [], run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.use_assistant_message = use_assistant_message @@ -93,6 +95,7 @@ class OpenAIStreamingInterface: self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg) @@ -226,14 +229,15 @@ class OpenAIStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback self.stream_was_cancelled = True logger.warning( - "Stream was cancelled (CancelledError). Attempting to process current event. " + "Stream was cancelled (%s). Attempting to process current event. " f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. " f"Error: %s, trace: %s", + type(e).__name__, e, traceback.format_exc(), ) @@ -267,6 +271,10 @@ class OpenAIStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + logger.info( f"OpenAIStreamingInterface: Stream processing complete. " f"Received {self.total_events_received} events, " @@ -561,9 +569,11 @@ class SimpleOpenAIStreamingInterface: model: str = None, run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -715,14 +725,15 @@ class SimpleOpenAIStreamingInterface: message_index += 1 prev_message_type = new_message_type yield message - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback self.stream_was_cancelled = True logger.warning( - "Stream was cancelled (CancelledError). Attempting to process current event. " + "Stream was cancelled (%s). Attempting to process current event. " f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. " f"Error: %s, trace: %s", + type(e).__name__, e, traceback.format_exc(), ) @@ -764,6 +775,10 @@ class SimpleOpenAIStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + logger.info( f"SimpleOpenAIStreamingInterface: Stream processing complete. " f"Received {self.total_events_received} events, " @@ -932,6 +947,7 @@ class SimpleOpenAIResponsesStreamingInterface: model: str = None, run_id: str | None = None, step_id: str | None = None, + cancellation_event: Optional["asyncio.Event"] = None, ): self.is_openai_proxy = is_openai_proxy self.messages = messages @@ -946,6 +962,7 @@ class SimpleOpenAIResponsesStreamingInterface: self.message_id = None self.run_id = run_id self.step_id = step_id + self.cancellation_event = cancellation_event # Premake IDs for database writes self.letta_message_id = Message.generate_id() @@ -1102,14 +1119,15 @@ class SimpleOpenAIResponsesStreamingInterface: ) # Continue to next event rather than killing the stream continue - except asyncio.CancelledError as e: + except (asyncio.CancelledError, RunCancelledException) as e: import traceback self.stream_was_cancelled = True logger.warning( - "Stream was cancelled (CancelledError). Attempting to process current event. " + "Stream was cancelled (%s). Attempting to process current event. " f"Events received so far: {self.total_events_received}, last event: {self.last_event_type}. " f"Error: %s, trace: %s", + type(e).__name__, e, traceback.format_exc(), ) @@ -1136,6 +1154,10 @@ class SimpleOpenAIResponsesStreamingInterface: yield LettaStopReason(stop_reason=StopReasonType.error) raise e finally: + # Check if cancellation was signaled via shared event + if self.cancellation_event and self.cancellation_event.is_set(): + self.stream_was_cancelled = True + logger.info( f"ResponsesAPI Stream processing complete. " f"Received {self.total_events_received} events, " diff --git a/letta/server/rest_api/routers/v1/conversations.py b/letta/server/rest_api/routers/v1/conversations.py index 9cbba747..691e89af 100644 --- a/letta/server/rest_api/routers/v1/conversations.py +++ b/letta/server/rest_api/routers/v1/conversations.py @@ -289,11 +289,14 @@ async def retrieve_conversation_stream( ) if settings.enable_cancellation_aware_streaming: + from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run + stream = cancellation_aware_stream_wrapper( stream_generator=stream, run_manager=server.run_manager, run_id=run.id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run.id), ) if request and request.include_pings and settings.enable_keepalive: diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index 30316d46..b4c3973d 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -393,11 +393,14 @@ async def retrieve_stream_for_run( ) if settings.enable_cancellation_aware_streaming: + from letta.server.rest_api.streaming_response import cancellation_aware_stream_wrapper, get_cancellation_event_for_run + stream = cancellation_aware_stream_wrapper( stream_generator=stream, run_manager=server.run_manager, run_id=run_id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run_id), ) if request.include_pings and settings.enable_keepalive: diff --git a/letta/server/rest_api/streaming_response.py b/letta/server/rest_api/streaming_response.py index 9869ff5c..02d727ff 100644 --- a/letta/server/rest_api/streaming_response.py +++ b/letta/server/rest_api/streaming_response.py @@ -7,6 +7,7 @@ import json import re from collections.abc import AsyncIterator from datetime import datetime, timezone +from typing import Dict, Optional from uuid import uuid4 import anyio @@ -26,6 +27,17 @@ from letta.utils import safe_create_task logger = get_logger(__name__) +# Global registry of cancellation events per run_id +# Note: Events are small and we don't bother cleaning them up +_cancellation_events: Dict[str, asyncio.Event] = {} + + +def get_cancellation_event_for_run(run_id: str) -> asyncio.Event: + """Get or create a cancellation event for a run.""" + if run_id not in _cancellation_events: + _cancellation_events[run_id] = asyncio.Event() + return _cancellation_events[run_id] + class RunCancelledException(Exception): """Exception raised when a run is explicitly cancelled (not due to client timeout)""" @@ -125,6 +137,7 @@ async def cancellation_aware_stream_wrapper( run_id: str, actor: User, cancellation_check_interval: float = 0.5, + cancellation_event: Optional[asyncio.Event] = None, ) -> AsyncIterator[str | bytes]: """ Wraps a stream generator to provide real-time run cancellation checking. @@ -156,11 +169,22 @@ async def cancellation_aware_stream_wrapper( run = await run_manager.get_run_by_id(run_id=run_id, actor=actor) if run.status == RunStatus.cancelled: logger.info(f"Stream cancelled for run {run_id}, interrupting stream") + + # Signal cancellation via shared event if available + if cancellation_event: + cancellation_event.set() + logger.info(f"Set cancellation event for run {run_id}") + # Send cancellation event to client - cancellation_event = {"message_type": "stop_reason", "stop_reason": "cancelled"} - yield f"data: {json.dumps(cancellation_event)}\n\n" - # Raise custom exception for explicit run cancellation - raise RunCancelledException(run_id, f"Run {run_id} was cancelled") + stop_event = {"message_type": "stop_reason", "stop_reason": "cancelled"} + yield f"data: {json.dumps(stop_event)}\n\n" + + # Inject exception INTO the generator so its except blocks can catch it + try: + await stream_generator.athrow(RunCancelledException(run_id, f"Run {run_id} was cancelled")) + except (StopAsyncIteration, RunCancelledException): + # Generator closed gracefully or raised the exception back + break except RunCancelledException: # Re-raise cancellation immediately, don't catch it raise @@ -173,9 +197,10 @@ async def cancellation_aware_stream_wrapper( yield chunk except RunCancelledException: - # Re-raise RunCancelledException to distinguish from client timeout + # Don't re-raise - we already injected the exception into the generator + # The generator has handled it and set its stream_was_cancelled flag logger.info(f"Stream for run {run_id} was explicitly cancelled and cleaned up") - raise + # Don't raise - let it exit gracefully except asyncio.CancelledError: # Re-raise CancelledError (likely client timeout) to ensure proper cleanup logger.info(f"Stream for run {run_id} was cancelled (likely client timeout) and cleaned up") diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index bc5a4cdc..354e4490 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -38,9 +38,11 @@ from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.server.rest_api.redis_stream_manager import create_background_stream_processor, redis_sse_stream_generator from letta.server.rest_api.streaming_response import ( + RunCancelledException, StreamingResponseWithStatusCode, add_keepalive_to_stream, cancellation_aware_stream_wrapper, + get_cancellation_event_for_run, ) from letta.server.rest_api.utils import capture_sentry_exception from letta.services.run_manager import RunManager @@ -168,6 +170,7 @@ class StreamingService: run_manager=self.runs_manager, run_id=run.id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run.id), ) safe_create_task( @@ -195,6 +198,7 @@ class StreamingService: run_manager=self.runs_manager, run_id=run.id, actor=actor, + cancellation_event=get_cancellation_event_for_run(run.id), ) # conditionally wrap with keepalive based on request parameter @@ -451,6 +455,14 @@ class StreamingService: yield f"event: error\ndata: {error_message.model_dump_json()}\n\n" # Send [DONE] marker to properly close the stream yield "data: [DONE]\n\n" + except RunCancelledException as e: + # Run was explicitly cancelled - this is not an error + # The cancellation has already been handled by cancellation_aware_stream_wrapper + logger.info(f"Run {run_id} was cancelled, exiting stream gracefully") + # Send [DONE] to properly close the stream + yield "data: [DONE]\n\n" + # Don't update run status in finally - cancellation is already recorded + run_status = None # Signal to finally block to skip update except Exception as e: run_status = RunStatus.failed stop_reason = LettaStopReason(stop_reason=StopReasonType.error) diff --git a/tests/integration_test_cancellation.py b/tests/integration_test_cancellation.py index 57f34e2d..6cc7a0bc 100644 --- a/tests/integration_test_cancellation.py +++ b/tests/integration_test_cancellation.py @@ -198,3 +198,8 @@ async def test_background_streaming_cancellation( response = await client.runs.messages.stream(run_id=run_id, starting_after=0) messages_from_stream = await accumulate_chunks(response) assert len(messages_from_stream) > 0 + + # Verify the stream contains stop_reason: cancelled (from our new cancellation logic) + stop_reasons = [msg for msg in messages_from_stream if hasattr(msg, "message_type") and msg.message_type == "stop_reason"] + assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason in stream, got {len(stop_reasons)}" + assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'" diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 73584629..be9078a8 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -1,3 +1,4 @@ +import asyncio import logging import uuid from typing import Any, List @@ -1333,3 +1334,69 @@ def test_agent_records_last_stop_reason_after_approval_flow( # Verify final agent state has the most recent stop reason final_agent = client.agents.retrieve(agent_id=agent.id) assert final_agent.last_stop_reason is not None + + +def test_approve_with_cancellation( + client: Letta, + agent: AgentState, +) -> None: + """ + Test that when approval and cancellation happen simultaneously, + the stream returns stop_reason: cancelled and stream_was_cancelled is set. + """ + import threading + import time + + # Step 1: Send message that triggers approval request + response = client.agents.messages.create( + agent_id=agent.id, + messages=USER_MESSAGE_TEST_APPROVAL, + ) + tool_call_id = response.messages[-1].tool_call.tool_call_id + + # Step 2: Start cancellation in background thread + def cancel_after_delay(): + time.sleep(0.3) # Wait for stream to start + client.agents.messages.cancel(agent_id=agent.id) + + cancel_thread = threading.Thread(target=cancel_after_delay, daemon=True) + cancel_thread.start() + + # Step 3: Start approval stream (will be cancelled during processing) + response = client.agents.messages.stream( + agent_id=agent.id, + messages=[ + { + "type": "approval", + "approvals": [ + { + "type": "approval", + "approve": True, + "tool_call_id": tool_call_id, + }, + ], + }, + ], + stream_tokens=True, + ) + + # Step 4: Accumulate chunks + messages = accumulate_chunks(response) + + # Step 5: Verify we got chunks AND a cancelled stop reason + assert len(messages) > 0, "Should receive at least some chunks before cancellation" + + # Find stop_reason in messages + stop_reasons = [msg for msg in messages if hasattr(msg, "message_type") and msg.message_type == "stop_reason"] + assert len(stop_reasons) == 1, f"Expected exactly 1 stop_reason, got {len(stop_reasons)}" + assert stop_reasons[0].stop_reason == "cancelled", f"Expected stop_reason 'cancelled', got '{stop_reasons[0].stop_reason}'" + + # Step 6: Verify run status is cancelled + runs = client.runs.list(agent_ids=[agent.id]) + latest_run = runs.items[0] + assert latest_run.status == "cancelled", f"Expected run status 'cancelled', got '{latest_run.status}'" + + # Wait for cancel thread to finish + cancel_thread.join(timeout=1.0) + + logger.info(f"✅ Test passed: approval with cancellation handled correctly, received {len(messages)} chunks")