Files
letta-server/letta/interfaces/gemini_streaming_interface.py
cthomas c162de5127 fix: use shared event + .athrow() to properly set stream_was_cancelle… (#9019)
fix: use shared event + .athrow() to properly set stream_was_cancelled flag

**Problem:**
When a run is cancelled via /cancel endpoint, `stream_was_cancelled` remained
False because `RunCancelledException` was raised in the consumer code (wrapper),
which closes the generator from outside. This causes Python to skip the
generator's except blocks and jump directly to finally with the wrong flag value.

**Solution:**
1. Shared `asyncio.Event` registry for cross-layer cancellation signaling
2. `cancellation_aware_stream_wrapper` sets the event when cancellation detected
3. Wrapper uses `.athrow()` to inject exception INTO generator (not consumer-side raise)
4. All streaming interfaces check event in `finally` block to set flag correctly
5. `streaming_service.py` handles `RunCancelledException` gracefully, yields [DONE]

**Changes:**
- streaming_response.py: Event registry + .athrow() injection + graceful handling
- openai_streaming_interface.py: 3 classes check event in finally
- gemini_streaming_interface.py: Check event in finally
- anthropic_*.py: Catch RunCancelledException
- simple_llm_stream_adapter.py: Create & pass event to interfaces
- streaming_service.py: Handle RunCancelledException, yield [DONE], skip double-update
- routers/v1/{conversations,runs}.py: Pass event to wrapper
- integration_test_human_in_the_loop.py: New test for approval + cancellation

**Tests:**
- test_tool_call with cancellation (OpenAI models) 
- test_approve_with_cancellation (approval flow + concurrent cancel) 

**Known cosmetic warnings (pre-existing):**
- "Run already in terminal state" - agent loop tries to update after /cancel
- "Stream ended without terminal event" - background streaming timing race

👾 Generated with [Letta Code](https://letta.com)

Co-authored-by: Letta <noreply@letta.com>
2026-01-29 12:44:04 -08:00

342 lines
15 KiB
Python

import asyncio
import base64
import json
from collections.abc import AsyncGenerator
from datetime import datetime, timezone
from typing import AsyncIterator, List, Optional
from google.genai.types import (
GenerateContentResponse,
)
from letta.log import get_logger
from letta.schemas.letta_message import (
ApprovalRequestMessage,
AssistantMessage,
LettaMessage,
ReasoningMessage,
ToolCallDelta,
ToolCallMessage,
)
from letta.schemas.letta_message_content import (
ReasoningContent,
TextContent,
ToolCallContent,
)
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.streaming_response import RunCancelledException
from letta.server.rest_api.utils import decrement_message_uuid
from letta.utils import get_tool_call_id
logger = get_logger(__name__)
class SimpleGeminiStreamingInterface:
"""
Encapsulates the logic for streaming responses from Gemini API:
https://ai.google.dev/gemini-api/docs/text-generation#streaming-responses
"""
def __init__(
self,
requires_approval_tools: list = [],
run_id: str | None = None,
step_id: str | None = None,
cancellation_event: Optional["asyncio.Event"] = None,
):
self.run_id = run_id
self.step_id = step_id
self.cancellation_event = cancellation_event
# self.messages = messages
# self.tools = tools
self.requires_approval_tools = requires_approval_tools
# ID responses used
self.message_id = None
# In Gemini streaming, tool call comes all at once
self.tool_call_id: str | None = None
self.tool_call_name: str | None = None
self.tool_call_args: dict | None = None # NOTE: Not a str!
self.collected_tool_calls: list[ToolCall] = []
# NOTE: signature only is included if tools are present
self.thinking_signature: str | None = None
# Regular text content too (avoid O(n^2) by accumulating parts)
self._text_parts: list[str] = []
self.text_content: str | None = None # legacy; not used elsewhere
# Premake IDs for database writes
self.letta_message_id = Message.generate_id()
# self.model = model
# Sadly, Gemini's encrypted reasoning logic forces us to store stream parts in state
self.content_parts: List[ReasoningContent | TextContent | ToolCallContent] = []
# Token counters
self.input_tokens = 0
self.output_tokens = 0
# Cache token tracking (Gemini uses cached_content_token_count)
# None means "not reported by provider", 0 means "provider reported 0"
self.cached_tokens: int | None = None
# Thinking/reasoning token tracking (Gemini uses thoughts_token_count)
# None means "not reported by provider", 0 means "provider reported 0"
self.thinking_tokens: int | None = None
# Raw usage from provider (for transparent logging in provider trace)
self.raw_usage: dict | None = None
# Track cancellation status
self.stream_was_cancelled: bool = False
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
"""This is (unusually) in chunked format, instead of merged"""
for content in self.content_parts:
if isinstance(content, ReasoningContent):
# This assumes there is only one signature per turn
content.signature = self.thinking_signature
return self.content_parts
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
if self.collected_tool_calls:
return self.collected_tool_calls[-1]
if self.tool_call_id is None:
raise ValueError("No tool call ID available")
if self.tool_call_name is None:
raise ValueError("No tool call name available")
if self.tool_call_args is None:
raise ValueError("No tool call arguments available")
tool_call_args_str = json.dumps(self.tool_call_args)
return ToolCall(id=self.tool_call_id, function=FunctionCall(name=self.tool_call_name, arguments=tool_call_args_str))
def get_tool_call_objects(self) -> list[ToolCall]:
"""Return all finalized tool calls collected during this message (parallel supported)."""
return list(self.collected_tool_calls)
async def process(
self,
stream: AsyncIterator[GenerateContentResponse],
ttft_span: Optional["Span"] = None,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
"""
Iterates over the Gemini stream, yielding SSE events.
It also collects tokens and detects if a tool call is triggered.
"""
prev_message_type = None
message_index = 0
try:
async for event in stream:
try:
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
if prev_message_type != None:
message_index += 1
prev_message_type = new_message_type
yield message
except (asyncio.CancelledError, RunCancelledException) as e:
import traceback
logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc())
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
new_message_type = message.message_type
if new_message_type != prev_message_type:
if prev_message_type != None:
message_index += 1
prev_message_type = new_message_type
yield message
# Don't raise the exception here
continue
except Exception as e:
import traceback
logger.exception("Error processing stream: %s", e)
if ttft_span:
ttft_span.add_event(
name="stop_reason",
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
)
yield LettaStopReason(stop_reason=StopReasonType.error)
raise e
finally:
# Check if cancellation was signaled via shared event
if self.cancellation_event and self.cancellation_event.is_set():
self.stream_was_cancelled = True
logger.info(f"GeminiStreamingInterface: Stream processing complete. stream was cancelled: {self.stream_was_cancelled}")
async def _process_event(
self,
event: GenerateContentResponse,
ttft_span: Optional["Span"] = None,
prev_message_type: Optional[str] = None,
message_index: int = 0,
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
# Every event has usage data + model info on it,
# so we can continually extract
self.model = event.model_version
self.message_id = event.response_id
usage_metadata = event.usage_metadata
if usage_metadata:
if usage_metadata.prompt_token_count:
self.input_tokens = usage_metadata.prompt_token_count
# Use candidates_token_count directly for output tokens.
# Do NOT use (total_token_count - prompt_token_count) as that incorrectly
# includes thinking/reasoning tokens which can be 10-100x the actual output.
if usage_metadata.candidates_token_count:
self.output_tokens = usage_metadata.candidates_token_count
# Capture cache token data (Gemini uses cached_content_token_count)
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count is not None:
self.cached_tokens = usage_metadata.cached_content_token_count
# Capture thinking/reasoning token data (Gemini uses thoughts_token_count)
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count is not None:
self.thinking_tokens = usage_metadata.thoughts_token_count
# Store raw usage for transparent provider trace logging
try:
self.raw_usage = (
usage_metadata.to_json_dict()
if hasattr(usage_metadata, "to_json_dict")
else {
"prompt_token_count": usage_metadata.prompt_token_count,
"candidates_token_count": usage_metadata.candidates_token_count,
"total_token_count": usage_metadata.total_token_count,
}
)
except Exception as e:
logger.error(f"Failed to capture raw_usage from Gemini: {e}")
self.raw_usage = None
if not event.candidates or len(event.candidates) == 0:
return
else:
# NOTE: should always be len 1
candidate = event.candidates[0]
if not candidate.content or not candidate.content.parts:
return
for part in candidate.content.parts:
# NOTE: the thought signature often comes after the thought text, eg with the tool call
if part.thought_signature:
# NOTE: the thought_signature comes on the Part with the function_call
thought_signature = part.thought_signature
self.thinking_signature = base64.b64encode(thought_signature).decode("utf-8")
# Don't emit empty reasoning message - signature will be attached to actual reasoning content
# Thinking summary content part (bool means text is thought part)
if part.thought:
reasoning_summary = part.text
# Only emit reasoning message if we have actual content
if reasoning_summary and reasoning_summary.strip():
if prev_message_type and prev_message_type != "reasoning_message":
message_index += 1
yield ReasoningMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
source="reasoner_model",
reasoning=reasoning_summary,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "reasoning_message"
self.content_parts.append(
ReasoningContent(
is_native=True,
reasoning=reasoning_summary,
signature=self.thinking_signature,
)
)
# Plain text content part
elif part.text:
content = part.text
self._text_parts.append(content)
if prev_message_type and prev_message_type != "assistant_message":
message_index += 1
yield AssistantMessage(
id=self.letta_message_id,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
date=datetime.now(timezone.utc),
content=content,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "assistant_message"
self.content_parts.append(
TextContent(
text=content,
signature=self.thinking_signature,
)
)
# Tool call function part
# NOTE: in gemini, this comes all at once, and the args are JSON dict, not stringified
elif part.function_call:
function_call = part.function_call
# Look for call_id, name, and possibly arguments (though likely always empty string)
call_id = get_tool_call_id()
name = function_call.name
arguments = function_call.args # NOTE: dict, not str
arguments_str = json.dumps(arguments) # NOTE: use json_dumps?
self.tool_call_id = call_id
self.tool_call_name = name
self.tool_call_args = arguments
self.collected_tool_calls.append(ToolCall(id=call_id, function=FunctionCall(name=name, arguments=arguments_str)))
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
yield ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=name,
arguments=arguments_str,
tool_call_id=call_id,
),
run_id=self.run_id,
step_id=self.step_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_delta = ToolCallDelta(
name=name,
arguments=arguments_str,
tool_call_id=call_id,
)
yield ToolCallMessage(
id=self.letta_message_id,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
date=datetime.now(timezone.utc),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = "tool_call_message"
self.content_parts.append(
ToolCallContent(
id=call_id,
name=name,
input=arguments,
signature=self.thinking_signature,
)
)