* fix: prevent empty reasoning messages in streaming interfaces Prevents empty "Thinking..." indicators from appearing in clients by filtering out reasoning messages with no content at the source. Changes: - Gemini: Don't emit ReasoningMessage when only thought_signature exists - Gemini: Only emit reasoning content if text is non-empty - Anthropic: Don't emit ReasoningMessage for BetaSignatureDelta - Anthropic: Only emit reasoning content if thinking text is non-empty This fixes the issue where providers send signature metadata before actual thinking content, causing empty reasoning blocks to appear in the UI after responses complete. Affects: Gemini reasoning, Anthropic extended thinking 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: handle Anthropic thinking signature correctly - Only include 'signature' in Anthropic message payload if it is not None (fixes BadRequestError). - Capture and attach 'signature' to ReasoningMessage in streaming interface. * fix(anthropic): attach signature to last reasoning message in stream --------- Co-authored-by: Letta <noreply@letta.com>
332 lines
14 KiB
Python
332 lines
14 KiB
Python
import asyncio
|
|
import base64
|
|
import json
|
|
from collections.abc import AsyncGenerator
|
|
from datetime import datetime, timezone
|
|
from typing import AsyncIterator, List, Optional
|
|
|
|
from google.genai.types import (
|
|
GenerateContentResponse,
|
|
)
|
|
|
|
from letta.log import get_logger
|
|
from letta.schemas.letta_message import (
|
|
ApprovalRequestMessage,
|
|
AssistantMessage,
|
|
LettaMessage,
|
|
ReasoningMessage,
|
|
ToolCallDelta,
|
|
ToolCallMessage,
|
|
)
|
|
from letta.schemas.letta_message_content import (
|
|
ReasoningContent,
|
|
TextContent,
|
|
ToolCallContent,
|
|
)
|
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
|
from letta.schemas.message import Message
|
|
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
|
from letta.server.rest_api.utils import decrement_message_uuid
|
|
from letta.utils import get_tool_call_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class SimpleGeminiStreamingInterface:
|
|
"""
|
|
Encapsulates the logic for streaming responses from Gemini API:
|
|
https://ai.google.dev/gemini-api/docs/text-generation#streaming-responses
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
requires_approval_tools: list = [],
|
|
run_id: str | None = None,
|
|
step_id: str | None = None,
|
|
):
|
|
self.run_id = run_id
|
|
self.step_id = step_id
|
|
|
|
# self.messages = messages
|
|
# self.tools = tools
|
|
self.requires_approval_tools = requires_approval_tools
|
|
# ID responses used
|
|
self.message_id = None
|
|
|
|
# In Gemini streaming, tool call comes all at once
|
|
self.tool_call_id: str | None = None
|
|
self.tool_call_name: str | None = None
|
|
self.tool_call_args: dict | None = None # NOTE: Not a str!
|
|
|
|
self.collected_tool_calls: list[ToolCall] = []
|
|
|
|
# NOTE: signature only is included if tools are present
|
|
self.thinking_signature: str | None = None
|
|
|
|
# Regular text content too (avoid O(n^2) by accumulating parts)
|
|
self._text_parts: list[str] = []
|
|
self.text_content: str | None = None # legacy; not used elsewhere
|
|
|
|
# Premake IDs for database writes
|
|
self.letta_message_id = Message.generate_id()
|
|
# self.model = model
|
|
|
|
# Sadly, Gemini's encrypted reasoning logic forces us to store stream parts in state
|
|
self.content_parts: List[ReasoningContent | TextContent | ToolCallContent] = []
|
|
|
|
# Token counters
|
|
self.input_tokens = 0
|
|
self.output_tokens = 0
|
|
|
|
# Cache token tracking (Gemini uses cached_content_token_count)
|
|
# None means "not reported by provider", 0 means "provider reported 0"
|
|
self.cached_tokens: int | None = None
|
|
|
|
# Thinking/reasoning token tracking (Gemini uses thoughts_token_count)
|
|
# None means "not reported by provider", 0 means "provider reported 0"
|
|
self.thinking_tokens: int | None = None
|
|
|
|
# Raw usage from provider (for transparent logging in provider trace)
|
|
self.raw_usage: dict | None = None
|
|
|
|
def get_content(self) -> List[ReasoningContent | TextContent | ToolCallContent]:
|
|
"""This is (unusually) in chunked format, instead of merged"""
|
|
for content in self.content_parts:
|
|
if isinstance(content, ReasoningContent):
|
|
# This assumes there is only one signature per turn
|
|
content.signature = self.thinking_signature
|
|
return self.content_parts
|
|
|
|
def get_tool_call_object(self) -> ToolCall:
|
|
"""Useful for agent loop"""
|
|
if self.collected_tool_calls:
|
|
return self.collected_tool_calls[-1]
|
|
|
|
if self.tool_call_id is None:
|
|
raise ValueError("No tool call ID available")
|
|
if self.tool_call_name is None:
|
|
raise ValueError("No tool call name available")
|
|
if self.tool_call_args is None:
|
|
raise ValueError("No tool call arguments available")
|
|
|
|
tool_call_args_str = json.dumps(self.tool_call_args)
|
|
return ToolCall(id=self.tool_call_id, function=FunctionCall(name=self.tool_call_name, arguments=tool_call_args_str))
|
|
|
|
def get_tool_call_objects(self) -> list[ToolCall]:
|
|
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
|
return list(self.collected_tool_calls)
|
|
|
|
async def process(
|
|
self,
|
|
stream: AsyncIterator[GenerateContentResponse],
|
|
ttft_span: Optional["Span"] = None,
|
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
|
"""
|
|
Iterates over the Gemini stream, yielding SSE events.
|
|
It also collects tokens and detects if a tool call is triggered.
|
|
"""
|
|
prev_message_type = None
|
|
message_index = 0
|
|
try:
|
|
async for event in stream:
|
|
try:
|
|
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
|
new_message_type = message.message_type
|
|
if new_message_type != prev_message_type:
|
|
if prev_message_type != None:
|
|
message_index += 1
|
|
prev_message_type = new_message_type
|
|
yield message
|
|
except asyncio.CancelledError as e:
|
|
import traceback
|
|
|
|
logger.info("Cancelled stream attempt but overriding %s: %s", e, traceback.format_exc())
|
|
async for message in self._process_event(event, ttft_span, prev_message_type, message_index):
|
|
new_message_type = message.message_type
|
|
if new_message_type != prev_message_type:
|
|
if prev_message_type != None:
|
|
message_index += 1
|
|
prev_message_type = new_message_type
|
|
yield message
|
|
|
|
# Don't raise the exception here
|
|
continue
|
|
|
|
except Exception as e:
|
|
import traceback
|
|
|
|
logger.exception("Error processing stream: %s", e)
|
|
if ttft_span:
|
|
ttft_span.add_event(
|
|
name="stop_reason",
|
|
attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()},
|
|
)
|
|
yield LettaStopReason(stop_reason=StopReasonType.error)
|
|
raise e
|
|
finally:
|
|
logger.info("GeminiStreamingInterface: Stream processing complete.")
|
|
|
|
async def _process_event(
|
|
self,
|
|
event: GenerateContentResponse,
|
|
ttft_span: Optional["Span"] = None,
|
|
prev_message_type: Optional[str] = None,
|
|
message_index: int = 0,
|
|
) -> AsyncGenerator[LettaMessage | LettaStopReason, None]:
|
|
# Every event has usage data + model info on it,
|
|
# so we can continually extract
|
|
self.model = event.model_version
|
|
self.message_id = event.response_id
|
|
usage_metadata = event.usage_metadata
|
|
if usage_metadata:
|
|
if usage_metadata.prompt_token_count:
|
|
self.input_tokens = usage_metadata.prompt_token_count
|
|
# Use candidates_token_count directly for output tokens.
|
|
# Do NOT use (total_token_count - prompt_token_count) as that incorrectly
|
|
# includes thinking/reasoning tokens which can be 10-100x the actual output.
|
|
if usage_metadata.candidates_token_count:
|
|
self.output_tokens = usage_metadata.candidates_token_count
|
|
# Capture cache token data (Gemini uses cached_content_token_count)
|
|
# Use `is not None` to capture 0 values (meaning "provider reported 0 cached tokens")
|
|
if hasattr(usage_metadata, "cached_content_token_count") and usage_metadata.cached_content_token_count is not None:
|
|
self.cached_tokens = usage_metadata.cached_content_token_count
|
|
# Capture thinking/reasoning token data (Gemini uses thoughts_token_count)
|
|
# Use `is not None` to capture 0 values (meaning "provider reported 0 reasoning tokens")
|
|
if hasattr(usage_metadata, "thoughts_token_count") and usage_metadata.thoughts_token_count is not None:
|
|
self.thinking_tokens = usage_metadata.thoughts_token_count
|
|
# Store raw usage for transparent provider trace logging
|
|
try:
|
|
self.raw_usage = (
|
|
usage_metadata.to_json_dict()
|
|
if hasattr(usage_metadata, "to_json_dict")
|
|
else {
|
|
"prompt_token_count": usage_metadata.prompt_token_count,
|
|
"candidates_token_count": usage_metadata.candidates_token_count,
|
|
"total_token_count": usage_metadata.total_token_count,
|
|
}
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to capture raw_usage from Gemini: {e}")
|
|
self.raw_usage = None
|
|
|
|
if not event.candidates or len(event.candidates) == 0:
|
|
return
|
|
else:
|
|
# NOTE: should always be len 1
|
|
candidate = event.candidates[0]
|
|
|
|
if not candidate.content or not candidate.content.parts:
|
|
return
|
|
|
|
for part in candidate.content.parts:
|
|
# NOTE: the thought signature often comes after the thought text, eg with the tool call
|
|
if part.thought_signature:
|
|
# NOTE: the thought_signature comes on the Part with the function_call
|
|
thought_signature = part.thought_signature
|
|
self.thinking_signature = base64.b64encode(thought_signature).decode("utf-8")
|
|
# Don't emit empty reasoning message - signature will be attached to actual reasoning content
|
|
|
|
# Thinking summary content part (bool means text is thought part)
|
|
if part.thought:
|
|
reasoning_summary = part.text
|
|
# Only emit reasoning message if we have actual content
|
|
if reasoning_summary and reasoning_summary.strip():
|
|
if prev_message_type and prev_message_type != "reasoning_message":
|
|
message_index += 1
|
|
yield ReasoningMessage(
|
|
id=self.letta_message_id,
|
|
date=datetime.now(timezone.utc).isoformat(),
|
|
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
|
source="reasoner_model",
|
|
reasoning=reasoning_summary,
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
prev_message_type = "reasoning_message"
|
|
self.content_parts.append(
|
|
ReasoningContent(
|
|
is_native=True,
|
|
reasoning=reasoning_summary,
|
|
signature=self.thinking_signature,
|
|
)
|
|
)
|
|
|
|
# Plain text content part
|
|
elif part.text:
|
|
content = part.text
|
|
self._text_parts.append(content)
|
|
if prev_message_type and prev_message_type != "assistant_message":
|
|
message_index += 1
|
|
yield AssistantMessage(
|
|
id=self.letta_message_id,
|
|
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
|
date=datetime.now(timezone.utc),
|
|
content=content,
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
prev_message_type = "assistant_message"
|
|
self.content_parts.append(
|
|
TextContent(
|
|
text=content,
|
|
signature=self.thinking_signature,
|
|
)
|
|
)
|
|
|
|
# Tool call function part
|
|
# NOTE: in gemini, this comes all at once, and the args are JSON dict, not stringified
|
|
elif part.function_call:
|
|
function_call = part.function_call
|
|
|
|
# Look for call_id, name, and possibly arguments (though likely always empty string)
|
|
call_id = get_tool_call_id()
|
|
name = function_call.name
|
|
arguments = function_call.args # NOTE: dict, not str
|
|
arguments_str = json.dumps(arguments) # NOTE: use json_dumps?
|
|
|
|
self.tool_call_id = call_id
|
|
self.tool_call_name = name
|
|
self.tool_call_args = arguments
|
|
|
|
self.collected_tool_calls.append(ToolCall(id=call_id, function=FunctionCall(name=name, arguments=arguments_str)))
|
|
|
|
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
|
yield ApprovalRequestMessage(
|
|
id=decrement_message_uuid(self.letta_message_id),
|
|
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
|
date=datetime.now(timezone.utc),
|
|
tool_call=ToolCallDelta(
|
|
name=name,
|
|
arguments=arguments_str,
|
|
tool_call_id=call_id,
|
|
),
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
else:
|
|
if prev_message_type and prev_message_type != "tool_call_message":
|
|
message_index += 1
|
|
tool_call_delta = ToolCallDelta(
|
|
name=name,
|
|
arguments=arguments_str,
|
|
tool_call_id=call_id,
|
|
)
|
|
yield ToolCallMessage(
|
|
id=self.letta_message_id,
|
|
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
|
date=datetime.now(timezone.utc),
|
|
tool_call=tool_call_delta,
|
|
tool_calls=tool_call_delta,
|
|
run_id=self.run_id,
|
|
step_id=self.step_id,
|
|
)
|
|
prev_message_type = "tool_call_message"
|
|
self.content_parts.append(
|
|
ToolCallContent(
|
|
id=call_id,
|
|
name=name,
|
|
input=arguments,
|
|
signature=self.thinking_signature,
|
|
)
|
|
)
|