fix: use shared event + .athrow() to properly set stream_was_cancelled flag
**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.
**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]
**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation
**Tests:**
- test_tool_call with cancellation (OpenAI models) ✅
- test_approve_with_cancellation (approval flow + concurrent cancel) ✅
**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race
👾 Generated with [Letta Code](https://letta.com)
Co-authored-by: Letta <noreply@letta.com>
342 lines
15 KiB
Python
342 lines
15 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
from collections.abc import AsyncGenerator
|
|
from datetime import datetime, timezone
|
|
from typing import AsyncIterator, List, Optional
|
|
|
|
from google.genai.types import (
|
|
GenerateContentResponse,
|
|
)
|
|
|
|
from letta.log import get_logger
|
|
from letta.schemas.letta_message import (
|
|
ApprovalRequestMessage,
|
|
AssistantMessage,
|
|
LettaMessage,
|
|
ReasoningMessage,
|
|
ToolCallDelta,
|
|
ToolCallMessage,
|
|
)
|
|
from letta.schemas.letta_message_content import (
|
|
ReasoningContent,
|
|
TextContent,
|
|
ToolCallContent,
|
|
)
|
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
|
from letta.schemas.message import Message
|
|
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
|
from letta.server.rest_api.streaming_response import RunCancelledException
|
|
from letta.server.rest_api.utils import decrement_message_uuid
|
|
from letta.utils import get_tool_call_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SimpleGeminiStreamingInterface:
|
|
"""
|
|
Encapsulates the logic for streaming responses from Gemini API:
|
|
https://ai.google.dev/gemini-api/docs/text-generation#streaming-responses
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
requires_approval_tools: list = [],
|
|
run_id: str | None = None,
|
|
step_id: str | None = None,
|
|
cancellation_event: Optional["asyncio.Event"] = None,
|
|
):
|
|
self.run_id = run_id
|
|
self.step_id = step_id
|
|
self.cancellation_event = cancellation_event
|
|
|
|
# self.messages = messages
|
|
# self.tools = tools
|
|
self.requires_approval_tools = requires_approval_tools
|
|
# ID responses used
|
|
self.message_id = None
|
|
|
|
# In Gemini streaming, tool call comes all at once
|
|
self.tool_call_id: str | None = None
|
|
self.tool_call_name: str | None = None
|
|
self.tool_call_args: dict | None = None # NOTE: Not a str!
|
|
|
|
self.collected_tool_calls: list[ToolCall] = []
|
|
|
|
# NOTE: signature only is included if tools are present
|
|
self.thinking_signature: str | None = None
|
|
|
|
# Regular text content too (avoid O(n^2) by accumulating parts)
|
|
self._text_parts: list[str] = []
|
|
self.text_content: str | None = None # legacy; not used elsewhere
|
|
|
|
# Premake IDs for database writes
|
|
self.letta_message_id = Message.generate_id()
|
|
# self.model = model
|
|
|
|
# Sadly, Gemini's encrypted reasoning logic forces us to store stream parts in state
|
|
self.content_parts: List[ReasoningContent | TextContent | ToolCallContent] = []
|
|
|
|
# Token counters
|
|
self.input_tokens = 0
|
|
self.output_tokens = 0
|
|
|
|
# Cache token tracking (Gemini uses cached_content_token_count)
|
|
# None means "not reported by provider", 0 means "provider reported 0"
|
|
self.cached_tokens: int | None = None
|
|
|
|
# Thinking/reasoning token tracking (Gemini uses thoughts_token_count)
|
|
# None means "not reported by provider", 0 means "provider reported 0"
|
|
self.thinking_tokens: int | None = None
|
|
|
|
# Raw usage from provider (for transparent logging in provider trace)
|
|
self.raw_usage: dict | None = None
|
|
|
|
# Track cancellation status
|
|
self.stream_was_cancelled: bool = False
|
|
|
|
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
|
|
"""This is (unusually) in chunked format, instead of merged"""
|
|
for content in self.content_parts:
|
|
if isinstance(content, ReasoningContent):
|
|
# This assumes there is only one signature per turn
|
|
content.signature = self.thinking_signature
|
|
return self.content_parts
|
|
|
|
def get_tool_call_object(self) -> ToolCall:
|
|
"""Useful for agent loop"""
|
|
if self.collected_tool_calls:
|
|
return self.collected_tool_calls[-1]
|
|
|
|
if self.tool_call_id is None:
|
|
raise ValueError("No tool call ID available")
|
|
if self.tool_call_name is None:
|
|
raise ValueError("No tool call name available")
|
|
if self.tool_call_args is None:
|
|
raise ValueError("No tool call arguments available")
|
|
|
|
tool_call_args_str = json.dumps(self.tool_call_args)
|
|
return ToolCall(id=self.tool_call_id, function=FunctionCall(name=self.tool_call_name, arguments=tool_call_args_str))
|
|
|
|
def get_tool_call_objects(self) -> list[ToolCall]:
|
|
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
|
return list(self.collected_tool_calls)
|
|
|
|
async def process(
|
|
self,
|
|
stream: AsyncIterator[GenerateContentResponse],
|
|
ttft_span: Optional["Span"] = None,
|
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
|
"""
|
|
Iterates over the Gemini stream, yielding SSE events.
|
|
It also collects tokens and detects if a tool call is triggered.
|
|
"""
|
|
prev_message_type = None
|
|
message_index = 0
|
|
try:
|
|
async for event in stream:
|
|
try:
|
|
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
|
new_message_type = message.message_type
|
|
if new_message_type != prev_message_type:
|
|
if prev_message_type != None:
|
|
message_index += 1
|
|
prev_message_type = new_message_type
|
|
yield message
|
|
except (asyncio.CancelledError, RunCancelledException) as e:
|
|
import traceback
|
|
|
|
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
|
|
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
|
new_message_type = message.message_type
|
|
if new_message_type != prev_message_type:
|
|
if prev_message_type != None:
|
|
message_index += 1
|
|
prev_message_type = new_message_type
|
|
yield message
|
|
|
|
# Don't raise the exception here
|
|
continue
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
logger.exception("Error processing stream: %s", e)
|
|
if ttft_span:
|
|
ttft_span.add_event(
|
|
name="stop_reason",
|
|
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
|
)
|
|
yield LettaStopReason(stop_reason=StopReasonType.error)
|
|
raise e
|
|
finally:
|
|
# Check if cancellation was signaled via shared event
|
|
if self.cancellation_event and self.cancellation_event.is_set():
|
|
self.stream_was_cancelled = True
|
|
|
|
logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}")
|
|
|
|
async def _process_event(
|
|
self,
|
|
event: GenerateContentResponse,
|
|
ttft_span: Optional["Span"] = None,
|
|
prev_message_type: Optional[str] = None,
|
|
message_index: int = 0,
|
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
|
# Every event has usage data + model info on it,
|
|
# so we can continually extract
|
|
self.model = event.model_version
|
|
self.message_id = event.response_id
|
|
usage_metadata = event.usage_metadata
|
|
if usage_metadata:
|
|
if usage_metadata.prompt_token_count:
|
|
self.input_tokens = usage_metadata.prompt_token_count
|
|
# Use candidates_token_count directly for output tokens.
|
|
# Do NOT use (total_token_count - prompt_token_count) as that incorrectly
|
|
# includes thinking/reasoning tokens which can be 10-100x the actual output.
|
|
if usage_metadata.candidates_token_count:
|
|
self.output_tokens = usage_metadata.candidates_token_count
|
|
# Capture cache token data (Gemini uses cached_content_token_count)
|
|
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
|
|
if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count is not None:
|
|
self.cached_tokens = usage_metadata.cached_content_token_count
|
|
# Capture thinking/reasoning token data (Gemini uses thoughts_token_count)
|
|
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
|
|
if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count is not None:
|
|
self.thinking_tokens = usage_metadata.thoughts_token_count
|
|
# Store raw usage for transparent provider trace logging
|
|
try:
|
|
self.raw_usage = (
|
|
usage_metadata.to_json_dict()
|
|
if hasattr(usage_metadata, "to_json_dict")
|
|
else {
|
|
"prompt_token_count": usage_metadata.prompt_token_count,
|
|
"candidates_token_count": usage_metadata.candidates_token_count,
|
|
"total_token_count": usage_metadata.total_token_count,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to capture raw_usage from Gemini: {e}")
|
|
self.raw_usage = None
|
|
|
|
if not event.candidates or len(event.candidates) == 0:
|
|
return
|
|
else:
|
|
# NOTE: should always be len 1
|
|
candidate = event.candidates[0]
|
|
|
|
if not candidate.content or not candidate.content.parts:
|
|
return
|
|
|
|
for part in candidate.content.parts:
|
|
# NOTE: the thought signature often comes after the thought text, eg with the tool call
|
|
if part.thought_signature:
|
|
# NOTE: the thought_signature comes on the Part with the function_call
|
|
thought_signature = part.thought_signature
|
|
self.thinking_signature = base64.b64encode(thought_signature).decode("utf-8")
|
|
# Don't emit empty reasoning message - signature will be attached to actual reasoning content
|
|
|
|
# Thinking summary content part (bool means text is thought part)
|
|
if part.thought:
|
|
reasoning_summary = part.text
|
|
# Only emit reasoning message if we have actual content
|
|
if reasoning_summary and reasoning_summary.strip():
|
|
if prev_message_type and prev_message_type != "reasoning_message":
|
|
message_index += 1
|
|
yield ReasoningMessage(
|
|
id=self.letta_message_id,
|
|
date=datetime.now(timezone.utc).isoformat(),
|
|
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
|
source="reasoner_model",
|
|
reasoning=reasoning_summary,
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
prev_message_type = "reasoning_message"
|
|
self.content_parts.append(
|
|
ReasoningContent(
|
|
is_native=True,
|
|
reasoning=reasoning_summary,
|
|
signature=self.thinking_signature,
|
|
)
|
|
)
|
|
|
|
# Plain text content part
|
|
elif part.text:
|
|
content = part.text
|
|
self._text_parts.append(content)
|
|
if prev_message_type and prev_message_type != "assistant_message":
|
|
message_index += 1
|
|
yield AssistantMessage(
|
|
id=self.letta_message_id,
|
|
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
|
date=datetime.now(timezone.utc),
|
|
content=content,
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
prev_message_type = "assistant_message"
|
|
self.content_parts.append(
|
|
TextContent(
|
|
text=content,
|
|
signature=self.thinking_signature,
|
|
)
|
|
)
|
|
|
|
# Tool call function part
|
|
# NOTE: in gemini, this comes all at once, and the args are JSON dict, not stringified
|
|
elif part.function_call:
|
|
function_call = part.function_call
|
|
|
|
# Look for call_id, name, and possibly arguments (though likely always empty string)
|
|
call_id = get_tool_call_id()
|
|
name = function_call.name
|
|
arguments = function_call.args # NOTE: dict, not str
|
|
arguments_str = json.dumps(arguments) # NOTE: use json_dumps?
|
|
|
|
self.tool_call_id = call_id
|
|
self.tool_call_name = name
|
|
self.tool_call_args = arguments
|
|
|
|
self.collected_tool_calls.append(ToolCall(id=call_id, function=FunctionCall(name=name, arguments=arguments_str)))
|
|
|
|
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
|
yield ApprovalRequestMessage(
|
|
id=decrement_message_uuid(self.letta_message_id),
|
|
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
|
date=datetime.now(timezone.utc),
|
|
tool_call=ToolCallDelta(
|
|
name=name,
|
|
arguments=arguments_str,
|
|
tool_call_id=call_id,
|
|
),
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
else:
|
|
if prev_message_type and prev_message_type != "tool_call_message":
|
|
message_index += 1
|
|
tool_call_delta = ToolCallDelta(
|
|
name=name,
|
|
arguments=arguments_str,
|
|
tool_call_id=call_id,
|
|
)
|
|
yield ToolCallMessage(
|
|
id=self.letta_message_id,
|
|
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
|
date=datetime.now(timezone.utc),
|
|
tool_call=tool_call_delta,
|
|
tool_calls=tool_call_delta,
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
prev_message_type = "tool_call_message"
|
|
self.content_parts.append(
|
|
ToolCallContent(
|
|
id=call_id,
|
|
name=name,
|
|
input=arguments,
|
|
signature=self.thinking_signature,
|
|
)
|
|
)
|