* catch contextwindowexceeded error * fix(core): detect Google token limit errors as ContextWindowExceededError Google's error message says "input token count exceeds the maximum number of tokens allowed" which doesn't contain the word "context", so it was falling through to generic LLMBadRequestError instead of ContextWindowExceededError. This means compaction won't auto-trigger. Expands the detection to also match "token count" and "tokens allowed" in addition to the existing "context" keyword. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix(core): add missing message arg to LLMBadRequestError in OpenAI client The generic 400 path in handle_llm_error was constructing LLMBadRequestError without the required message positional arg, causing TypeError in prod during summarization. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * ci: add adapters/ test suite to core unit test matrix 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix(tests): update adapter error handling test expectations to match actual behavior The streaming adapter's error handling double-wraps errors: the AnthropicStreamingInterface calls handle_llm_error first, then the adapter catches the result and calls handle_llm_error again, which falls through to the base class LLMError. Updated test expectations to match this behavior. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix(core): prevent double-wrapping of LLMError in stream adapter The AnthropicStreamingInterface.process() already transforms raw provider errors into LLMError subtypes via handle_llm_error. The adapter was catching the result and calling handle_llm_error again, which didn't recognize the already-transformed LLMError and wrapped it in a generic LLMError("Unhandled LLM error"). This downgraded specific error types (LLMConnectionError, LLMServerError, etc.) and broke retry logic that matches on specific subtypes. Now the adapter checks if the error is already an LLMError and re-raises it as-is. Tests restored to original correct expectations. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com>
1101 lines
54 KiB
Python
1101 lines
54 KiB
Python
import base64
|
|
import copy
|
|
import json
|
|
import uuid
|
|
from typing import AsyncIterator, List, Optional
|
|
|
|
import httpx
|
|
import pydantic_core
|
|
from google.genai import Client, errors
|
|
from google.genai.types import (
|
|
FunctionCallingConfig,
|
|
FunctionCallingConfigMode,
|
|
GenerateContentResponse,
|
|
HttpOptions,
|
|
ThinkingConfig,
|
|
ToolConfig,
|
|
)
|
|
|
|
from letta.constants import NON_USER_MSG_PREFIX
|
|
from letta.errors import (
|
|
ContextWindowExceededError,
|
|
ErrorCode,
|
|
LLMAuthenticationError,
|
|
LLMBadRequestError,
|
|
LLMConnectionError,
|
|
LLMNotFoundError,
|
|
LLMPermissionDeniedError,
|
|
LLMRateLimitError,
|
|
LLMServerError,
|
|
LLMTimeoutError,
|
|
LLMUnprocessableEntityError,
|
|
)
|
|
from letta.helpers.datetime_helpers import get_utc_time_int
|
|
from letta.helpers.json_helpers import json_dumps, json_loads, sanitize_unicode_surrogates
|
|
from letta.llm_api.llm_client_base import LLMClientBase
|
|
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
|
from letta.log import get_logger
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.agent import AgentType
|
|
from letta.schemas.enums import ProviderCategory
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool
|
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.settings import model_settings, settings
|
|
from letta.utils import get_tool_call_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class GoogleVertexClient(LLMClientBase):
|
|
MAX_RETRIES = model_settings.gemini_max_retries
|
|
provider_label = "Google Vertex"
|
|
|
|
def _get_client(self, llm_config: Optional[LLMConfig] = None):
|
|
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
|
if llm_config:
|
|
api_key, _, _ = self.get_byok_overrides(llm_config)
|
|
if api_key:
|
|
return Client(
|
|
api_key=api_key,
|
|
http_options=HttpOptions(timeout=timeout_ms),
|
|
)
|
|
return Client(
|
|
vertexai=True,
|
|
project=model_settings.google_cloud_project,
|
|
location=model_settings.google_cloud_location,
|
|
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
|
|
)
|
|
|
|
async def _get_client_async(self, llm_config: Optional[LLMConfig] = None):
|
|
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
|
if llm_config:
|
|
api_key, _, _ = await self.get_byok_overrides_async(llm_config)
|
|
if api_key:
|
|
return Client(
|
|
api_key=api_key,
|
|
http_options=HttpOptions(timeout=timeout_ms),
|
|
)
|
|
return Client(
|
|
vertexai=True,
|
|
project=model_settings.google_cloud_project,
|
|
location=model_settings.google_cloud_location,
|
|
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
|
|
)
|
|
|
|
def _provider_prefix(self) -> str:
|
|
return f"[{self.provider_label}]"
|
|
|
|
def _provider_name(self) -> str:
|
|
return self.provider_label
|
|
|
|
@trace_method
|
|
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying request to llm and returns raw response.
|
|
"""
|
|
try:
|
|
client = self._get_client(llm_config)
|
|
response = client.models.generate_content(
|
|
model=llm_config.model,
|
|
contents=request_data["contents"],
|
|
config=request_data["config"],
|
|
)
|
|
return response.model_dump()
|
|
except pydantic_core._pydantic_core.ValidationError as e:
|
|
# Handle Pydantic validation errors from the Google SDK
|
|
# This occurs when tool schemas contain unsupported fields
|
|
logger.error(
|
|
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
|
|
)
|
|
raise LLMBadRequestError(
|
|
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
|
|
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
|
|
f"Please check your tool definitions. Error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e, llm_config=llm_config)
|
|
|
|
@trace_method
|
|
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying request to llm and returns raw response.
|
|
"""
|
|
request_data = sanitize_unicode_surrogates(request_data)
|
|
|
|
client = await self._get_client_async(llm_config)
|
|
|
|
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
|
|
# https://github.com/googleapis/python-aiplatform/issues/4472
|
|
retry_count = 1
|
|
should_retry = True
|
|
response_data = None
|
|
while should_retry and retry_count <= self.MAX_RETRIES:
|
|
try:
|
|
response = await client.aio.models.generate_content(
|
|
model=llm_config.model,
|
|
contents=request_data["contents"],
|
|
config=request_data["config"],
|
|
)
|
|
except pydantic_core._pydantic_core.ValidationError as e:
|
|
# Handle Pydantic validation errors from the Google SDK
|
|
# This occurs when tool schemas contain unsupported fields
|
|
logger.error(
|
|
f"Pydantic validation error when calling {self._provider_name()} API. "
|
|
f"Tool schema contains unsupported fields. Error: {e}"
|
|
)
|
|
raise LLMBadRequestError(
|
|
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
|
|
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
|
|
f"Please check your tool definitions. Error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
except errors.APIError as e:
|
|
# Retry on 503 and 500 errors as well, usually ephemeral from Gemini
|
|
if e.code == 503 or e.code == 500 or e.code == 504:
|
|
logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}")
|
|
retry_count += 1
|
|
if retry_count > self.MAX_RETRIES:
|
|
raise self.handle_llm_error(e, llm_config=llm_config)
|
|
continue
|
|
raise self.handle_llm_error(e, llm_config=llm_config)
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e, llm_config=llm_config)
|
|
response_data = response.model_dump()
|
|
is_malformed_function_call = self.is_malformed_function_call(response_data)
|
|
if is_malformed_function_call:
|
|
logger.warning(
|
|
f"Received FinishReason.MALFORMED_FUNCTION_CALL in response for {llm_config.model}, retrying {retry_count}/{self.MAX_RETRIES}"
|
|
)
|
|
# Modify the last message if it's a heartbeat to include warning about special characters
|
|
if request_data["contents"] and len(request_data["contents"]) > 0:
|
|
last_message = request_data["contents"][-1]
|
|
if last_message.get("role") == "user" and last_message.get("parts"):
|
|
for part in last_message["parts"]:
|
|
if "text" in part:
|
|
try:
|
|
# Try to parse as JSON to check if it's a heartbeat
|
|
message_json = json_loads(part["text"])
|
|
if message_json.get("type") == "heartbeat" and "reason" in message_json:
|
|
# Append warning to the reason
|
|
warning = f" RETRY {retry_count}/{self.MAX_RETRIES} ***DO NOT USE SPECIAL CHARACTERS OR QUOTATIONS INSIDE FUNCTION CALL ARGUMENTS. IF YOU MUST, MAKE SURE TO ESCAPE THEM PROPERLY***"
|
|
message_json["reason"] = message_json["reason"] + warning
|
|
# Update the text with modified JSON
|
|
part["text"] = json_dumps(message_json)
|
|
logger.warning(
|
|
f"Modified heartbeat message with special character warning for retry {retry_count}/{self.MAX_RETRIES}"
|
|
)
|
|
except (json.JSONDecodeError, TypeError, AttributeError):
|
|
# Not a JSON message or not a heartbeat, skip modification
|
|
pass
|
|
|
|
should_retry = is_malformed_function_call
|
|
retry_count += 1
|
|
|
|
if response_data is None:
|
|
raise RuntimeError("Failed to get response data after all retries")
|
|
return response_data
|
|
|
|
@trace_method
|
|
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
|
|
request_data = sanitize_unicode_surrogates(request_data)
|
|
|
|
client = await self._get_client_async(llm_config)
|
|
|
|
try:
|
|
response = await client.aio.models.generate_content_stream(
|
|
model=llm_config.model,
|
|
contents=request_data["contents"],
|
|
config=request_data["config"],
|
|
)
|
|
except pydantic_core._pydantic_core.ValidationError as e:
|
|
# Handle Pydantic validation errors from the Google SDK
|
|
# This occurs when tool schemas contain unsupported fields
|
|
logger.error(
|
|
f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}"
|
|
)
|
|
raise LLMBadRequestError(
|
|
message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. "
|
|
f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. "
|
|
f"Please check your tool definitions. Error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
except errors.APIError as e:
|
|
raise self.handle_llm_error(e)
|
|
except Exception as e:
|
|
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
|
|
raise e
|
|
# Direct yield - keeps response alive in generator's local scope throughout iteration
|
|
# This is required because the SDK's connection lifecycle is tied to the response object
|
|
try:
|
|
async for chunk in response:
|
|
yield chunk
|
|
except errors.ClientError as e:
|
|
if e.code == 499:
|
|
logger.info(f"{self._provider_prefix()} Stream cancelled by client (499): {e}")
|
|
return
|
|
raise self.handle_llm_error(e, llm_config=llm_config)
|
|
except errors.APIError as e:
|
|
raise self.handle_llm_error(e, llm_config=llm_config)
|
|
|
|
@staticmethod
|
|
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
|
|
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
|
|
|
|
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
|
|
so there is no natural follow-up 'model' role message.
|
|
|
|
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
|
|
with role == 'model' that is placed in-betweeen and function output
|
|
(role == 'tool') and user message (role == 'user').
|
|
"""
|
|
dummy_yield_message = {
|
|
"role": "model",
|
|
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
|
|
}
|
|
messages_with_padding = []
|
|
for i, message in enumerate(messages):
|
|
messages_with_padding.append(message)
|
|
# Check if the current message role is 'tool' and the next message role is 'user'
|
|
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
|
|
messages_with_padding.append(dummy_yield_message)
|
|
|
|
return messages_with_padding
|
|
|
|
def _clean_google_ai_schema_properties(self, schema_part: dict):
|
|
"""Recursively clean schema parts to remove unsupported Google AI keywords."""
|
|
if not isinstance(schema_part, dict):
|
|
return
|
|
|
|
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
|
|
# * Only a subset of the OpenAPI schema is supported.
|
|
# * Supported parameter types in Python are limited.
|
|
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema", "const", "$ref"]
|
|
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
|
|
for key_to_remove in keys_to_remove_at_this_level:
|
|
logger.debug(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
|
|
del schema_part[key_to_remove]
|
|
|
|
if schema_part.get("type") == "string" and "format" in schema_part:
|
|
allowed_formats = ["enum", "date-time"]
|
|
if schema_part["format"] not in allowed_formats:
|
|
logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}")
|
|
del schema_part["format"]
|
|
|
|
# Check properties within the current level
|
|
if "properties" in schema_part and isinstance(schema_part["properties"], dict):
|
|
for prop_name, prop_schema in schema_part["properties"].items():
|
|
self._clean_google_ai_schema_properties(prop_schema)
|
|
|
|
# Check items within arrays
|
|
if "items" in schema_part and isinstance(schema_part["items"], dict):
|
|
self._clean_google_ai_schema_properties(schema_part["items"])
|
|
|
|
# Check within anyOf, allOf, oneOf lists
|
|
for key in ["anyOf", "allOf", "oneOf"]:
|
|
if key in schema_part and isinstance(schema_part[key], list):
|
|
for item_schema in schema_part[key]:
|
|
self._clean_google_ai_schema_properties(item_schema)
|
|
|
|
def _resolve_json_schema_refs(self, schema: dict, defs: dict = None) -> dict:
|
|
"""
|
|
Recursively resolve $ref in JSON schema by inlining definitions.
|
|
Google GenAI SDK does not support $ref.
|
|
"""
|
|
if defs is None:
|
|
# Look for definitions at the top level
|
|
defs = schema.get("$defs") or schema.get("definitions") or {}
|
|
|
|
if not isinstance(schema, dict):
|
|
return schema
|
|
|
|
# If this is a ref, resolve it
|
|
if "$ref" in schema:
|
|
ref = schema["$ref"]
|
|
if isinstance(ref, str):
|
|
for prefix in ("#/$defs/", "#/definitions/"):
|
|
if ref.startswith(prefix):
|
|
ref_name = ref.split("/")[-1]
|
|
if ref_name in defs:
|
|
resolved = defs[ref_name].copy()
|
|
return self._resolve_json_schema_refs(resolved, defs)
|
|
break
|
|
|
|
logger.warning(f"Could not resolve $ref '{ref}' in schema — will be stripped by schema cleaner")
|
|
|
|
# Recursively process children
|
|
new_schema = schema.copy()
|
|
|
|
# We need to remove $defs/definitions from the output schema as Google doesn't support them
|
|
if "$defs" in new_schema:
|
|
del new_schema["$defs"]
|
|
if "definitions" in new_schema:
|
|
del new_schema["definitions"]
|
|
|
|
for k, v in new_schema.items():
|
|
if isinstance(v, dict):
|
|
new_schema[k] = self._resolve_json_schema_refs(v, defs)
|
|
elif isinstance(v, list):
|
|
new_schema[k] = [self._resolve_json_schema_refs(i, defs) if isinstance(i, dict) else i for i in v]
|
|
|
|
return new_schema
|
|
|
|
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
|
|
"""
|
|
OpenAI style:
|
|
"tools": [{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "find_movies",
|
|
"description": "find ....",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
PARAM: {
|
|
"type": PARAM_TYPE, # eg "string"
|
|
"description": PARAM_DESCRIPTION,
|
|
},
|
|
...
|
|
},
|
|
"required": List[str],
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
Google AI style:
|
|
"tools": [{
|
|
"functionDeclarations": [{
|
|
"name": "find_movies",
|
|
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
|
|
"parameters": {
|
|
"type": "OBJECT",
|
|
"properties": {
|
|
"location": {
|
|
"type": "STRING",
|
|
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
|
|
},
|
|
"description": {
|
|
"type": "STRING",
|
|
"description": "Any kind of description including category or genre, title words, attributes, etc."
|
|
}
|
|
},
|
|
"required": ["description"]
|
|
}
|
|
}, {
|
|
"name": "find_theaters",
|
|
...
|
|
"""
|
|
function_list = [
|
|
dict(
|
|
name=t.function.name,
|
|
description=t.function.description,
|
|
# Deep copy parameters to avoid modifying the original Tool object
|
|
parameters=copy.deepcopy(t.function.parameters) if t.function.parameters else {},
|
|
)
|
|
for t in tools
|
|
]
|
|
|
|
# Add inner thoughts if needed
|
|
for func in function_list:
|
|
# Note: Google AI API used to have weird casing requirements, but not any more
|
|
|
|
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
|
|
if "parameters" in func and isinstance(func["parameters"], dict):
|
|
# Resolve $ref in schema because Google AI SDK doesn't support them
|
|
func["parameters"] = self._resolve_json_schema_refs(func["parameters"])
|
|
self._clean_google_ai_schema_properties(func["parameters"])
|
|
|
|
# Add inner thoughts
|
|
if llm_config.put_inner_thoughts_in_kwargs:
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX
|
|
|
|
func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = {
|
|
"type": "string",
|
|
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
}
|
|
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX)
|
|
|
|
return [{"functionDeclarations": function_list}]
|
|
|
|
@trace_method
|
|
def build_request_data(
|
|
self,
|
|
agent_type: AgentType, # if react, use native content + strip heartbeats
|
|
messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
tools: List[dict],
|
|
force_tool_call: Optional[str] = None,
|
|
requires_subsequent_tool_call: bool = False,
|
|
tool_return_truncation_chars: Optional[int] = None,
|
|
) -> dict:
|
|
"""
|
|
Constructs a request object in the expected data format for this client.
|
|
"""
|
|
# NOTE: forcing inner thoughts in kwargs off
|
|
if agent_type == AgentType.letta_v1_agent:
|
|
llm_config.put_inner_thoughts_in_kwargs = False
|
|
|
|
if tools:
|
|
tool_objs = [Tool(type="function", function=t) for t in tools]
|
|
tool_names = [t.function.name for t in tool_objs]
|
|
# Convert to the exact payload style Google expects
|
|
formatted_tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
|
|
else:
|
|
formatted_tools = []
|
|
tool_names = []
|
|
|
|
contents = self.add_dummy_model_messages(
|
|
PydanticMessage.to_google_dicts_from_list(
|
|
messages,
|
|
current_model=llm_config.model,
|
|
put_inner_thoughts_in_kwargs=False if agent_type == AgentType.letta_v1_agent else True,
|
|
native_content=True if agent_type == AgentType.letta_v1_agent else False,
|
|
),
|
|
)
|
|
|
|
request_data = {
|
|
"contents": contents,
|
|
"config": {
|
|
"temperature": llm_config.temperature,
|
|
"tools": formatted_tools,
|
|
},
|
|
}
|
|
# Make tokens is optional
|
|
if llm_config.max_tokens:
|
|
request_data["config"]["max_output_tokens"] = llm_config.max_tokens
|
|
|
|
if len(tool_names) == 1 and settings.use_vertex_structured_outputs_experimental:
|
|
request_data["config"]["response_mime_type"] = "application/json"
|
|
request_data["config"]["response_schema"] = self.get_function_call_response_schema(tools[0])
|
|
del request_data["config"]["tools"]
|
|
elif tools:
|
|
if agent_type == AgentType.letta_v1_agent:
|
|
# don't require tools
|
|
tool_call_mode = FunctionCallingConfigMode.AUTO
|
|
tool_config = ToolConfig(
|
|
function_calling_config=FunctionCallingConfig(
|
|
mode=tool_call_mode,
|
|
)
|
|
)
|
|
else:
|
|
# require tools
|
|
tool_call_mode = FunctionCallingConfigMode.ANY
|
|
tool_config = ToolConfig(
|
|
function_calling_config=FunctionCallingConfig(
|
|
mode=tool_call_mode,
|
|
# Provide the list of tools (though empty should also work, it seems not to)
|
|
allowed_function_names=tool_names,
|
|
)
|
|
)
|
|
|
|
request_data["config"]["tool_config"] = tool_config.model_dump()
|
|
|
|
# https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
|
# 2.5 Pro
|
|
# - Default: dynamic thinking
|
|
# - Dynamic thinking that cannot be disabled
|
|
# - Range: -1 (for dynamic), or 128-32768
|
|
# 2.5 Flash
|
|
# - Default: dynamic thinking
|
|
# - Dynamic thinking that *can* be disabled
|
|
# - Range: -1, 0, or 0-24576
|
|
# 2.5 Flash Lite
|
|
# - Default: no thinking
|
|
# - Dynamic thinking that *can* be disabled
|
|
# - Range: -1, 0, or 512-24576
|
|
# TODO when using v3 agent loop, properly support the native thinking in Gemini
|
|
|
|
# Add thinking_config for all Gemini reasoning models (2.5 series)
|
|
# If enable_reasoner is False, set thinking_budget to 0
|
|
# Otherwise, use the value from max_reasoning_tokens
|
|
if self.is_reasoning_model(llm_config) or "flash" in llm_config.model:
|
|
if llm_config.model.startswith("gemini-3"):
|
|
# letting thinking_level to default to high by not specifying thinking_budget
|
|
thinking_config = ThinkingConfig(include_thoughts=True)
|
|
else:
|
|
# Gemini reasoning models may fail to call tools even with FunctionCallingConfigMode.ANY if thinking is fully disabled, set to minimum to prevent tool call failure
|
|
thinking_budget = (
|
|
llm_config.max_reasoning_tokens if llm_config.enable_reasoner else self.get_thinking_budget(llm_config.model)
|
|
)
|
|
if thinking_budget <= 0:
|
|
logger.warning(
|
|
f"Thinking budget of {thinking_budget} for Gemini reasoning model {llm_config.model}, this will likely cause tool call failures"
|
|
)
|
|
# For models that require thinking mode (2.5 Pro, 3.x), override with minimum valid budget
|
|
if llm_config.model.startswith("gemini-2.5-pro"):
|
|
thinking_budget = 128
|
|
logger.warning(
|
|
f"Overriding thinking_budget to {thinking_budget} for model {llm_config.model} which requires thinking mode"
|
|
)
|
|
thinking_config = ThinkingConfig(
|
|
thinking_budget=(thinking_budget),
|
|
include_thoughts=(thinking_budget > 1),
|
|
)
|
|
request_data["config"]["thinking_config"] = thinking_config.model_dump()
|
|
|
|
return request_data
|
|
|
|
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
|
"""Extract usage statistics from Gemini response and return as LettaUsageStatistics."""
|
|
if not response_data:
|
|
return LettaUsageStatistics()
|
|
|
|
response = GenerateContentResponse(**response_data)
|
|
if not response.usage_metadata:
|
|
return LettaUsageStatistics()
|
|
|
|
cached_tokens = None
|
|
if (
|
|
hasattr(response.usage_metadata, "cached_content_token_count")
|
|
and response.usage_metadata.cached_content_token_count is not None
|
|
):
|
|
cached_tokens = response.usage_metadata.cached_content_token_count
|
|
|
|
reasoning_tokens = None
|
|
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
|
|
reasoning_tokens = response.usage_metadata.thoughts_token_count
|
|
|
|
return LettaUsageStatistics(
|
|
prompt_tokens=response.usage_metadata.prompt_token_count or 0,
|
|
completion_tokens=response.usage_metadata.candidates_token_count or 0,
|
|
total_tokens=response.usage_metadata.total_token_count or 0,
|
|
cached_input_tokens=cached_tokens,
|
|
reasoning_tokens=reasoning_tokens,
|
|
)
|
|
|
|
@trace_method
|
|
async def convert_response_to_chat_completion(
|
|
self,
|
|
response_data: dict,
|
|
input_messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Converts custom response format from llm client into an OpenAI
|
|
ChatCompletionsResponse object.
|
|
|
|
Example:
|
|
{
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{
|
|
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
|
|
}
|
|
]
|
|
}
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 9,
|
|
"candidatesTokenCount": 27,
|
|
"totalTokenCount": 36
|
|
}
|
|
}
|
|
"""
|
|
response = GenerateContentResponse(**response_data)
|
|
try:
|
|
choices = []
|
|
index = 0
|
|
for candidate in response.candidates:
|
|
content = candidate.content
|
|
|
|
if content is None or content.role is None or content.parts is None:
|
|
# This means the response is malformed like MALFORMED_FUNCTION_CALL
|
|
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
|
|
raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}")
|
|
else:
|
|
raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}")
|
|
|
|
role = content.role
|
|
assert role == "model", f"Unknown role in response: {role}"
|
|
|
|
parts = content.parts
|
|
|
|
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
|
|
# so let's disable it for now
|
|
|
|
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
|
|
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
|
|
# To patch this, if we have multiple parts we can take the last one
|
|
# Unless parallel tool calling is enabled, in which case multiple parts may be intentional
|
|
if len(parts) > 1 and not llm_config.enable_reasoner and not llm_config.parallel_tool_calls:
|
|
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
|
|
# only truncate if reasoning is off and parallel tool calls are disabled
|
|
parts = [parts[-1]]
|
|
|
|
# TODO support parts / multimodal
|
|
# Parallel tool calling is now supported when llm_config.parallel_tool_calls is enabled
|
|
openai_response_message = None
|
|
for response_message in parts:
|
|
# Convert the actual message style to OpenAI style
|
|
if response_message.function_call:
|
|
function_call = response_message.function_call
|
|
function_name = function_call.name
|
|
function_args = function_call.args
|
|
assert isinstance(function_args, dict), function_args
|
|
|
|
# TODO this is kind of funky - really, we should be passing 'native_content' as a kwarg to fork behavior
|
|
inner_thoughts = response_message.text
|
|
if llm_config.put_inner_thoughts_in_kwargs:
|
|
# NOTE: this also involves stripping the inner monologue out of the function
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
|
|
|
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
|
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
|
)
|
|
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
|
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
|
else:
|
|
pass
|
|
# inner_thoughts = None
|
|
# inner_thoughts = response_message.text
|
|
|
|
# Google AI API doesn't generate tool call IDs
|
|
tool_call = ToolCall(
|
|
id=get_tool_call_id(),
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
|
),
|
|
)
|
|
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
content=inner_thoughts,
|
|
tool_calls=[tool_call],
|
|
)
|
|
if response_message.thought_signature:
|
|
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
|
|
openai_response_message.reasoning_content_signature = thought_signature
|
|
else:
|
|
openai_response_message.content = inner_thoughts
|
|
if openai_response_message.tool_calls is None:
|
|
openai_response_message.tool_calls = []
|
|
openai_response_message.tool_calls.append(tool_call)
|
|
if response_message.thought_signature:
|
|
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
|
|
openai_response_message.reasoning_content_signature = thought_signature
|
|
|
|
else:
|
|
if response_message.thought:
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
reasoning_content=response_message.text,
|
|
)
|
|
else:
|
|
openai_response_message.reasoning_content = response_message.text
|
|
try:
|
|
# Structured output tool call
|
|
function_call = json_loads(response_message.text)
|
|
|
|
# Access dict keys - will raise TypeError/KeyError if not a dict or missing keys
|
|
function_name = function_call["name"]
|
|
function_args = function_call["args"]
|
|
assert isinstance(function_args, dict), function_args
|
|
|
|
# NOTE: this also involves stripping the inner monologue out of the function
|
|
if llm_config.put_inner_thoughts_in_kwargs:
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
|
|
|
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
|
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
|
)
|
|
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
|
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
|
else:
|
|
inner_thoughts = None
|
|
|
|
# Google AI API doesn't generate tool call IDs
|
|
tool_call = ToolCall(
|
|
id=get_tool_call_id(),
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
|
),
|
|
)
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
content=inner_thoughts,
|
|
tool_calls=[tool_call],
|
|
)
|
|
else:
|
|
openai_response_message.content = inner_thoughts
|
|
if openai_response_message.tool_calls is None:
|
|
openai_response_message.tool_calls = []
|
|
openai_response_message.tool_calls.append(tool_call)
|
|
|
|
except (json.decoder.JSONDecodeError, ValueError, TypeError, KeyError, AssertionError) as e:
|
|
if candidate.finish_reason == "MAX_TOKENS":
|
|
raise LLMServerError("Could not parse response data from LLM: exceeded max token limit")
|
|
# Log the parsing error for debugging
|
|
logger.warning(
|
|
f"Failed to parse structured output from LLM response: {e}. Response text: {response_message.text[:500]}"
|
|
)
|
|
# Inner thoughts are the content by default
|
|
inner_thoughts = response_message.text
|
|
|
|
# Google AI API doesn't generate tool call IDs
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
content=inner_thoughts,
|
|
)
|
|
else:
|
|
openai_response_message.content = inner_thoughts
|
|
if response_message.thought_signature:
|
|
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
|
|
openai_response_message.reasoning_content_signature = thought_signature
|
|
|
|
# Google AI API uses different finish reason strings than OpenAI
|
|
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
|
|
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
|
|
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
|
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
|
|
finish_reason = candidate.finish_reason.value
|
|
if finish_reason == "STOP":
|
|
openai_finish_reason = (
|
|
"function_call"
|
|
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
|
|
else "stop"
|
|
)
|
|
elif finish_reason == "MAX_TOKENS":
|
|
openai_finish_reason = "length"
|
|
elif finish_reason == "SAFETY":
|
|
openai_finish_reason = "content_filter"
|
|
elif finish_reason == "RECITATION":
|
|
openai_finish_reason = "content_filter"
|
|
else:
|
|
raise LLMServerError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
|
|
|
|
choices.append(
|
|
Choice(
|
|
finish_reason=openai_finish_reason,
|
|
index=index,
|
|
message=openai_response_message,
|
|
)
|
|
)
|
|
index += 1
|
|
|
|
# if len(choices) > 1:
|
|
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
|
|
|
|
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
|
|
# "usageMetadata": {
|
|
# "promptTokenCount": 9,
|
|
# "candidatesTokenCount": 27,
|
|
# "totalTokenCount": 36
|
|
# }
|
|
usage = None
|
|
if response.usage_metadata:
|
|
# Extract usage via centralized method
|
|
from letta.schemas.enums import ProviderType
|
|
|
|
usage = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.google_ai)
|
|
else:
|
|
# Count it ourselves using the Gemini token counting API
|
|
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
|
|
google_messages = PydanticMessage.to_google_dicts_from_list(input_messages, current_model=llm_config.model)
|
|
prompt_tokens = await self.count_tokens(messages=google_messages, model=llm_config.model)
|
|
# For completion tokens, wrap the response content in Google format
|
|
completion_content = [{"role": "model", "parts": [{"text": json_dumps(openai_response_message.model_dump())}]}]
|
|
completion_tokens = await self.count_tokens(messages=completion_content, model=llm_config.model)
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
usage = UsageStatistics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
)
|
|
|
|
response_id = str(uuid.uuid4())
|
|
return ChatCompletionResponse(
|
|
id=response_id,
|
|
choices=choices,
|
|
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
|
created=get_utc_time_int(),
|
|
usage=usage,
|
|
)
|
|
except KeyError as e:
|
|
raise e
|
|
|
|
def get_function_call_response_schema(self, tool: dict) -> dict:
|
|
return {
|
|
"type": "OBJECT",
|
|
"properties": {
|
|
"name": {"type": "STRING", "enum": [tool["name"]]},
|
|
"args": {
|
|
"type": "OBJECT",
|
|
"properties": tool["parameters"]["properties"],
|
|
"required": tool["parameters"]["required"],
|
|
},
|
|
},
|
|
"propertyOrdering": ["name", "args"],
|
|
"required": ["name", "args"],
|
|
}
|
|
|
|
# https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
|
# | Model | Default setting | Range | Disable thinking | Turn on dynamic thinking|
|
|
# |-----------------|-------------------------------------------------------------------|--------------|----------------------------|-------------------------|
|
|
# | 2.5 Pro | Dynamic thinking: Model decides when and how much to think | 128-32768 | N/A: Cannot disable | thinkingBudget = -1 |
|
|
# | 2.5 Flash | Dynamic thinking: Model decides when and how much to think | 0-24576 | thinkingBudget = 0 | thinkingBudget = -1 |
|
|
# | 2.5 Flash Lite | Model does not think | 512-24576 | thinkingBudget = 0 | thinkingBudget = -1 |
|
|
# | 3.x | Dynamic thinking: Model decides when and how much to think | 128-? | N/A: Cannot disable | thinkingBudget = -1 |
|
|
def get_thinking_budget(self, model: str) -> bool:
|
|
if model_settings.gemini_force_minimum_thinking_budget:
|
|
if all(substring in model for substring in ["2.5", "flash", "lite"]):
|
|
return 512
|
|
elif all(substring in model for substring in ["2.5", "flash"]):
|
|
return 1
|
|
# Gemini 3 and 2.5 Pro require thinking mode and cannot have budget 0
|
|
if model.startswith("gemini-3") or model.startswith("gemini-2.5-pro"):
|
|
return 128 # Minimum valid budget for models that require thinking
|
|
return 0
|
|
|
|
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
|
return (
|
|
llm_config.model.startswith("gemini-2.5-flash")
|
|
or llm_config.model.startswith("gemini-2.5-pro")
|
|
or llm_config.model.startswith("gemini-3")
|
|
)
|
|
|
|
def is_malformed_function_call(self, response_data: dict) -> dict:
|
|
response = GenerateContentResponse(**response_data)
|
|
for candidate in response.candidates:
|
|
content = candidate.content
|
|
if content is None or content.role is None or content.parts is None:
|
|
return candidate.finish_reason == "MALFORMED_FUNCTION_CALL"
|
|
return False
|
|
|
|
@trace_method
|
|
def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception:
|
|
is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None
|
|
|
|
# Handle Google GenAI specific errors
|
|
if isinstance(e, errors.ClientError):
|
|
if e.code == 499:
|
|
logger.info(f"{self._provider_prefix()} Request cancelled by client (499): {e}")
|
|
return LLMConnectionError(
|
|
message=f"Request to {self._provider_name()} was cancelled (client disconnected): {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"status_code": 499, "cause": "client_cancelled", "is_byok": is_byok},
|
|
)
|
|
|
|
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
|
|
|
|
# Handle specific error codes
|
|
if e.code == 400:
|
|
error_str = str(e).lower()
|
|
if ("context" in error_str or "token count" in error_str or "tokens allowed" in error_str) and (
|
|
"exceed" in error_str or "limit" in error_str or "too long" in error_str
|
|
):
|
|
return ContextWindowExceededError(
|
|
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
|
|
details={"is_byok": is_byok},
|
|
)
|
|
else:
|
|
return LLMBadRequestError(
|
|
message=f"Bad request to {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INVALID_ARGUMENT,
|
|
details={"is_byok": is_byok},
|
|
)
|
|
elif e.code == 401:
|
|
return LLMAuthenticationError(
|
|
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"is_byok": is_byok},
|
|
)
|
|
elif e.code == 403:
|
|
return LLMPermissionDeniedError(
|
|
message=f"Permission denied by {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"is_byok": is_byok},
|
|
)
|
|
elif e.code == 404:
|
|
return LLMNotFoundError(
|
|
message=f"Resource not found in {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"is_byok": is_byok},
|
|
)
|
|
elif e.code == 408:
|
|
return LLMTimeoutError(
|
|
message=f"Request to {self._provider_name()} timed out: {str(e)}",
|
|
code=ErrorCode.TIMEOUT,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
|
)
|
|
elif e.code == 422:
|
|
return LLMUnprocessableEntityError(
|
|
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"is_byok": is_byok},
|
|
)
|
|
elif e.code == 429:
|
|
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
|
|
return LLMRateLimitError(
|
|
message=f"Rate limited by {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
|
details={"is_byok": is_byok},
|
|
)
|
|
else:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} client error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
"is_byok": is_byok,
|
|
},
|
|
)
|
|
|
|
if isinstance(e, errors.ServerError):
|
|
logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}")
|
|
|
|
# Handle specific server error codes
|
|
if e.code == 500:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} internal server error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
"is_byok": is_byok,
|
|
},
|
|
)
|
|
elif e.code == 502:
|
|
return LLMConnectionError(
|
|
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
|
)
|
|
elif e.code == 503:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} service unavailable: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
"is_byok": is_byok,
|
|
},
|
|
)
|
|
elif e.code == 504:
|
|
return LLMTimeoutError(
|
|
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.TIMEOUT,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
|
)
|
|
else:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} server error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
"is_byok": is_byok,
|
|
},
|
|
)
|
|
|
|
if isinstance(e, errors.APIError):
|
|
logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}")
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} API error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
"is_byok": is_byok,
|
|
},
|
|
)
|
|
|
|
# Handle httpx.RemoteProtocolError which can occur during streaming
|
|
# when the remote server closes the connection unexpectedly
|
|
# (e.g., "peer closed connection without sending complete message body")
|
|
if isinstance(e, httpx.RemoteProtocolError):
|
|
logger.warning(f"{self._provider_prefix()} Remote protocol error during streaming: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Connection error during {self._provider_name()} streaming: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
|
)
|
|
|
|
# Handle httpx network errors which can occur during streaming
|
|
# when the connection is unexpectedly closed while reading/writing
|
|
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
|
|
logger.warning(f"{self._provider_prefix()} Network error during streaming: {type(e).__name__}: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Network error during {self._provider_name()} streaming: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok},
|
|
)
|
|
|
|
# Handle connection-related errors
|
|
if "connection" in str(e).lower() or "timeout" in str(e).lower():
|
|
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok},
|
|
)
|
|
|
|
# Fallback to base implementation for other errors
|
|
return super().handle_llm_error(e, llm_config=llm_config)
|
|
|
|
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
|
"""
|
|
Count tokens for the given messages and tools using the Gemini token counting API.
|
|
|
|
Args:
|
|
messages: List of message dicts in Google AI format (with 'role' and 'parts' keys)
|
|
model: The model to use for token counting (defaults to gemini-2.0-flash-lite)
|
|
tools: List of OpenAI-style Tool objects to include in the count
|
|
|
|
Returns:
|
|
The total token count for the input
|
|
"""
|
|
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
|
|
|
|
client = self._get_client()
|
|
|
|
# Default model for token counting if not specified
|
|
count_model = model or GOOGLE_MODEL_FOR_API_KEY_CHECK
|
|
|
|
# Build the contents parameter
|
|
# If no messages provided, use empty string (like the API key check)
|
|
if messages is None or len(messages) == 0:
|
|
contents = ""
|
|
else:
|
|
# Messages should already be in Google format (role + parts)
|
|
contents = messages
|
|
|
|
try:
|
|
# Count message tokens
|
|
result = await client.aio.models.count_tokens(
|
|
model=count_model,
|
|
contents=contents,
|
|
)
|
|
total_tokens = result.total_tokens
|
|
|
|
# Count tool tokens separately by serializing to text
|
|
# The Gemini count_tokens API doesn't support a tools parameter directly
|
|
if tools and len(tools) > 0:
|
|
# Serialize tools to JSON text and count those tokens
|
|
tools_text = json.dumps([t.model_dump() for t in tools])
|
|
tools_result = await client.aio.models.count_tokens(
|
|
model=count_model,
|
|
contents=tools_text,
|
|
)
|
|
total_tokens += tools_result.total_tokens
|
|
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e)
|
|
|
|
return total_tokens
|