Files
letta-server/letta/llm_api/google_vertex_client.py
Kian Jones b9c4ed3b15 fix: catch contextwindowexceeded error on gemini (#9450)
* catch contextwindowexceeded error

* fix(core): detect Google token limit errors as ContextWindowExceededError

Google's error message says "input token count exceeds the maximum
number of tokens allowed" which doesn't contain the word "context",
so it was falling through to generic LLMBadRequestError instead of
ContextWindowExceededError. This means compaction won't auto-trigger.

Expands the detection to also match "token count" and "tokens allowed"
in addition to the existing "context" keyword.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix(core): add missing message arg to LLMBadRequestError in OpenAI client

The generic 400 path in handle_llm_error was constructing
LLMBadRequestError without the required message positional arg,
causing TypeError in prod during summarization.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* ci: add adapters/ test suite to core unit test matrix

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix(tests): update adapter error handling test expectations to match actual behavior

The streaming adapter's error handling double-wraps errors: the
AnthropicStreamingInterface calls handle_llm_error first, then the
adapter catches the result and calls handle_llm_error again, which
falls through to the base class LLMError. Updated test expectations
to match this behavior.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix(core): prevent double-wrapping of LLMError in stream adapter

The AnthropicStreamingInterface.process() already transforms raw
provider errors into LLMError subtypes via handle_llm_error. The
adapter was catching the result and calling handle_llm_error again,
which didn't recognize the already-transformed LLMError and wrapped
it in a generic LLMError("Unhandled LLM error"). This downgraded
specific error types (LLMConnectionError, LLMServerError, etc.)
and broke retry logic that matches on specific subtypes.

Now the adapter checks if the error is already an LLMError and
re-raises it as-is. Tests restored to original correct expectations.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

---------

Co-authored-by: Letta <noreply@letta.com>
2026-02-24 10:52:07 -08:00

1101 lines
54 KiB
Python

import base64
import copy
import json
import uuid
from typing import AsyncIterator, List, Optional
import httpx
import pydantic_core
from google.genai import Client, errors
from google.genai.types import (
FunctionCallingConfig,
FunctionCallingConfigMode,
GenerateContentResponse,
HttpOptions,
ThinkingConfig,
ToolConfig,
)
from letta.constants import NON_USER_MSG_PREFIX
from letta.errors import (
ContextWindowExceededError,
ErrorCode,
LLMAuthenticationError,
LLMBadRequestError,
LLMConnectionError,
LLMNotFoundError,
LLMPermissionDeniedError,
LLMRateLimitError,
LLMServerError,
LLMTimeoutError,
LLMUnprocessableEntityError,
)
from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps, json_loads, sanitize_unicode_surrogates
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentType
from letta.schemas.enums import ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.schemas.usage import LettaUsageStatistics
from letta.settings import model_settings, settings
from letta.utils import get_tool_call_id
logger = get_logger(__name__)
class GoogleVertexClient(LLMClientBase):
MAX_RETRIES = model_settings.gemini_max_retries
provider_label = "Google Vertex"
def _get_client(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
if llm_config:
api_key, _, _ = self.get_byok_overrides(llm_config)
if api_key:
return Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
return Client(
vertexai=True,
project=model_settings.google_cloud_project,
location=model_settings.google_cloud_location,
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
)
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
if llm_config:
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
if api_key:
return Client(
api_key=api_key,
http_options=HttpOptions(timeout=timeout_ms),
)
return Client(
vertexai=True,
project=model_settings.google_cloud_project,
location=model_settings.google_cloud_location,
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
)
def _provider_prefix(self) -> str:
return f"[{self.provider_label}]"
def _provider_name(self) -> str:
return self.provider_label
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
try:
client = self._get_client(llm_config)
response = client.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
return response.model_dump()
except pydantic_core._pydantic_core.ValidationError as e:
# Handle Pydantic validation errors from the Google SDK
# This occurs when tool schemas contain unsupported fields
logger.error(
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
)
raise LLMBadRequestError(
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
f"Please check your tool definitions. Error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
except Exception as e:
raise self.handle_llm_error(e, llm_config=llm_config)
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
"""
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_client_async(llm_config)
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
# https://github.com/googleapis/python-aiplatform/issues/4472
retry_count = 1
should_retry = True
response_data = None
while should_retry and retry_count <= self.MAX_RETRIES:
try:
response = await client.aio.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
except pydantic_core._pydantic_core.ValidationError as e:
# Handle Pydantic validation errors from the Google SDK
# This occurs when tool schemas contain unsupported fields
logger.error(
f"Pydantic validation error when calling {self._provider_name()} API. "
f"Tool schema contains unsupported fields. Error: {e}"
)
raise LLMBadRequestError(
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
f"Please check your tool definitions. Error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
except errors.APIError as e:
# Retry on 503 and 500 errors as well, usually ephemeral from Gemini
if e.code == 503 or e.code == 500 or e.code == 504:
logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}")
retry_count += 1
if retry_count > self.MAX_RETRIES:
raise self.handle_llm_error(e, llm_config=llm_config)
continue
raise self.handle_llm_error(e, llm_config=llm_config)
except Exception as e:
raise self.handle_llm_error(e, llm_config=llm_config)
response_data = response.model_dump()
is_malformed_function_call = self.is_malformed_function_call(response_data)
if is_malformed_function_call:
logger.warning(
f"Received FinishReason.MALFORMED_FUNCTION_CALL in response for {llm_config.model}, retrying {retry_count}/{self.MAX_RETRIES}"
)
# Modify the last message if it's a heartbeat to include warning about special characters
if request_data["contents"] and len(request_data["contents"]) > 0:
last_message = request_data["contents"][-1]
if last_message.get("role") == "user" and last_message.get("parts"):
for part in last_message["parts"]:
if "text" in part:
try:
# Try to parse as JSON to check if it's a heartbeat
message_json = json_loads(part["text"])
if message_json.get("type") == "heartbeat" and "reason" in message_json:
# Append warning to the reason
warning = f" RETRY {retry_count}/{self.MAX_RETRIES} ***DO NOT USE SPECIAL CHARACTERS OR QUOTATIONS INSIDE FUNCTION CALL ARGUMENTS. IF YOU MUST, MAKE SURE TO ESCAPE THEM PROPERLY***"
message_json["reason"] = message_json["reason"] + warning
# Update the text with modified JSON
part["text"] = json_dumps(message_json)
logger.warning(
f"Modified heartbeat message with special character warning for retry {retry_count}/{self.MAX_RETRIES}"
)
except (json.JSONDecodeError, TypeError, AttributeError):
# Not a JSON message or not a heartbeat, skip modification
pass
should_retry = is_malformed_function_call
retry_count += 1
if response_data is None:
raise RuntimeError("Failed to get response data after all retries")
return response_data
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
request_data = sanitize_unicode_surrogates(request_data)
client = await self._get_client_async(llm_config)
try:
response = await client.aio.models.generate_content_stream(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
except pydantic_core._pydantic_core.ValidationError as e:
# Handle Pydantic validation errors from the Google SDK
# This occurs when tool schemas contain unsupported fields
logger.error(
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
)
raise LLMBadRequestError(
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
f"Please check your tool definitions. Error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
)
except errors.APIError as e:
raise self.handle_llm_error(e)
except Exception as e:
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
raise e
# Direct yield - keeps response alive in generator's local scope throughout iteration
# This is required because the SDK's connection lifecycle is tied to the response object
try:
async for chunk in response:
yield chunk
except errors.ClientError as e:
if e.code == 499:
logger.info(f"{self._provider_prefix()} Stream cancelled by client (499): {e}")
return
raise self.handle_llm_error(e, llm_config=llm_config)
except errors.APIError as e:
raise self.handle_llm_error(e, llm_config=llm_config)
@staticmethod
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
so there is no natural follow-up 'model' role message.
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
with role == 'model' that is placed in-betweeen and function output
(role == 'tool') and user message (role == 'user').
"""
dummy_yield_message = {
"role": "model",
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
}
messages_with_padding = []
for i, message in enumerate(messages):
messages_with_padding.append(message)
# Check if the current message role is 'tool' and the next message role is 'user'
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
messages_with_padding.append(dummy_yield_message)
return messages_with_padding
def _clean_google_ai_schema_properties(self, schema_part: dict):
"""Recursively clean schema parts to remove unsupported Google AI keywords."""
if not isinstance(schema_part, dict):
return
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
# * Only a subset of the OpenAPI schema is supported.
# * Supported parameter types in Python are limited.
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema", "const", "$ref"]
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
for key_to_remove in keys_to_remove_at_this_level:
logger.debug(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
del schema_part[key_to_remove]
if schema_part.get("type") == "string" and "format" in schema_part:
allowed_formats = ["enum", "date-time"]
if schema_part["format"] not in allowed_formats:
logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}")
del schema_part["format"]
# Check properties within the current level
if "properties" in schema_part and isinstance(schema_part["properties"], dict):
for prop_name, prop_schema in schema_part["properties"].items():
self._clean_google_ai_schema_properties(prop_schema)
# Check items within arrays
if "items" in schema_part and isinstance(schema_part["items"], dict):
self._clean_google_ai_schema_properties(schema_part["items"])
# Check within anyOf, allOf, oneOf lists
for key in ["anyOf", "allOf", "oneOf"]:
if key in schema_part and isinstance(schema_part[key], list):
for item_schema in schema_part[key]:
self._clean_google_ai_schema_properties(item_schema)
def _resolve_json_schema_refs(self, schema: dict, defs: dict = None) -> dict:
"""
Recursively resolve $ref in JSON schema by inlining definitions.
Google GenAI SDK does not support $ref.
"""
if defs is None:
# Look for definitions at the top level
defs = schema.get("$defs") or schema.get("definitions") or {}
if not isinstance(schema, dict):
return schema
# If this is a ref, resolve it
if "$ref" in schema:
ref = schema["$ref"]
if isinstance(ref, str):
for prefix in ("#/$defs/", "#/definitions/"):
if ref.startswith(prefix):
ref_name = ref.split("/")[-1]
if ref_name in defs:
resolved = defs[ref_name].copy()
return self._resolve_json_schema_refs(resolved, defs)
break
logger.warning(f"Could not resolve $ref '{ref}' in schema — will be stripped by schema cleaner")
# Recursively process children
new_schema = schema.copy()
# We need to remove $defs/definitions from the output schema as Google doesn't support them
if "$defs" in new_schema:
del new_schema["$defs"]
if "definitions" in new_schema:
del new_schema["definitions"]
for k, v in new_schema.items():
if isinstance(v, dict):
new_schema[k] = self._resolve_json_schema_refs(v, defs)
elif isinstance(v, list):
new_schema[k] = [self._resolve_json_schema_refs(i, defs) if isinstance(i, dict) else i for i in v]
return new_schema
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
"""
OpenAI style:
"tools": [{
"type": "function",
"function": {
"name": "find_movies",
"description": "find ....",
"parameters": {
"type": "object",
"properties": {
PARAM: {
"type": PARAM_TYPE, # eg "string"
"description": PARAM_DESCRIPTION,
},
...
},
"required": List[str],
}
}
}
]
Google AI style:
"tools": [{
"functionDeclarations": [{
"name": "find_movies",
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
"parameters": {
"type": "OBJECT",
"properties": {
"location": {
"type": "STRING",
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
},
"description": {
"type": "STRING",
"description": "Any kind of description including category or genre, title words, attributes, etc."
}
},
"required": ["description"]
}
}, {
"name": "find_theaters",
...
"""
function_list = [
dict(
name=t.function.name,
description=t.function.description,
# Deep copy parameters to avoid modifying the original Tool object
parameters=copy.deepcopy(t.function.parameters) if t.function.parameters else {},
)
for t in tools
]
# Add inner thoughts if needed
for func in function_list:
# Note: Google AI API used to have weird casing requirements, but not any more
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
if "parameters" in func and isinstance(func["parameters"], dict):
# Resolve $ref in schema because Google AI SDK doesn't support them
func["parameters"] = self._resolve_json_schema_refs(func["parameters"])
self._clean_google_ai_schema_properties(func["parameters"])
# Add inner thoughts
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX
func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = {
"type": "string",
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
}
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX)
return [{"functionDeclarations": function_list}]
@trace_method
def build_request_data(
self,
agent_type: AgentType, # if react, use native content + strip heartbeats
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: List[dict],
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
"""
Constructs a request object in the expected data format for this client.
"""
# NOTE: forcing inner thoughts in kwargs off
if agent_type == AgentType.letta_v1_agent:
llm_config.put_inner_thoughts_in_kwargs = False
if tools:
tool_objs = [Tool(type="function", function=t) for t in tools]
tool_names = [t.function.name for t in tool_objs]
# Convert to the exact payload style Google expects
formatted_tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
else:
formatted_tools = []
tool_names = []
contents = self.add_dummy_model_messages(
PydanticMessage.to_google_dicts_from_list(
messages,
current_model=llm_config.model,
put_inner_thoughts_in_kwargs=False if agent_type == AgentType.letta_v1_agent else True,
native_content=True if agent_type == AgentType.letta_v1_agent else False,
),
)
request_data = {
"contents": contents,
"config": {
"temperature": llm_config.temperature,
"tools": formatted_tools,
},
}
# Make tokens is optional
if llm_config.max_tokens:
request_data["config"]["max_output_tokens"] = llm_config.max_tokens
if len(tool_names) == 1 and settings.use_vertex_structured_outputs_experimental:
request_data["config"]["response_mime_type"] = "application/json"
request_data["config"]["response_schema"] = self.get_function_call_response_schema(tools[0])
del request_data["config"]["tools"]
elif tools:
if agent_type == AgentType.letta_v1_agent:
# don't require tools
tool_call_mode = FunctionCallingConfigMode.AUTO
tool_config = ToolConfig(
function_calling_config=FunctionCallingConfig(
mode=tool_call_mode,
)
)
else:
# require tools
tool_call_mode = FunctionCallingConfigMode.ANY
tool_config = ToolConfig(
function_calling_config=FunctionCallingConfig(
mode=tool_call_mode,
# Provide the list of tools (though empty should also work, it seems not to)
allowed_function_names=tool_names,
)
)
request_data["config"]["tool_config"] = tool_config.model_dump()
# https://ai.google.dev/gemini-api/docs/thinking#set-budget
# 2.5 Pro
# - Default: dynamic thinking
# - Dynamic thinking that cannot be disabled
# - Range: -1 (for dynamic), or 128-32768
# 2.5 Flash
# - Default: dynamic thinking
# - Dynamic thinking that *can* be disabled
# - Range: -1, 0, or 0-24576
# 2.5 Flash Lite
# - Default: no thinking
# - Dynamic thinking that *can* be disabled
# - Range: -1, 0, or 512-24576
# TODO when using v3 agent loop, properly support the native thinking in Gemini
# Add thinking_config for all Gemini reasoning models (2.5 series)
# If enable_reasoner is False, set thinking_budget to 0
# Otherwise, use the value from max_reasoning_tokens
if self.is_reasoning_model(llm_config) or "flash" in llm_config.model:
if llm_config.model.startswith("gemini-3"):
# letting thinking_level to default to high by not specifying thinking_budget
thinking_config = ThinkingConfig(include_thoughts=True)
else:
# Gemini reasoning models may fail to call tools even with FunctionCallingConfigMode.ANY if thinking is fully disabled, set to minimum to prevent tool call failure
thinking_budget = (
llm_config.max_reasoning_tokens if llm_config.enable_reasoner else self.get_thinking_budget(llm_config.model)
)
if thinking_budget <= 0:
logger.warning(
f"Thinking budget of {thinking_budget} for Gemini reasoning model {llm_config.model}, this will likely cause tool call failures"
)
# For models that require thinking mode (2.5 Pro, 3.x), override with minimum valid budget
if llm_config.model.startswith("gemini-2.5-pro"):
thinking_budget = 128
logger.warning(
f"Overriding thinking_budget to {thinking_budget} for model {llm_config.model} which requires thinking mode"
)
thinking_config = ThinkingConfig(
thinking_budget=(thinking_budget),
include_thoughts=(thinking_budget > 1),
)
request_data["config"]["thinking_config"] = thinking_config.model_dump()
return request_data
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
"""Extract usage statistics from Gemini response and return as LettaUsageStatistics."""
if not response_data:
return LettaUsageStatistics()
response = GenerateContentResponse(**response_data)
if not response.usage_metadata:
return LettaUsageStatistics()
cached_tokens = None
if (
hasattr(response.usage_metadata, "cached_content_token_count")
and response.usage_metadata.cached_content_token_count is not None
):
cached_tokens = response.usage_metadata.cached_content_token_count
reasoning_tokens = None
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
reasoning_tokens = response.usage_metadata.thoughts_token_count
return LettaUsageStatistics(
prompt_tokens=response.usage_metadata.prompt_token_count or 0,
completion_tokens=response.usage_metadata.candidates_token_count or 0,
total_tokens=response.usage_metadata.total_token_count or 0,
cached_input_tokens=cached_tokens,
reasoning_tokens=reasoning_tokens,
)
@trace_method
async def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage],
llm_config: LLMConfig,
) -> ChatCompletionResponse:
"""
Converts custom response format from llm client into an OpenAI
ChatCompletionsResponse object.
Example:
{
"candidates": [
{
"content": {
"parts": [
{
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
}
]
}
}
],
"usageMetadata": {
"promptTokenCount": 9,
"candidatesTokenCount": 27,
"totalTokenCount": 36
}
}
"""
response = GenerateContentResponse(**response_data)
try:
choices = []
index = 0
for candidate in response.candidates:
content = candidate.content
if content is None or content.role is None or content.parts is None:
# This means the response is malformed like MALFORMED_FUNCTION_CALL
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}")
else:
raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}")
role = content.role
assert role == "model", f"Unknown role in response: {role}"
parts = content.parts
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
# so let's disable it for now
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
# To patch this, if we have multiple parts we can take the last one
# Unless parallel tool calling is enabled, in which case multiple parts may be intentional
if len(parts) > 1 and not llm_config.enable_reasoner and not llm_config.parallel_tool_calls:
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
# only truncate if reasoning is off and parallel tool calls are disabled
parts = [parts[-1]]
# TODO support parts / multimodal
# Parallel tool calling is now supported when llm_config.parallel_tool_calls is enabled
openai_response_message = None
for response_message in parts:
# Convert the actual message style to OpenAI style
if response_message.function_call:
function_call = response_message.function_call
function_name = function_call.name
function_args = function_call.args
assert isinstance(function_args, dict), function_args
# TODO this is kind of funky - really, we should be passing 'native_content' as a kwarg to fork behavior
inner_thoughts = response_message.text
if llm_config.put_inner_thoughts_in_kwargs:
# NOTE: this also involves stripping the inner monologue out of the function
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
f"Couldn't find inner thoughts in function args:\n{function_call}"
)
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
else:
pass
# inner_thoughts = None
# inner_thoughts = response_message.text
# Google AI API doesn't generate tool call IDs
tool_call = ToolCall(
id=get_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
),
)
if openai_response_message is None:
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
tool_calls=[tool_call],
)
if response_message.thought_signature:
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
openai_response_message.reasoning_content_signature = thought_signature
else:
openai_response_message.content = inner_thoughts
if openai_response_message.tool_calls is None:
openai_response_message.tool_calls = []
openai_response_message.tool_calls.append(tool_call)
if response_message.thought_signature:
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
openai_response_message.reasoning_content_signature = thought_signature
else:
if response_message.thought:
if openai_response_message is None:
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
reasoning_content=response_message.text,
)
else:
openai_response_message.reasoning_content = response_message.text
try:
# Structured output tool call
function_call = json_loads(response_message.text)
# Access dict keys - will raise TypeError/KeyError if not a dict or missing keys
function_name = function_call["name"]
function_args = function_call["args"]
assert isinstance(function_args, dict), function_args
# NOTE: this also involves stripping the inner monologue out of the function
if llm_config.put_inner_thoughts_in_kwargs:
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
f"Couldn't find inner thoughts in function args:\n{function_call}"
)
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
else:
inner_thoughts = None
# Google AI API doesn't generate tool call IDs
tool_call = ToolCall(
id=get_tool_call_id(),
type="function",
function=FunctionCall(
name=function_name,
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
),
)
if openai_response_message is None:
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
tool_calls=[tool_call],
)
else:
openai_response_message.content = inner_thoughts
if openai_response_message.tool_calls is None:
openai_response_message.tool_calls = []
openai_response_message.tool_calls.append(tool_call)
except (json.decoder.JSONDecodeError, ValueError, TypeError, KeyError, AssertionError) as e:
if candidate.finish_reason == "MAX_TOKENS":
raise LLMServerError("Could not parse response data from LLM: exceeded max token limit")
# Log the parsing error for debugging
logger.warning(
f"Failed to parse structured output from LLM response: {e}. Response text: {response_message.text[:500]}"
)
# Inner thoughts are the content by default
inner_thoughts = response_message.text
# Google AI API doesn't generate tool call IDs
if openai_response_message is None:
openai_response_message = Message(
role="assistant", # NOTE: "model" -> "assistant"
content=inner_thoughts,
)
else:
openai_response_message.content = inner_thoughts
if response_message.thought_signature:
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
openai_response_message.reasoning_content_signature = thought_signature
# Google AI API uses different finish reason strings than OpenAI
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
finish_reason = candidate.finish_reason.value
if finish_reason == "STOP":
openai_finish_reason = (
"function_call"
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
else "stop"
)
elif finish_reason == "MAX_TOKENS":
openai_finish_reason = "length"
elif finish_reason == "SAFETY":
openai_finish_reason = "content_filter"
elif finish_reason == "RECITATION":
openai_finish_reason = "content_filter"
else:
raise LLMServerError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
choices.append(
Choice(
finish_reason=openai_finish_reason,
index=index,
message=openai_response_message,
)
)
index += 1
# if len(choices) > 1:
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
# "usageMetadata": {
# "promptTokenCount": 9,
# "candidatesTokenCount": 27,
# "totalTokenCount": 36
# }
usage = None
if response.usage_metadata:
# Extract usage via centralized method
from letta.schemas.enums import ProviderType
usage = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.google_ai)
else:
# Count it ourselves using the Gemini token counting API
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
google_messages = PydanticMessage.to_google_dicts_from_list(input_messages, current_model=llm_config.model)
prompt_tokens = await self.count_tokens(messages=google_messages, model=llm_config.model)
# For completion tokens, wrap the response content in Google format
completion_content = [{"role": "model", "parts": [{"text": json_dumps(openai_response_message.model_dump())}]}]
completion_tokens = await self.count_tokens(messages=completion_content, model=llm_config.model)
total_tokens = prompt_tokens + completion_tokens
usage = UsageStatistics(
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
)
response_id = str(uuid.uuid4())
return ChatCompletionResponse(
id=response_id,
choices=choices,
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
created=get_utc_time_int(),
usage=usage,
)
except KeyError as e:
raise e
def get_function_call_response_schema(self, tool: dict) -> dict:
return {
"type": "OBJECT",
"properties": {
"name": {"type": "STRING", "enum": [tool["name"]]},
"args": {
"type": "OBJECT",
"properties": tool["parameters"]["properties"],
"required": tool["parameters"]["required"],
},
},
"propertyOrdering": ["name", "args"],
"required": ["name", "args"],
}
# https://ai.google.dev/gemini-api/docs/thinking#set-budget
# | Model | Default setting | Range | Disable thinking | Turn on dynamic thinking|
# |-----------------|-------------------------------------------------------------------|--------------|----------------------------|-------------------------|
# | 2.5 Pro | Dynamic thinking: Model decides when and how much to think | 128-32768 | N/A: Cannot disable | thinkingBudget = -1 |
# | 2.5 Flash | Dynamic thinking: Model decides when and how much to think | 0-24576 | thinkingBudget = 0 | thinkingBudget = -1 |
# | 2.5 Flash Lite | Model does not think | 512-24576 | thinkingBudget = 0 | thinkingBudget = -1 |
# | 3.x | Dynamic thinking: Model decides when and how much to think | 128-? | N/A: Cannot disable | thinkingBudget = -1 |
def get_thinking_budget(self, model: str) -> bool:
if model_settings.gemini_force_minimum_thinking_budget:
if all(substring in model for substring in ["2.5", "flash", "lite"]):
return 512
elif all(substring in model for substring in ["2.5", "flash"]):
return 1
# Gemini 3 and 2.5 Pro require thinking mode and cannot have budget 0
if model.startswith("gemini-3") or model.startswith("gemini-2.5-pro"):
return 128 # Minimum valid budget for models that require thinking
return 0
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return (
llm_config.model.startswith("gemini-2.5-flash")
or llm_config.model.startswith("gemini-2.5-pro")
or llm_config.model.startswith("gemini-3")
)
def is_malformed_function_call(self, response_data: dict) -> dict:
response = GenerateContentResponse(**response_data)
for candidate in response.candidates:
content = candidate.content
if content is None or content.role is None or content.parts is None:
return candidate.finish_reason == "MALFORMED_FUNCTION_CALL"
return False
@trace_method
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
# Handle Google GenAI specific errors
if isinstance(e, errors.ClientError):
if e.code == 499:
logger.info(f"{self._provider_prefix()} Request cancelled by client (499): {e}")
return LLMConnectionError(
message=f"Request to {self._provider_name()} was cancelled (client disconnected): {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"status_code": 499, "cause": "client_cancelled", "is_byok": is_byok},
)
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
# Handle specific error codes
if e.code == 400:
error_str = str(e).lower()
if ("context" in error_str or "token count" in error_str or "tokens allowed" in error_str) and (
"exceed" in error_str or "limit" in error_str or "too long" in error_str
):
return ContextWindowExceededError(
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
details={"is_byok": is_byok},
)
else:
return LLMBadRequestError(
message=f"Bad request to {self._provider_name()}: {str(e)}",
code=ErrorCode.INVALID_ARGUMENT,
details={"is_byok": is_byok},
)
elif e.code == 401:
return LLMAuthenticationError(
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 403:
return LLMPermissionDeniedError(
message=f"Permission denied by {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 404:
return LLMNotFoundError(
message=f"Resource not found in {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 408:
return LLMTimeoutError(
message=f"Request to {self._provider_name()} timed out: {str(e)}",
code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
elif e.code == 422:
return LLMUnprocessableEntityError(
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"is_byok": is_byok},
)
elif e.code == 429:
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
return LLMRateLimitError(
message=f"Rate limited by {self._provider_name()}: {str(e)}",
code=ErrorCode.RATE_LIMIT_EXCEEDED,
details={"is_byok": is_byok},
)
else:
return LLMServerError(
message=f"{self._provider_name()} client error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
if isinstance(e, errors.ServerError):
logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}")
# Handle specific server error codes
if e.code == 500:
return LLMServerError(
message=f"{self._provider_name()} internal server error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
elif e.code == 502:
return LLMConnectionError(
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
elif e.code == 503:
return LLMServerError(
message=f"{self._provider_name()} service unavailable: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
elif e.code == 504:
return LLMTimeoutError(
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
code=ErrorCode.TIMEOUT,
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
else:
return LLMServerError(
message=f"{self._provider_name()} server error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
if isinstance(e, errors.APIError):
logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}")
return LLMServerError(
message=f"{self._provider_name()} API error: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={
"status_code": e.code,
"response_json": getattr(e, "response_json", None),
"is_byok": is_byok,
},
)
# Handle httpx.RemoteProtocolError which can occur during streaming
# when the remote server closes the connection unexpectedly
# (e.g., "peer closed connection without sending complete message body")
if isinstance(e, httpx.RemoteProtocolError):
logger.warning(f"{self._provider_prefix()} Remote protocol error during streaming: {e}")
return LLMConnectionError(
message=f"Connection error during {self._provider_name()} streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Handle httpx network errors which can occur during streaming
# when the connection is unexpectedly closed while reading/writing
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
logger.warning(f"{self._provider_prefix()} Network error during streaming: {type(e).__name__}: {e}")
return LLMConnectionError(
message=f"Network error during {self._provider_name()} streaming: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
)
# Handle connection-related errors
if "connection" in str(e).lower() or "timeout" in str(e).lower():
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
return LLMConnectionError(
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
code=ErrorCode.INTERNAL_SERVER_ERROR,
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
)
# Fallback to base implementation for other errors
return super().handle_llm_error(e, llm_config=llm_config)
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
"""
Count tokens for the given messages and tools using the Gemini token counting API.
Args:
messages: List of message dicts in Google AI format (with 'role' and 'parts' keys)
model: The model to use for token counting (defaults to gemini-2.0-flash-lite)
tools: List of OpenAI-style Tool objects to include in the count
Returns:
The total token count for the input
"""
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
client = self._get_client()
# Default model for token counting if not specified
count_model = model or GOOGLE_MODEL_FOR_API_KEY_CHECK
# Build the contents parameter
# If no messages provided, use empty string (like the API key check)
if messages is None or len(messages) == 0:
contents = ""
else:
# Messages should already be in Google format (role + parts)
contents = messages
try:
# Count message tokens
result = await client.aio.models.count_tokens(
model=count_model,
contents=contents,
)
total_tokens = result.total_tokens
# Count tool tokens separately by serializing to text
# The Gemini count_tokens API doesn't support a tools parameter directly
if tools and len(tools) > 0:
# Serialize tools to JSON text and count those tokens
tools_text = json.dumps([t.model_dump() for t in tools])
tools_result = await client.aio.models.count_tokens(
model=count_model,
contents=tools_text,
)
total_tokens += tools_result.total_tokens
except Exception as e:
raise self.handle_llm_error(e)
return total_tokens