* feat: add non-streaming option for conversation messages - Add ConversationMessageRequest with stream=True default (backwards compatible) - stream=true (default): SSE streaming via StreamingService - stream=false: JSON response via AgentLoop.load().step() 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * chore: regenerate API schema for ConversationMessageRequest * feat: add direct ClickHouse storage for raw LLM traces Adds ability to store raw LLM request/response payloads directly in ClickHouse, bypassing OTEL span attribute size limits. This enables debugging and analytics on large LLM payloads (>10MB system prompts, large tool schemas, etc.). New files: - letta/schemas/llm_raw_trace.py: Pydantic schema with ClickHouse row helper - letta/services/llm_raw_trace_writer.py: Async batching writer (fire-and-forget) - letta/services/llm_raw_trace_reader.py: Reader with query methods - scripts/sql/clickhouse/llm_raw_traces.ddl: Production table DDL - scripts/sql/clickhouse/llm_raw_traces_local.ddl: Local dev DDL - apps/core/clickhouse-init.sql: Local dev initialization Modified: - letta/settings.py: Added 4 settings (store_llm_raw_traces, ttl, batch_size, flush_interval) - letta/llm_api/llm_client_base.py: Integration into request_async_with_telemetry - compose.yaml: Added ClickHouse service for local dev - justfile: Added clickhouse, clickhouse-cli, clickhouse-traces commands Feature disabled by default (LETTA_STORE_LLM_RAW_TRACES=false). Uses ZSTD(3) compression for 10-30x reduction on JSON payloads. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: address code review feedback for LLM raw traces Fixes based on code review feedback: 1. Fix ClickHouse endpoint parsing - default to secure=False for raw host:port inputs (was defaulting to HTTPS which breaks local dev) 2. Make raw trace writes truly fire-and-forget - use asyncio.create_task() instead of awaiting, so JSON serialization doesn't block request path 3. Add bounded queue (maxsize=10000) - prevents unbounded memory growth under load. Drops traces with warning if queue is full. 4. Fix deprecated asyncio usage - get_running_loop() instead of get_event_loop() 5. Add org_id fallback - use _telemetry_org_id if actor doesn't have it 6. Remove unused imports - json import in reader 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: add missing asyncio import and simplify JSON serialization - Add missing 'import asyncio' that was causing 'name asyncio is not defined' error - Remove unnecessary clean_double_escapes() function - the JSON is stored correctly, the clickhouse-client CLI was just adding extra escaping when displaying - Update just clickhouse-trace to use Python client for correct JSON output 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * test: add clickhouse raw trace integration test * test: simplify clickhouse trace assertions * refactor: centralize usage parsing and stream error traces Use per-client usage helpers for raw trace extraction and ensure streaming errors log requests with error metadata. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * test: exercise provider usage parsing live Make live OpenAI/Anthropic/Gemini requests with credential gating and validate Anthropic cache usage mapping when present. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * test: fix usage parsing tests to pass - Use GoogleAIClient with GEMINI_API_KEY instead of GoogleVertexClient - Update model to gemini-2.0-flash (1.5-flash deprecated in v1beta) - Add tools=[] for Gemini/Anthropic build_request_data 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: extract_usage_statistics returns LettaUsageStatistics Standardize on LettaUsageStatistics as the canonical usage format returned by client helpers. Inline UsageStatistics construction for ChatCompletionResponse where needed. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * feat: add is_byok and llm_config_json columns to ClickHouse traces Extend llm_raw_traces table with: - is_byok (UInt8): Track BYOK vs base provider usage for billing analytics - llm_config_json (String, ZSTD): Store full LLM config for debugging and analysis This enables queries like: - BYOK usage breakdown by provider/model - Config parameter analysis (temperature, max_tokens, etc.) - Debugging specific request configurations * feat: add tests for error traces, llm_config_json, and cache tokens - Update llm_raw_trace_reader.py to query new columns (is_byok, cached_input_tokens, cache_write_tokens, reasoning_tokens, llm_config_json) - Add test_error_trace_stored_in_clickhouse to verify error fields - Add test_cache_tokens_stored_for_anthropic to verify cache token storage - Update existing tests to verify llm_config_json is stored correctly - Make llm_config required in log_provider_trace_async() - Simplify provider extraction to use provider_name directly 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * ci: add ClickHouse integration tests to CI pipeline - Add use-clickhouse option to reusable-test-workflow.yml - Add ClickHouse service container with otel database - Add schema initialization step using clickhouse-init.sql - Add ClickHouse env vars (CLICKHOUSE_ENDPOINT, etc.) - Add separate clickhouse-integration-tests job running integration_test_clickhouse_llm_raw_traces.py 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: simplify provider and org_id extraction in raw trace writer - Use model_endpoint_type.value for provider (not provider_name) - Simplify org_id to just self.actor.organization_id (actor is always pydantic) 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: simplify LLMRawTraceWriter with _enabled flag - Check ClickHouse env vars once at init, set _enabled flag - Early return in write_async/flush_async if not enabled - Remove ValueError raises (never used) - Simplify _get_client (no validation needed since already checked) 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: add LLMRawTraceWriter shutdown to FastAPI lifespan Properly flush pending traces on graceful shutdown via lifespan instead of relying only on atexit handler. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * feat: add agent_tags column to ClickHouse traces Store agent tags as Array(String) for filtering/analytics by tag. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * cleanup * fix(ci): fix ClickHouse schema initialization in CI - Create database separately before loading SQL file - Remove CREATE DATABASE from SQL file (handled in CI step) - Add verification step to confirm table was created - Use -sf flag for curl to fail on HTTP errors 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: simplify LLM trace writer with ClickHouse async_insert - Use ClickHouse async_insert for server-side batching instead of manual queue/flush loop - Sync cloud DDL schema with clickhouse-init.sql (add missing columns) - Remove redundant llm_raw_traces_local.ddl - Remove unused batch_size/flush_interval settings - Update tests for simplified writer Key changes: - async_insert=1, wait_for_async_insert=1 for reliable server-side batching - Simple per-trace retry with exponential backoff (max 3 retries) - ~150 lines removed from writer 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: consolidate ClickHouse direct writes into TelemetryManager backend - Add clickhouse_direct backend to provider_trace_backends - Remove duplicate ClickHouse write logic from llm_client_base.py - Configure via LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND=postgres,clickhouse_direct The clickhouse_direct backend: - Converts ProviderTrace to LLMRawTrace - Extracts usage stats from response JSON - Writes via LLMRawTraceWriter with async_insert 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: address PR review comments and fix llm_config bug Review comment fixes: - Rename clickhouse_direct -> clickhouse_analytics (clearer purpose) - Remove ClickHouse from OSS compose.yaml, create separate compose.clickhouse.yaml - Delete redundant scripts/test_llm_raw_traces.py (use pytest tests) - Remove unused llm_raw_traces_ttl_days setting (TTL handled in DDL) - Fix socket description leak in telemetry_manager docstring - Add cloud-only comment to clickhouse-init.sql - Update justfile to use separate compose file Bug fix: - Fix llm_config not being passed to ProviderTrace in telemetry - Now correctly populates provider, model, is_byok for all LLM calls - Affects both request_async_with_telemetry and log_provider_trace_async DDL optimizations: - Add secondary indexes (bloom_filter for agent_id, model, step_id) - Add minmax indexes for is_byok, is_error - Change model and error_type to LowCardinality for faster GROUP BY 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: rename llm_raw_traces -> llm_traces Address review feedback that "raw" is misleading since we denormalize fields. Renames: - Table: llm_raw_traces -> llm_traces - Schema: LLMRawTrace -> LLMTrace - Files: llm_raw_trace_{reader,writer}.py -> llm_trace_{reader,writer}.py - Setting: store_llm_raw_traces -> store_llm_traces 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: update workflow references to llm_traces Missed renaming table name in CI workflow files. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: update clickhouse_direct -> clickhouse_analytics in docstring 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * chore: remove inaccurate OTEL size limit comments The 4MB limit is our own truncation logic, not an OTEL protocol limit. The real benefit is denormalized columns for analytics queries. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * chore: remove local ClickHouse dev setup (cloud-only feature) - Delete clickhouse-init.sql and compose.clickhouse.yaml - Remove local clickhouse just commands - Update CI to use cloud DDL with MergeTree for testing clickhouse_analytics is a cloud-only feature. For local dev, use postgres backend. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: restore compose.yaml to match main 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: merge clickhouse_analytics into clickhouse backend Per review feedback - having two separate backends was confusing. Now the clickhouse backend: - Writes to llm_traces table (denormalized for cost analytics) - Reads from OTEL traces table (will cut over to llm_traces later) Config: LETTA_TELEMETRY_PROVIDER_TRACE_BACKEND=postgres,clickhouse 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: correct path to DDL file in CI workflow 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * chore: add provider index to DDL for faster filtering 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: configure telemetry backend in clickhouse tests Tests need to set telemetry_settings.provider_trace_backends to include 'clickhouse', otherwise traces are routed to default postgres backend. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: set provider_trace_backend field, not property provider_trace_backends is a computed property, need to set the underlying provider_trace_backend string field instead. 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: error trace test and error_type extraction - Add TelemetryManager to error trace test so traces get written - Fix error_type extraction to check top-level before nested error dict 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: use provider_trace.id for trace correlation across backends - Pass provider_trace.id to LLMTrace instead of auto-generating - Log warning if ID is missing (shouldn't happen, helps debug) - Fallback to new UUID only if not set 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: trace ID correlation and concurrency issues - Strip "provider_trace-" prefix from ID for UUID storage in ClickHouse - Add asyncio.Lock to serialize writes (clickhouse_connect not thread-safe) - Fix Anthropic prompt_tokens to include cached tokens for cost analytics - Log warning if provider_trace.id is missing 🤖 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com> Co-authored-by: Caren Thomas <carenthomas@gmail.com>
951 lines
46 KiB
Python
951 lines
46 KiB
Python
import base64
|
|
import json
|
|
import uuid
|
|
from typing import AsyncIterator, List, Optional
|
|
|
|
import httpx
|
|
from google.genai import Client, errors
|
|
from google.genai.types import (
|
|
FunctionCallingConfig,
|
|
FunctionCallingConfigMode,
|
|
GenerateContentResponse,
|
|
HttpOptions,
|
|
ThinkingConfig,
|
|
ToolConfig,
|
|
)
|
|
|
|
from letta.constants import NON_USER_MSG_PREFIX
|
|
from letta.errors import (
|
|
ContextWindowExceededError,
|
|
ErrorCode,
|
|
LLMAuthenticationError,
|
|
LLMBadRequestError,
|
|
LLMConnectionError,
|
|
LLMNotFoundError,
|
|
LLMPermissionDeniedError,
|
|
LLMRateLimitError,
|
|
LLMServerError,
|
|
LLMTimeoutError,
|
|
LLMUnprocessableEntityError,
|
|
)
|
|
from letta.helpers.datetime_helpers import get_utc_time_int
|
|
from letta.helpers.json_helpers import json_dumps, json_loads
|
|
from letta.llm_api.llm_client_base import LLMClientBase
|
|
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
|
from letta.log import get_logger
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.agent import AgentType
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool
|
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
|
|
from letta.schemas.usage import LettaUsageStatistics
|
|
from letta.settings import model_settings, settings
|
|
from letta.utils import get_tool_call_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class GoogleVertexClient(LLMClientBase):
|
|
MAX_RETRIES = model_settings.gemini_max_retries
|
|
provider_label = "Google Vertex"
|
|
|
|
def _get_client(self):
|
|
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
|
return Client(
|
|
vertexai=True,
|
|
project=model_settings.google_cloud_project,
|
|
location=model_settings.google_cloud_location,
|
|
http_options=HttpOptions(api_version="v1", timeout=timeout_ms),
|
|
)
|
|
|
|
def _provider_prefix(self) -> str:
|
|
return f"[{self.provider_label}]"
|
|
|
|
def _provider_name(self) -> str:
|
|
return self.provider_label
|
|
|
|
@trace_method
|
|
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying request to llm and returns raw response.
|
|
"""
|
|
try:
|
|
client = self._get_client()
|
|
response = client.models.generate_content(
|
|
model=llm_config.model,
|
|
contents=request_data["contents"],
|
|
config=request_data["config"],
|
|
)
|
|
return response.model_dump()
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e)
|
|
|
|
@trace_method
|
|
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying request to llm and returns raw response.
|
|
"""
|
|
client = self._get_client()
|
|
|
|
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
|
|
# https://github.com/googleapis/python-aiplatform/issues/4472
|
|
retry_count = 1
|
|
should_retry = True
|
|
response_data = None
|
|
while should_retry and retry_count <= self.MAX_RETRIES:
|
|
try:
|
|
response = await client.aio.models.generate_content(
|
|
model=llm_config.model,
|
|
contents=request_data["contents"],
|
|
config=request_data["config"],
|
|
)
|
|
except errors.APIError as e:
|
|
# Retry on 503 and 500 errors as well, usually ephemeral from Gemini
|
|
if e.code == 503 or e.code == 500 or e.code == 504:
|
|
logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}")
|
|
retry_count += 1
|
|
if retry_count > self.MAX_RETRIES:
|
|
raise self.handle_llm_error(e)
|
|
continue
|
|
raise self.handle_llm_error(e)
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e)
|
|
response_data = response.model_dump()
|
|
is_malformed_function_call = self.is_malformed_function_call(response_data)
|
|
if is_malformed_function_call:
|
|
logger.warning(
|
|
f"Received FinishReason.MALFORMED_FUNCTION_CALL in response for {llm_config.model}, retrying {retry_count}/{self.MAX_RETRIES}"
|
|
)
|
|
# Modify the last message if it's a heartbeat to include warning about special characters
|
|
if request_data["contents"] and len(request_data["contents"]) > 0:
|
|
last_message = request_data["contents"][-1]
|
|
if last_message.get("role") == "user" and last_message.get("parts"):
|
|
for part in last_message["parts"]:
|
|
if "text" in part:
|
|
try:
|
|
# Try to parse as JSON to check if it's a heartbeat
|
|
message_json = json_loads(part["text"])
|
|
if message_json.get("type") == "heartbeat" and "reason" in message_json:
|
|
# Append warning to the reason
|
|
warning = f" RETRY {retry_count}/{self.MAX_RETRIES} ***DO NOT USE SPECIAL CHARACTERS OR QUOTATIONS INSIDE FUNCTION CALL ARGUMENTS. IF YOU MUST, MAKE SURE TO ESCAPE THEM PROPERLY***"
|
|
message_json["reason"] = message_json["reason"] + warning
|
|
# Update the text with modified JSON
|
|
part["text"] = json_dumps(message_json)
|
|
logger.warning(
|
|
f"Modified heartbeat message with special character warning for retry {retry_count}/{self.MAX_RETRIES}"
|
|
)
|
|
except (json.JSONDecodeError, TypeError, AttributeError):
|
|
# Not a JSON message or not a heartbeat, skip modification
|
|
pass
|
|
|
|
should_retry = is_malformed_function_call
|
|
retry_count += 1
|
|
|
|
if response_data is None:
|
|
raise RuntimeError("Failed to get response data after all retries")
|
|
return response_data
|
|
|
|
@trace_method
|
|
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]:
|
|
client = self._get_client()
|
|
|
|
try:
|
|
response = await client.aio.models.generate_content_stream(
|
|
model=llm_config.model,
|
|
contents=request_data["contents"],
|
|
config=request_data["config"],
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}")
|
|
raise e
|
|
# Direct yield - keeps response alive in generator's local scope throughout iteration
|
|
# This is required because the SDK's connection lifecycle is tied to the response object
|
|
async for chunk in response:
|
|
yield chunk
|
|
|
|
@staticmethod
|
|
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
|
|
"""Google AI API requires all function call returns are immediately followed by a 'model' role message.
|
|
|
|
In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user,
|
|
so there is no natural follow-up 'model' role message.
|
|
|
|
To satisfy the Google AI API restrictions, we can add a dummy 'yield' message
|
|
with role == 'model' that is placed in-betweeen and function output
|
|
(role == 'tool') and user message (role == 'user').
|
|
"""
|
|
dummy_yield_message = {
|
|
"role": "model",
|
|
"parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}],
|
|
}
|
|
messages_with_padding = []
|
|
for i, message in enumerate(messages):
|
|
messages_with_padding.append(message)
|
|
# Check if the current message role is 'tool' and the next message role is 'user'
|
|
if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"):
|
|
messages_with_padding.append(dummy_yield_message)
|
|
|
|
return messages_with_padding
|
|
|
|
def _clean_google_ai_schema_properties(self, schema_part: dict):
|
|
"""Recursively clean schema parts to remove unsupported Google AI keywords."""
|
|
if not isinstance(schema_part, dict):
|
|
return
|
|
|
|
# Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations
|
|
# * Only a subset of the OpenAPI schema is supported.
|
|
# * Supported parameter types in Python are limited.
|
|
unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema"]
|
|
keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part]
|
|
for key_to_remove in keys_to_remove_at_this_level:
|
|
logger.debug(f"Removing unsupported keyword '{key_to_remove}' from schema part.")
|
|
del schema_part[key_to_remove]
|
|
|
|
if schema_part.get("type") == "string" and "format" in schema_part:
|
|
allowed_formats = ["enum", "date-time"]
|
|
if schema_part["format"] not in allowed_formats:
|
|
logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}")
|
|
del schema_part["format"]
|
|
|
|
# Check properties within the current level
|
|
if "properties" in schema_part and isinstance(schema_part["properties"], dict):
|
|
for prop_name, prop_schema in schema_part["properties"].items():
|
|
self._clean_google_ai_schema_properties(prop_schema)
|
|
|
|
# Check items within arrays
|
|
if "items" in schema_part and isinstance(schema_part["items"], dict):
|
|
self._clean_google_ai_schema_properties(schema_part["items"])
|
|
|
|
# Check within anyOf, allOf, oneOf lists
|
|
for key in ["anyOf", "allOf", "oneOf"]:
|
|
if key in schema_part and isinstance(schema_part[key], list):
|
|
for item_schema in schema_part[key]:
|
|
self._clean_google_ai_schema_properties(item_schema)
|
|
|
|
def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]:
|
|
"""
|
|
OpenAI style:
|
|
"tools": [{
|
|
"type": "function",
|
|
"function": {
|
|
"name": "find_movies",
|
|
"description": "find ....",
|
|
"parameters": {
|
|
"type": "object",
|
|
"properties": {
|
|
PARAM: {
|
|
"type": PARAM_TYPE, # eg "string"
|
|
"description": PARAM_DESCRIPTION,
|
|
},
|
|
...
|
|
},
|
|
"required": List[str],
|
|
}
|
|
}
|
|
}
|
|
]
|
|
|
|
Google AI style:
|
|
"tools": [{
|
|
"functionDeclarations": [{
|
|
"name": "find_movies",
|
|
"description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.",
|
|
"parameters": {
|
|
"type": "OBJECT",
|
|
"properties": {
|
|
"location": {
|
|
"type": "STRING",
|
|
"description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616"
|
|
},
|
|
"description": {
|
|
"type": "STRING",
|
|
"description": "Any kind of description including category or genre, title words, attributes, etc."
|
|
}
|
|
},
|
|
"required": ["description"]
|
|
}
|
|
}, {
|
|
"name": "find_theaters",
|
|
...
|
|
"""
|
|
function_list = [
|
|
dict(
|
|
name=t.function.name,
|
|
description=t.function.description,
|
|
parameters=t.function.parameters, # TODO need to unpack
|
|
)
|
|
for t in tools
|
|
]
|
|
|
|
# Add inner thoughts if needed
|
|
for func in function_list:
|
|
# Note: Google AI API used to have weird casing requirements, but not any more
|
|
|
|
# Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned
|
|
if "parameters" in func and isinstance(func["parameters"], dict):
|
|
self._clean_google_ai_schema_properties(func["parameters"])
|
|
|
|
# Add inner thoughts
|
|
if llm_config.put_inner_thoughts_in_kwargs:
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX
|
|
|
|
func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = {
|
|
"type": "string",
|
|
"description": INNER_THOUGHTS_KWARG_DESCRIPTION,
|
|
}
|
|
func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX)
|
|
|
|
return [{"functionDeclarations": function_list}]
|
|
|
|
@trace_method
|
|
def build_request_data(
|
|
self,
|
|
agent_type: AgentType, # if react, use native content + strip heartbeats
|
|
messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
tools: List[dict],
|
|
force_tool_call: Optional[str] = None,
|
|
requires_subsequent_tool_call: bool = False,
|
|
tool_return_truncation_chars: Optional[int] = None,
|
|
) -> dict:
|
|
"""
|
|
Constructs a request object in the expected data format for this client.
|
|
"""
|
|
# NOTE: forcing inner thoughts in kwargs off
|
|
if agent_type == AgentType.letta_v1_agent:
|
|
llm_config.put_inner_thoughts_in_kwargs = False
|
|
|
|
if tools:
|
|
tool_objs = [Tool(type="function", function=t) for t in tools]
|
|
tool_names = [t.function.name for t in tool_objs]
|
|
# Convert to the exact payload style Google expects
|
|
formatted_tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config)
|
|
else:
|
|
formatted_tools = []
|
|
tool_names = []
|
|
|
|
contents = self.add_dummy_model_messages(
|
|
PydanticMessage.to_google_dicts_from_list(
|
|
messages,
|
|
current_model=llm_config.model,
|
|
put_inner_thoughts_in_kwargs=False if agent_type == AgentType.letta_v1_agent else True,
|
|
native_content=True if agent_type == AgentType.letta_v1_agent else False,
|
|
),
|
|
)
|
|
|
|
request_data = {
|
|
"contents": contents,
|
|
"config": {
|
|
"temperature": llm_config.temperature,
|
|
"tools": formatted_tools,
|
|
},
|
|
}
|
|
# Make tokens is optional
|
|
if llm_config.max_tokens:
|
|
request_data["config"]["max_output_tokens"] = llm_config.max_tokens
|
|
|
|
if len(tool_names) == 1 and settings.use_vertex_structured_outputs_experimental:
|
|
request_data["config"]["response_mime_type"] = "application/json"
|
|
request_data["config"]["response_schema"] = self.get_function_call_response_schema(tools[0])
|
|
del request_data["config"]["tools"]
|
|
elif tools:
|
|
if agent_type == AgentType.letta_v1_agent:
|
|
# don't require tools
|
|
tool_call_mode = FunctionCallingConfigMode.AUTO
|
|
tool_config = ToolConfig(
|
|
function_calling_config=FunctionCallingConfig(
|
|
mode=tool_call_mode,
|
|
)
|
|
)
|
|
else:
|
|
# require tools
|
|
tool_call_mode = FunctionCallingConfigMode.ANY
|
|
tool_config = ToolConfig(
|
|
function_calling_config=FunctionCallingConfig(
|
|
mode=tool_call_mode,
|
|
# Provide the list of tools (though empty should also work, it seems not to)
|
|
allowed_function_names=tool_names,
|
|
)
|
|
)
|
|
|
|
request_data["config"]["tool_config"] = tool_config.model_dump()
|
|
|
|
# https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
|
# 2.5 Pro
|
|
# - Default: dynamic thinking
|
|
# - Dynamic thinking that cannot be disabled
|
|
# - Range: -1 (for dynamic), or 128-32768
|
|
# 2.5 Flash
|
|
# - Default: dynamic thinking
|
|
# - Dynamic thinking that *can* be disabled
|
|
# - Range: -1, 0, or 0-24576
|
|
# 2.5 Flash Lite
|
|
# - Default: no thinking
|
|
# - Dynamic thinking that *can* be disabled
|
|
# - Range: -1, 0, or 512-24576
|
|
# TODO when using v3 agent loop, properly support the native thinking in Gemini
|
|
|
|
# Add thinking_config for all Gemini reasoning models (2.5 series)
|
|
# If enable_reasoner is False, set thinking_budget to 0
|
|
# Otherwise, use the value from max_reasoning_tokens
|
|
if self.is_reasoning_model(llm_config) or "flash" in llm_config.model:
|
|
if llm_config.model.startswith("gemini-3"):
|
|
# letting thinking_level to default to high by not specifying thinking_budget
|
|
thinking_config = ThinkingConfig(include_thoughts=True)
|
|
else:
|
|
# Gemini reasoning models may fail to call tools even with FunctionCallingConfigMode.ANY if thinking is fully disabled, set to minimum to prevent tool call failure
|
|
thinking_budget = (
|
|
llm_config.max_reasoning_tokens if llm_config.enable_reasoner else self.get_thinking_budget(llm_config.model)
|
|
)
|
|
if thinking_budget <= 0:
|
|
logger.warning(
|
|
f"Thinking budget of {thinking_budget} for Gemini reasoning model {llm_config.model}, this will likely cause tool call failures"
|
|
)
|
|
# For models that require thinking mode (2.5 Pro, 3.x), override with minimum valid budget
|
|
if llm_config.model.startswith("gemini-2.5-pro"):
|
|
thinking_budget = 128
|
|
logger.warning(
|
|
f"Overriding thinking_budget to {thinking_budget} for model {llm_config.model} which requires thinking mode"
|
|
)
|
|
thinking_config = ThinkingConfig(
|
|
thinking_budget=(thinking_budget),
|
|
include_thoughts=(thinking_budget > 1),
|
|
)
|
|
request_data["config"]["thinking_config"] = thinking_config.model_dump()
|
|
|
|
return request_data
|
|
|
|
def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics:
|
|
"""Extract usage statistics from Gemini response and return as LettaUsageStatistics."""
|
|
if not response_data:
|
|
return LettaUsageStatistics()
|
|
|
|
response = GenerateContentResponse(**response_data)
|
|
if not response.usage_metadata:
|
|
return LettaUsageStatistics()
|
|
|
|
cached_tokens = None
|
|
if (
|
|
hasattr(response.usage_metadata, "cached_content_token_count")
|
|
and response.usage_metadata.cached_content_token_count is not None
|
|
):
|
|
cached_tokens = response.usage_metadata.cached_content_token_count
|
|
|
|
reasoning_tokens = None
|
|
if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None:
|
|
reasoning_tokens = response.usage_metadata.thoughts_token_count
|
|
|
|
return LettaUsageStatistics(
|
|
prompt_tokens=response.usage_metadata.prompt_token_count or 0,
|
|
completion_tokens=response.usage_metadata.candidates_token_count or 0,
|
|
total_tokens=response.usage_metadata.total_token_count or 0,
|
|
cached_input_tokens=cached_tokens,
|
|
reasoning_tokens=reasoning_tokens,
|
|
)
|
|
|
|
@trace_method
|
|
async def convert_response_to_chat_completion(
|
|
self,
|
|
response_data: dict,
|
|
input_messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
) -> ChatCompletionResponse:
|
|
"""
|
|
Converts custom response format from llm client into an OpenAI
|
|
ChatCompletionsResponse object.
|
|
|
|
Example:
|
|
{
|
|
"candidates": [
|
|
{
|
|
"content": {
|
|
"parts": [
|
|
{
|
|
"text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14."
|
|
}
|
|
]
|
|
}
|
|
}
|
|
],
|
|
"usageMetadata": {
|
|
"promptTokenCount": 9,
|
|
"candidatesTokenCount": 27,
|
|
"totalTokenCount": 36
|
|
}
|
|
}
|
|
"""
|
|
response = GenerateContentResponse(**response_data)
|
|
try:
|
|
choices = []
|
|
index = 0
|
|
for candidate in response.candidates:
|
|
content = candidate.content
|
|
|
|
if content is None or content.role is None or content.parts is None:
|
|
# This means the response is malformed like MALFORMED_FUNCTION_CALL
|
|
if candidate.finish_reason == "MALFORMED_FUNCTION_CALL":
|
|
raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}")
|
|
else:
|
|
raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}")
|
|
|
|
role = content.role
|
|
assert role == "model", f"Unknown role in response: {role}"
|
|
|
|
parts = content.parts
|
|
|
|
# NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices),
|
|
# so let's disable it for now
|
|
|
|
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
|
|
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
|
|
# To patch this, if we have multiple parts we can take the last one
|
|
# Unless parallel tool calling is enabled, in which case multiple parts may be intentional
|
|
if len(parts) > 1 and not llm_config.enable_reasoner and not llm_config.parallel_tool_calls:
|
|
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
|
|
# only truncate if reasoning is off and parallel tool calls are disabled
|
|
parts = [parts[-1]]
|
|
|
|
# TODO support parts / multimodal
|
|
# Parallel tool calling is now supported when llm_config.parallel_tool_calls is enabled
|
|
openai_response_message = None
|
|
for response_message in parts:
|
|
# Convert the actual message style to OpenAI style
|
|
if response_message.function_call:
|
|
function_call = response_message.function_call
|
|
function_name = function_call.name
|
|
function_args = function_call.args
|
|
assert isinstance(function_args, dict), function_args
|
|
|
|
# TODO this is kind of funky - really, we should be passing 'native_content' as a kwarg to fork behavior
|
|
inner_thoughts = response_message.text
|
|
if llm_config.put_inner_thoughts_in_kwargs:
|
|
# NOTE: this also involves stripping the inner monologue out of the function
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
|
|
|
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
|
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
|
)
|
|
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
|
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
|
else:
|
|
pass
|
|
# inner_thoughts = None
|
|
# inner_thoughts = response_message.text
|
|
|
|
# Google AI API doesn't generate tool call IDs
|
|
tool_call = ToolCall(
|
|
id=get_tool_call_id(),
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
|
),
|
|
)
|
|
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
content=inner_thoughts,
|
|
tool_calls=[tool_call],
|
|
)
|
|
else:
|
|
openai_response_message.content = inner_thoughts
|
|
if openai_response_message.tool_calls is None:
|
|
openai_response_message.tool_calls = []
|
|
openai_response_message.tool_calls.append(tool_call)
|
|
if response_message.thought_signature:
|
|
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
|
|
openai_response_message.reasoning_content_signature = thought_signature
|
|
|
|
else:
|
|
if response_message.thought:
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
reasoning_content=response_message.text,
|
|
)
|
|
else:
|
|
openai_response_message.reasoning_content = response_message.text
|
|
try:
|
|
# Structured output tool call
|
|
function_call = json_loads(response_message.text)
|
|
|
|
# Access dict keys - will raise TypeError/KeyError if not a dict or missing keys
|
|
function_name = function_call["name"]
|
|
function_args = function_call["args"]
|
|
assert isinstance(function_args, dict), function_args
|
|
|
|
# NOTE: this also involves stripping the inner monologue out of the function
|
|
if llm_config.put_inner_thoughts_in_kwargs:
|
|
from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX
|
|
|
|
assert INNER_THOUGHTS_KWARG_VERTEX in function_args, (
|
|
f"Couldn't find inner thoughts in function args:\n{function_call}"
|
|
)
|
|
inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX)
|
|
assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}"
|
|
else:
|
|
inner_thoughts = None
|
|
|
|
# Google AI API doesn't generate tool call IDs
|
|
tool_call = ToolCall(
|
|
id=get_tool_call_id(),
|
|
type="function",
|
|
function=FunctionCall(
|
|
name=function_name,
|
|
arguments=clean_json_string_extra_backslash(json_dumps(function_args)),
|
|
),
|
|
)
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
content=inner_thoughts,
|
|
tool_calls=[tool_call],
|
|
)
|
|
else:
|
|
openai_response_message.content = inner_thoughts
|
|
if openai_response_message.tool_calls is None:
|
|
openai_response_message.tool_calls = []
|
|
openai_response_message.tool_calls.append(tool_call)
|
|
|
|
except (json.decoder.JSONDecodeError, ValueError, TypeError, KeyError, AssertionError) as e:
|
|
if candidate.finish_reason == "MAX_TOKENS":
|
|
raise LLMServerError("Could not parse response data from LLM: exceeded max token limit")
|
|
# Log the parsing error for debugging
|
|
logger.warning(
|
|
f"Failed to parse structured output from LLM response: {e}. Response text: {response_message.text[:500]}"
|
|
)
|
|
# Inner thoughts are the content by default
|
|
inner_thoughts = response_message.text
|
|
|
|
# Google AI API doesn't generate tool call IDs
|
|
if openai_response_message is None:
|
|
openai_response_message = Message(
|
|
role="assistant", # NOTE: "model" -> "assistant"
|
|
content=inner_thoughts,
|
|
)
|
|
else:
|
|
openai_response_message.content = inner_thoughts
|
|
if response_message.thought_signature:
|
|
thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8")
|
|
openai_response_message.reasoning_content_signature = thought_signature
|
|
|
|
# Google AI API uses different finish reason strings than OpenAI
|
|
# OpenAI: 'stop', 'length', 'function_call', 'content_filter', null
|
|
# see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api
|
|
# Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER
|
|
# see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason
|
|
finish_reason = candidate.finish_reason.value
|
|
if finish_reason == "STOP":
|
|
openai_finish_reason = (
|
|
"function_call"
|
|
if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0
|
|
else "stop"
|
|
)
|
|
elif finish_reason == "MAX_TOKENS":
|
|
openai_finish_reason = "length"
|
|
elif finish_reason == "SAFETY":
|
|
openai_finish_reason = "content_filter"
|
|
elif finish_reason == "RECITATION":
|
|
openai_finish_reason = "content_filter"
|
|
else:
|
|
raise LLMServerError(f"Unrecognized finish reason in Google AI response: {finish_reason}")
|
|
|
|
choices.append(
|
|
Choice(
|
|
finish_reason=openai_finish_reason,
|
|
index=index,
|
|
message=openai_response_message,
|
|
)
|
|
)
|
|
index += 1
|
|
|
|
# if len(choices) > 1:
|
|
# raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})")
|
|
|
|
# NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist?
|
|
# "usageMetadata": {
|
|
# "promptTokenCount": 9,
|
|
# "candidatesTokenCount": 27,
|
|
# "totalTokenCount": 36
|
|
# }
|
|
usage = None
|
|
if response.usage_metadata:
|
|
# Extract usage via centralized method
|
|
from letta.schemas.enums import ProviderType
|
|
|
|
usage = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.google_ai)
|
|
else:
|
|
# Count it ourselves using the Gemini token counting API
|
|
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
|
|
google_messages = PydanticMessage.to_google_dicts_from_list(input_messages, current_model=llm_config.model)
|
|
prompt_tokens = await self.count_tokens(messages=google_messages, model=llm_config.model)
|
|
# For completion tokens, wrap the response content in Google format
|
|
completion_content = [{"role": "model", "parts": [{"text": json_dumps(openai_response_message.model_dump())}]}]
|
|
completion_tokens = await self.count_tokens(messages=completion_content, model=llm_config.model)
|
|
total_tokens = prompt_tokens + completion_tokens
|
|
usage = UsageStatistics(
|
|
prompt_tokens=prompt_tokens,
|
|
completion_tokens=completion_tokens,
|
|
total_tokens=total_tokens,
|
|
)
|
|
|
|
response_id = str(uuid.uuid4())
|
|
return ChatCompletionResponse(
|
|
id=response_id,
|
|
choices=choices,
|
|
model=llm_config.model, # NOTE: Google API doesn't pass back model in the response
|
|
created=get_utc_time_int(),
|
|
usage=usage,
|
|
)
|
|
except KeyError as e:
|
|
raise e
|
|
|
|
def get_function_call_response_schema(self, tool: dict) -> dict:
|
|
return {
|
|
"type": "OBJECT",
|
|
"properties": {
|
|
"name": {"type": "STRING", "enum": [tool["name"]]},
|
|
"args": {
|
|
"type": "OBJECT",
|
|
"properties": tool["parameters"]["properties"],
|
|
"required": tool["parameters"]["required"],
|
|
},
|
|
},
|
|
"propertyOrdering": ["name", "args"],
|
|
"required": ["name", "args"],
|
|
}
|
|
|
|
# https://ai.google.dev/gemini-api/docs/thinking#set-budget
|
|
# | Model | Default setting | Range | Disable thinking | Turn on dynamic thinking|
|
|
# |-----------------|-------------------------------------------------------------------|--------------|----------------------------|-------------------------|
|
|
# | 2.5 Pro | Dynamic thinking: Model decides when and how much to think | 128-32768 | N/A: Cannot disable | thinkingBudget = -1 |
|
|
# | 2.5 Flash | Dynamic thinking: Model decides when and how much to think | 0-24576 | thinkingBudget = 0 | thinkingBudget = -1 |
|
|
# | 2.5 Flash Lite | Model does not think | 512-24576 | thinkingBudget = 0 | thinkingBudget = -1 |
|
|
# | 3.x | Dynamic thinking: Model decides when and how much to think | 128-? | N/A: Cannot disable | thinkingBudget = -1 |
|
|
def get_thinking_budget(self, model: str) -> bool:
|
|
if model_settings.gemini_force_minimum_thinking_budget:
|
|
if all(substring in model for substring in ["2.5", "flash", "lite"]):
|
|
return 512
|
|
elif all(substring in model for substring in ["2.5", "flash"]):
|
|
return 1
|
|
# Gemini 3 and 2.5 Pro require thinking mode and cannot have budget 0
|
|
if model.startswith("gemini-3") or model.startswith("gemini-2.5-pro"):
|
|
return 128 # Minimum valid budget for models that require thinking
|
|
return 0
|
|
|
|
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
|
return (
|
|
llm_config.model.startswith("gemini-2.5-flash")
|
|
or llm_config.model.startswith("gemini-2.5-pro")
|
|
or llm_config.model.startswith("gemini-3")
|
|
)
|
|
|
|
def is_malformed_function_call(self, response_data: dict) -> dict:
|
|
response = GenerateContentResponse(**response_data)
|
|
for candidate in response.candidates:
|
|
content = candidate.content
|
|
if content is None or content.role is None or content.parts is None:
|
|
return candidate.finish_reason == "MALFORMED_FUNCTION_CALL"
|
|
return False
|
|
|
|
@trace_method
|
|
def handle_llm_error(self, e: Exception) -> Exception:
|
|
# Handle Google GenAI specific errors
|
|
if isinstance(e, errors.ClientError):
|
|
logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}")
|
|
|
|
# Handle specific error codes
|
|
if e.code == 400:
|
|
error_str = str(e).lower()
|
|
if "context" in error_str and ("exceed" in error_str or "limit" in error_str or "too long" in error_str):
|
|
return ContextWindowExceededError(
|
|
message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}",
|
|
)
|
|
else:
|
|
return LLMBadRequestError(
|
|
message=f"Bad request to {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
elif e.code == 401:
|
|
return LLMAuthenticationError(
|
|
message=f"Authentication failed with {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
elif e.code == 403:
|
|
return LLMPermissionDeniedError(
|
|
message=f"Permission denied by {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
elif e.code == 404:
|
|
return LLMNotFoundError(
|
|
message=f"Resource not found in {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
elif e.code == 408:
|
|
return LLMTimeoutError(
|
|
message=f"Request to {self._provider_name()} timed out: {str(e)}",
|
|
code=ErrorCode.TIMEOUT,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
|
)
|
|
elif e.code == 422:
|
|
return LLMUnprocessableEntityError(
|
|
message=f"Invalid request content for {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
)
|
|
elif e.code == 429:
|
|
logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.")
|
|
return LLMRateLimitError(
|
|
message=f"Rate limited by {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.RATE_LIMIT_EXCEEDED,
|
|
)
|
|
else:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} client error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
},
|
|
)
|
|
|
|
if isinstance(e, errors.ServerError):
|
|
logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}")
|
|
|
|
# Handle specific server error codes
|
|
if e.code == 500:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} internal server error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
},
|
|
)
|
|
elif e.code == 502:
|
|
return LLMConnectionError(
|
|
message=f"Bad gateway from {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
|
)
|
|
elif e.code == 503:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} service unavailable: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
},
|
|
)
|
|
elif e.code == 504:
|
|
return LLMTimeoutError(
|
|
message=f"Gateway timeout from {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.TIMEOUT,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
|
)
|
|
else:
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} server error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
},
|
|
)
|
|
|
|
if isinstance(e, errors.APIError):
|
|
logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}")
|
|
return LLMServerError(
|
|
message=f"{self._provider_name()} API error: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={
|
|
"status_code": e.code,
|
|
"response_json": getattr(e, "response_json", None),
|
|
},
|
|
)
|
|
|
|
# Handle httpx.RemoteProtocolError which can occur during streaming
|
|
# when the remote server closes the connection unexpectedly
|
|
# (e.g., "peer closed connection without sending complete message body")
|
|
if isinstance(e, httpx.RemoteProtocolError):
|
|
logger.warning(f"{self._provider_prefix()} Remote protocol error during streaming: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Connection error during {self._provider_name()} streaming: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
|
)
|
|
|
|
# Handle httpx network errors which can occur during streaming
|
|
# when the connection is unexpectedly closed while reading/writing
|
|
if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)):
|
|
logger.warning(f"{self._provider_prefix()} Network error during streaming: {type(e).__name__}: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Network error during {self._provider_name()} streaming: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__},
|
|
)
|
|
|
|
# Handle connection-related errors
|
|
if "connection" in str(e).lower() or "timeout" in str(e).lower():
|
|
logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}")
|
|
return LLMConnectionError(
|
|
message=f"Failed to connect to {self._provider_name()}: {str(e)}",
|
|
code=ErrorCode.INTERNAL_SERVER_ERROR,
|
|
details={"cause": str(e.__cause__) if e.__cause__ else None},
|
|
)
|
|
|
|
# Fallback to base implementation for other errors
|
|
return super().handle_llm_error(e)
|
|
|
|
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
|
"""
|
|
Count tokens for the given messages and tools using the Gemini token counting API.
|
|
|
|
Args:
|
|
messages: List of message dicts in Google AI format (with 'role' and 'parts' keys)
|
|
model: The model to use for token counting (defaults to gemini-2.0-flash-lite)
|
|
tools: List of OpenAI-style Tool objects to include in the count
|
|
|
|
Returns:
|
|
The total token count for the input
|
|
"""
|
|
from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK
|
|
|
|
client = self._get_client()
|
|
|
|
# Default model for token counting if not specified
|
|
count_model = model or GOOGLE_MODEL_FOR_API_KEY_CHECK
|
|
|
|
# Build the contents parameter
|
|
# If no messages provided, use empty string (like the API key check)
|
|
if messages is None or len(messages) == 0:
|
|
contents = ""
|
|
else:
|
|
# Messages should already be in Google format (role + parts)
|
|
contents = messages
|
|
|
|
try:
|
|
# Count message tokens
|
|
result = await client.aio.models.count_tokens(
|
|
model=count_model,
|
|
contents=contents,
|
|
)
|
|
total_tokens = result.total_tokens
|
|
|
|
# Count tool tokens separately by serializing to text
|
|
# The Gemini count_tokens API doesn't support a tools parameter directly
|
|
if tools and len(tools) > 0:
|
|
# Serialize tools to JSON text and count those tokens
|
|
tools_text = json.dumps([t.model_dump() for t in tools])
|
|
tools_result = await client.aio.models.count_tokens(
|
|
model=count_model,
|
|
contents=tools_text,
|
|
)
|
|
total_tokens += tools_result.total_tokens
|
|
|
|
except Exception as e:
|
|
raise self.handle_llm_error(e)
|
|
|
|
return total_tokens
|