import base64 import copy import json import uuid from typing import AsyncIterator, List, Optional import httpx import pydantic_core from google.genai import Client, errors from google.genai.types import ( FunctionCallingConfig, FunctionCallingConfigMode, GenerateContentResponse, HttpOptions, ThinkingConfig, ToolConfig, ) from letta.constants import NON_USER_MSG_PREFIX from letta.errors import ( ContextWindowExceededError, ErrorCode, LLMAuthenticationError, LLMBadRequestError, LLMConnectionError, LLMInsufficientCreditsError, LLMNotFoundError, LLMPermissionDeniedError, LLMRateLimitError, LLMServerError, LLMTimeoutError, LLMUnprocessableEntityError, ) from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps, json_loads, sanitize_unicode_surrogates from letta.llm_api.error_utils import is_insufficient_credits_message from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.json_parser import clean_json_string_extra_backslash from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentType from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics from letta.schemas.usage import LettaUsageStatistics from letta.settings import model_settings, settings from letta.utils import get_tool_call_id logger = get_logger(__name__) class GoogleVertexClient(LLMClientBase): MAX_RETRIES = model_settings.gemini_max_retries provider_label = "Google Vertex" def _get_client(self, llm_config: Optional[LLMConfig] = None): timeout_ms = int(settings.llm_request_timeout_seconds * 1000) if llm_config: api_key, _, _ = self.get_byok_overrides(llm_config) if api_key: return Client( api_key=api_key, http_options=HttpOptions(timeout=timeout_ms), ) return Client( vertexai=True, project=model_settings.google_cloud_project, location=model_settings.google_cloud_location, http_options=HttpOptions(api_version="v1", timeout=timeout_ms), ) async def _get_client_async(self, llm_config: Optional[LLMConfig] = None): timeout_ms = int(settings.llm_request_timeout_seconds * 1000) if llm_config: api_key, _, _ = await self.get_byok_overrides_async(llm_config) if api_key: return Client( api_key=api_key, http_options=HttpOptions(timeout=timeout_ms), ) return Client( vertexai=True, project=model_settings.google_cloud_project, location=model_settings.google_cloud_location, http_options=HttpOptions(api_version="v1", timeout=timeout_ms), ) def _provider_prefix(self) -> str: return f"[{self.provider_label}]" def _provider_name(self) -> str: return self.provider_label @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. """ try: client = self._get_client(llm_config) response = client.models.generate_content( model=llm_config.model, contents=request_data["contents"], config=request_data["config"], ) return response.model_dump() except pydantic_core._pydantic_core.ValidationError as e: # Handle Pydantic validation errors from the Google SDK # This occurs when tool schemas contain unsupported fields logger.error( f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}" ) raise LLMBadRequestError( message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. " f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. " f"Please check your tool definitions. Error: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, ) except Exception as e: raise self.handle_llm_error(e, llm_config=llm_config) @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying request to llm and returns raw response. """ request_data = sanitize_unicode_surrogates(request_data) client = await self._get_client_async(llm_config) # Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry # https://github.com/googleapis/python-aiplatform/issues/4472 retry_count = 1 should_retry = True response_data = None while should_retry and retry_count <= self.MAX_RETRIES: try: response = await client.aio.models.generate_content( model=llm_config.model, contents=request_data["contents"], config=request_data["config"], ) except pydantic_core._pydantic_core.ValidationError as e: # Handle Pydantic validation errors from the Google SDK # This occurs when tool schemas contain unsupported fields logger.error( f"Pydantic validation error when calling {self._provider_name()} API. " f"Tool schema contains unsupported fields. Error: {e}" ) raise LLMBadRequestError( message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. " f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. " f"Please check your tool definitions. Error: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, ) except errors.APIError as e: # Retry on 503 and 500 errors as well, usually ephemeral from Gemini if e.code == 503 or e.code == 500 or e.code == 504: logger.warning(f"Received {e}, retrying {retry_count}/{self.MAX_RETRIES}") retry_count += 1 if retry_count > self.MAX_RETRIES: raise self.handle_llm_error(e, llm_config=llm_config) continue raise self.handle_llm_error(e, llm_config=llm_config) except Exception as e: raise self.handle_llm_error(e, llm_config=llm_config) response_data = response.model_dump() is_malformed_function_call = self.is_malformed_function_call(response_data) if is_malformed_function_call: logger.warning( f"Received FinishReason.MALFORMED_FUNCTION_CALL in response for {llm_config.model}, retrying {retry_count}/{self.MAX_RETRIES}" ) # Modify the last message if it's a heartbeat to include warning about special characters if request_data["contents"] and len(request_data["contents"]) > 0: last_message = request_data["contents"][-1] if last_message.get("role") == "user" and last_message.get("parts"): for part in last_message["parts"]: if "text" in part: try: # Try to parse as JSON to check if it's a heartbeat message_json = json_loads(part["text"]) if message_json.get("type") == "heartbeat" and "reason" in message_json: # Append warning to the reason warning = f" RETRY {retry_count}/{self.MAX_RETRIES} ***DO NOT USE SPECIAL CHARACTERS OR QUOTATIONS INSIDE FUNCTION CALL ARGUMENTS. IF YOU MUST, MAKE SURE TO ESCAPE THEM PROPERLY***" message_json["reason"] = message_json["reason"] + warning # Update the text with modified JSON part["text"] = json_dumps(message_json) logger.warning( f"Modified heartbeat message with special character warning for retry {retry_count}/{self.MAX_RETRIES}" ) except (json.JSONDecodeError, TypeError, AttributeError): # Not a JSON message or not a heartbeat, skip modification pass should_retry = is_malformed_function_call retry_count += 1 if response_data is None: raise RuntimeError("Failed to get response data after all retries") return response_data @trace_method async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncIterator[GenerateContentResponse]: request_data = sanitize_unicode_surrogates(request_data) client = await self._get_client_async(llm_config) try: response = await client.aio.models.generate_content_stream( model=llm_config.model, contents=request_data["contents"], config=request_data["config"], ) except pydantic_core._pydantic_core.ValidationError as e: # Handle Pydantic validation errors from the Google SDK # This occurs when tool schemas contain unsupported fields logger.error( f"Pydantic validation error when calling {self._provider_name()} API. Tool schema contains unsupported fields. Error: {e}" ) raise LLMBadRequestError( message=f"Invalid tool schema for {self._provider_name()}: Tool parameters contain unsupported fields. " f"Common issues: 'const', 'default', 'additionalProperties' are not supported by Google AI. " f"Please check your tool definitions. Error: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, ) except errors.APIError as e: raise self.handle_llm_error(e) except Exception as e: logger.error(f"Error streaming {self._provider_name()} request: {e} with request data: {json.dumps(request_data)}") raise e # Direct yield - keeps response alive in generator's local scope throughout iteration # This is required because the SDK's connection lifecycle is tied to the response object try: async for chunk in response: yield chunk except errors.ClientError as e: if e.code == 499: logger.info(f"{self._provider_prefix()} Stream cancelled by client (499): {e}") return raise self.handle_llm_error(e, llm_config=llm_config) except errors.APIError as e: raise self.handle_llm_error(e, llm_config=llm_config) @staticmethod def add_dummy_model_messages(messages: List[dict]) -> List[dict]: """Google AI API requires all function call returns are immediately followed by a 'model' role message. In Letta, the 'model' will often call a function (e.g. send_message) that itself yields to the user, so there is no natural follow-up 'model' role message. To satisfy the Google AI API restrictions, we can add a dummy 'yield' message with role == 'model' that is placed in-betweeen and function output (role == 'tool') and user message (role == 'user'). """ dummy_yield_message = { "role": "model", "parts": [{"text": f"{NON_USER_MSG_PREFIX}Function call returned, waiting for user response."}], } messages_with_padding = [] for i, message in enumerate(messages): messages_with_padding.append(message) # Check if the current message role is 'tool' and the next message role is 'user' if message["role"] in ["tool", "function"] and (i + 1 < len(messages) and messages[i + 1]["role"] == "user"): messages_with_padding.append(dummy_yield_message) return messages_with_padding def _clean_google_ai_schema_properties(self, schema_part: dict): """Recursively clean schema parts to remove unsupported Google AI keywords.""" if not isinstance(schema_part, dict): return # Per https://ai.google.dev/gemini-api/docs/function-calling?example=meeting#notes_and_limitations # * Only a subset of the OpenAPI schema is supported. # * Supported parameter types in Python are limited. unsupported_keys = ["default", "exclusiveMaximum", "exclusiveMinimum", "additionalProperties", "$schema", "const", "$ref"] keys_to_remove_at_this_level = [key for key in unsupported_keys if key in schema_part] for key_to_remove in keys_to_remove_at_this_level: logger.debug(f"Removing unsupported keyword '{key_to_remove}' from schema part.") del schema_part[key_to_remove] if schema_part.get("type") == "string" and "format" in schema_part: allowed_formats = ["enum", "date-time"] if schema_part["format"] not in allowed_formats: logger.warning(f"Removing unsupported format '{schema_part['format']}' for string type. Allowed: {allowed_formats}") del schema_part["format"] # Check properties within the current level if "properties" in schema_part and isinstance(schema_part["properties"], dict): for prop_name, prop_schema in schema_part["properties"].items(): self._clean_google_ai_schema_properties(prop_schema) # Check items within arrays if "items" in schema_part and isinstance(schema_part["items"], dict): self._clean_google_ai_schema_properties(schema_part["items"]) # Check within anyOf, allOf, oneOf lists for key in ["anyOf", "allOf", "oneOf"]: if key in schema_part and isinstance(schema_part[key], list): for item_schema in schema_part[key]: self._clean_google_ai_schema_properties(item_schema) def _resolve_json_schema_refs(self, schema: dict, defs: dict | None = None) -> dict: """ Recursively resolve $ref in JSON schema by inlining definitions. Google GenAI SDK does not support $ref. """ if defs is None: # Look for definitions at the top level defs = schema.get("$defs") or schema.get("definitions") or {} if not isinstance(schema, dict): return schema # If this is a ref, resolve it if "$ref" in schema: ref = schema["$ref"] if isinstance(ref, str): for prefix in ("#/$defs/", "#/definitions/"): if ref.startswith(prefix): ref_name = ref.split("/")[-1] if ref_name in defs: resolved = defs[ref_name].copy() return self._resolve_json_schema_refs(resolved, defs) break logger.warning(f"Could not resolve $ref '{ref}' in schema — will be stripped by schema cleaner") # Recursively process children new_schema = schema.copy() # We need to remove $defs/definitions from the output schema as Google doesn't support them if "$defs" in new_schema: del new_schema["$defs"] if "definitions" in new_schema: del new_schema["definitions"] for k, v in new_schema.items(): if isinstance(v, dict): new_schema[k] = self._resolve_json_schema_refs(v, defs) elif isinstance(v, list): new_schema[k] = [self._resolve_json_schema_refs(i, defs) if isinstance(i, dict) else i for i in v] return new_schema def convert_tools_to_google_ai_format(self, tools: List[Tool], llm_config: LLMConfig) -> List[dict]: """ OpenAI style: "tools": [{ "type": "function", "function": { "name": "find_movies", "description": "find ....", "parameters": { "type": "object", "properties": { PARAM: { "type": PARAM_TYPE, # eg "string" "description": PARAM_DESCRIPTION, }, ... }, "required": List[str], } } } ] Google AI style: "tools": [{ "functionDeclarations": [{ "name": "find_movies", "description": "find movie titles currently playing in theaters based on any description, genre, title words, etc.", "parameters": { "type": "OBJECT", "properties": { "location": { "type": "STRING", "description": "The city and state, e.g. San Francisco, CA or a zip code e.g. 95616" }, "description": { "type": "STRING", "description": "Any kind of description including category or genre, title words, attributes, etc." } }, "required": ["description"] } }, { "name": "find_theaters", ... """ function_list = [ dict( name=t.function.name, description=t.function.description, # Deep copy parameters to avoid modifying the original Tool object parameters=copy.deepcopy(t.function.parameters) if t.function.parameters else {}, ) for t in tools ] # Add inner thoughts if needed for func in function_list: # Note: Google AI API used to have weird casing requirements, but not any more # Google AI API only supports a subset of OpenAPI 3.0, so unsupported params must be cleaned if "parameters" in func and isinstance(func["parameters"], dict): # Resolve $ref in schema because Google AI SDK doesn't support them func["parameters"] = self._resolve_json_schema_refs(func["parameters"]) self._clean_google_ai_schema_properties(func["parameters"]) # Add inner thoughts if llm_config.put_inner_thoughts_in_kwargs: from letta.local_llm.constants import INNER_THOUGHTS_KWARG_DESCRIPTION, INNER_THOUGHTS_KWARG_VERTEX func["parameters"]["properties"][INNER_THOUGHTS_KWARG_VERTEX] = { "type": "string", "description": INNER_THOUGHTS_KWARG_DESCRIPTION, } func["parameters"]["required"].append(INNER_THOUGHTS_KWARG_VERTEX) return [{"functionDeclarations": function_list}] @trace_method def build_request_data( self, agent_type: AgentType, # if react, use native content + strip heartbeats messages: List[PydanticMessage], llm_config: LLMConfig, tools: List[dict], force_tool_call: Optional[str] = None, requires_subsequent_tool_call: bool = False, tool_return_truncation_chars: Optional[int] = None, ) -> dict: """ Constructs a request object in the expected data format for this client. """ # NOTE: forcing inner thoughts in kwargs off if agent_type == AgentType.letta_v1_agent: llm_config.put_inner_thoughts_in_kwargs = False if tools: tool_objs = [Tool(type="function", function=t) for t in tools] tool_names = [t.function.name for t in tool_objs] # Convert to the exact payload style Google expects formatted_tools = self.convert_tools_to_google_ai_format(tool_objs, llm_config) else: formatted_tools = [] tool_names = [] contents = self.add_dummy_model_messages( PydanticMessage.to_google_dicts_from_list( messages, current_model=llm_config.model, put_inner_thoughts_in_kwargs=False if agent_type == AgentType.letta_v1_agent else True, native_content=True if agent_type == AgentType.letta_v1_agent else False, ), ) request_data = { "contents": contents, "config": { "temperature": llm_config.temperature, "tools": formatted_tools, }, } # Make tokens is optional if llm_config.max_tokens: request_data["config"]["max_output_tokens"] = llm_config.max_tokens if len(tool_names) == 1 and settings.use_vertex_structured_outputs_experimental: request_data["config"]["response_mime_type"] = "application/json" request_data["config"]["response_schema"] = self.get_function_call_response_schema(tools[0]) del request_data["config"]["tools"] elif tools: if agent_type == AgentType.letta_v1_agent: # don't require tools tool_call_mode = FunctionCallingConfigMode.AUTO tool_config = ToolConfig( function_calling_config=FunctionCallingConfig( mode=tool_call_mode, ) ) else: # require tools tool_call_mode = FunctionCallingConfigMode.ANY tool_config = ToolConfig( function_calling_config=FunctionCallingConfig( mode=tool_call_mode, # Provide the list of tools (though empty should also work, it seems not to) allowed_function_names=tool_names, ) ) request_data["config"]["tool_config"] = tool_config.model_dump() # https://ai.google.dev/gemini-api/docs/thinking#set-budget # 2.5 Pro # - Default: dynamic thinking # - Dynamic thinking that cannot be disabled # - Range: -1 (for dynamic), or 128-32768 # 2.5 Flash # - Default: dynamic thinking # - Dynamic thinking that *can* be disabled # - Range: -1, 0, or 0-24576 # 2.5 Flash Lite # - Default: no thinking # - Dynamic thinking that *can* be disabled # - Range: -1, 0, or 512-24576 # TODO when using v3 agent loop, properly support the native thinking in Gemini # Add thinking_config for all Gemini reasoning models (2.5 series) # If enable_reasoner is False, set thinking_budget to 0 # Otherwise, use the value from max_reasoning_tokens if self.is_reasoning_model(llm_config) or "flash" in llm_config.model: if llm_config.model.startswith("gemini-3"): # letting thinking_level to default to high by not specifying thinking_budget thinking_config = ThinkingConfig(include_thoughts=True) else: # Gemini reasoning models may fail to call tools even with FunctionCallingConfigMode.ANY if thinking is fully disabled, set to minimum to prevent tool call failure thinking_budget = ( llm_config.max_reasoning_tokens if llm_config.enable_reasoner else self.get_thinking_budget(llm_config.model) ) if thinking_budget <= 0: logger.warning( f"Thinking budget of {thinking_budget} for Gemini reasoning model {llm_config.model}, this will likely cause tool call failures" ) # For models that require thinking mode (2.5 Pro, 3.x), override with minimum valid budget if llm_config.model.startswith("gemini-2.5-pro"): thinking_budget = 128 logger.warning( f"Overriding thinking_budget to {thinking_budget} for model {llm_config.model} which requires thinking mode" ) thinking_config = ThinkingConfig( thinking_budget=(thinking_budget), include_thoughts=(thinking_budget > 1), ) request_data["config"]["thinking_config"] = thinking_config.model_dump() return request_data def extract_usage_statistics(self, response_data: dict | None, llm_config: LLMConfig) -> LettaUsageStatistics: """Extract usage statistics from Gemini response and return as LettaUsageStatistics.""" if not response_data: return LettaUsageStatistics() response = GenerateContentResponse(**response_data) if not response.usage_metadata: return LettaUsageStatistics() cached_tokens = None if ( hasattr(response.usage_metadata, "cached_content_token_count") and response.usage_metadata.cached_content_token_count is not None ): cached_tokens = response.usage_metadata.cached_content_token_count reasoning_tokens = None if hasattr(response.usage_metadata, "thoughts_token_count") and response.usage_metadata.thoughts_token_count is not None: reasoning_tokens = response.usage_metadata.thoughts_token_count return LettaUsageStatistics( prompt_tokens=response.usage_metadata.prompt_token_count or 0, completion_tokens=response.usage_metadata.candidates_token_count or 0, total_tokens=response.usage_metadata.total_token_count or 0, cached_input_tokens=cached_tokens, reasoning_tokens=reasoning_tokens, ) @trace_method async def convert_response_to_chat_completion( self, response_data: dict, input_messages: List[PydanticMessage], llm_config: LLMConfig, ) -> ChatCompletionResponse: """ Converts custom response format from llm client into an OpenAI ChatCompletionsResponse object. Example: { "candidates": [ { "content": { "parts": [ { "text": " OK. Barbie is showing in two theaters in Mountain View, CA: AMC Mountain View 16 and Regal Edwards 14." } ] } } ], "usageMetadata": { "promptTokenCount": 9, "candidatesTokenCount": 27, "totalTokenCount": 36 } } """ response = GenerateContentResponse(**response_data) try: choices = [] index = 0 for candidate in response.candidates: content = candidate.content if content is None or content.role is None or content.parts is None: # This means the response is malformed like MALFORMED_FUNCTION_CALL if candidate.finish_reason == "MALFORMED_FUNCTION_CALL": raise LLMServerError(f"Malformed response from {self._provider_name()}: {candidate.finish_reason}") else: raise LLMServerError(f"Invalid response data from {self._provider_name()}: {candidate.model_dump()}") role = content.role assert role == "model", f"Unknown role in response: {role}" parts = content.parts # NOTE: we aren't properly supported multi-parts here anyways (we're just appending choices), # so let's disable it for now # NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text # {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'} # To patch this, if we have multiple parts we can take the last one # Unless parallel tool calling is enabled, in which case multiple parts may be intentional if len(parts) > 1 and not llm_config.enable_reasoner and not llm_config.parallel_tool_calls: logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}") # only truncate if reasoning is off and parallel tool calls are disabled parts = [parts[-1]] # TODO support parts / multimodal # Parallel tool calling is now supported when llm_config.parallel_tool_calls is enabled openai_response_message = None for response_message in parts: # Convert the actual message style to OpenAI style if response_message.function_call: function_call = response_message.function_call function_name = function_call.name function_args = function_call.args assert isinstance(function_args, dict), function_args # TODO this is kind of funky - really, we should be passing 'native_content' as a kwarg to fork behavior inner_thoughts = response_message.text if llm_config.put_inner_thoughts_in_kwargs: # NOTE: this also involves stripping the inner monologue out of the function from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX assert INNER_THOUGHTS_KWARG_VERTEX in function_args, ( f"Couldn't find inner thoughts in function args:\n{function_call}" ) inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX) assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}" else: pass # inner_thoughts = None # inner_thoughts = response_message.text # Google AI API doesn't generate tool call IDs tool_call = ToolCall( id=get_tool_call_id(), type="function", function=FunctionCall( name=function_name, arguments=clean_json_string_extra_backslash(json_dumps(function_args)), ), ) if openai_response_message is None: openai_response_message = Message( role="assistant", # NOTE: "model" -> "assistant" content=inner_thoughts, tool_calls=[tool_call], ) if response_message.thought_signature: thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8") openai_response_message.reasoning_content_signature = thought_signature else: openai_response_message.content = inner_thoughts if openai_response_message.tool_calls is None: openai_response_message.tool_calls = [] openai_response_message.tool_calls.append(tool_call) if response_message.thought_signature: thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8") openai_response_message.reasoning_content_signature = thought_signature else: if response_message.thought: if openai_response_message is None: openai_response_message = Message( role="assistant", # NOTE: "model" -> "assistant" reasoning_content=response_message.text, ) else: openai_response_message.reasoning_content = response_message.text try: # Structured output tool call function_call = json_loads(response_message.text) # Access dict keys - will raise TypeError/KeyError if not a dict or missing keys function_name = function_call["name"] function_args = function_call["args"] assert isinstance(function_args, dict), function_args # NOTE: this also involves stripping the inner monologue out of the function if llm_config.put_inner_thoughts_in_kwargs: from letta.local_llm.constants import INNER_THOUGHTS_KWARG_VERTEX assert INNER_THOUGHTS_KWARG_VERTEX in function_args, ( f"Couldn't find inner thoughts in function args:\n{function_call}" ) inner_thoughts = function_args.pop(INNER_THOUGHTS_KWARG_VERTEX) assert inner_thoughts is not None, f"Expected non-null inner thoughts function arg:\n{function_call}" else: inner_thoughts = None # Google AI API doesn't generate tool call IDs tool_call = ToolCall( id=get_tool_call_id(), type="function", function=FunctionCall( name=function_name, arguments=clean_json_string_extra_backslash(json_dumps(function_args)), ), ) if openai_response_message is None: openai_response_message = Message( role="assistant", # NOTE: "model" -> "assistant" content=inner_thoughts, tool_calls=[tool_call], ) else: openai_response_message.content = inner_thoughts if openai_response_message.tool_calls is None: openai_response_message.tool_calls = [] openai_response_message.tool_calls.append(tool_call) except (json.decoder.JSONDecodeError, ValueError, TypeError, KeyError, AssertionError) as e: if candidate.finish_reason == "MAX_TOKENS": raise LLMServerError("Could not parse response data from LLM: exceeded max token limit") # Log the parsing error for debugging logger.warning( f"Failed to parse structured output from LLM response: {e}. Response text: {response_message.text[:500]}" ) # Inner thoughts are the content by default inner_thoughts = response_message.text # Google AI API doesn't generate tool call IDs if openai_response_message is None: openai_response_message = Message( role="assistant", # NOTE: "model" -> "assistant" content=inner_thoughts, ) else: openai_response_message.content = inner_thoughts if response_message.thought_signature: thought_signature = base64.b64encode(response_message.thought_signature).decode("utf-8") openai_response_message.reasoning_content_signature = thought_signature # Google AI API uses different finish reason strings than OpenAI # OpenAI: 'stop', 'length', 'function_call', 'content_filter', null # see: https://platform.openai.com/docs/guides/text-generation/chat-completions-api # Google AI API: FINISH_REASON_UNSPECIFIED, STOP, MAX_TOKENS, SAFETY, RECITATION, OTHER # see: https://ai.google.dev/api/python/google/ai/generativelanguage/Candidate/FinishReason finish_reason = candidate.finish_reason.value if finish_reason == "STOP": openai_finish_reason = ( "function_call" if openai_response_message.tool_calls is not None and len(openai_response_message.tool_calls) > 0 else "stop" ) elif finish_reason == "MAX_TOKENS": openai_finish_reason = "length" elif finish_reason == "SAFETY": openai_finish_reason = "content_filter" elif finish_reason == "RECITATION": openai_finish_reason = "content_filter" else: raise LLMServerError(f"Unrecognized finish reason in Google AI response: {finish_reason}") choices.append( Choice( finish_reason=openai_finish_reason, index=index, message=openai_response_message, ) ) index += 1 # if len(choices) > 1: # raise UserWarning(f"Unexpected number of candidates in response (expected 1, got {len(choices)})") # NOTE: some of the Google AI APIs show UsageMetadata in the response, but it seems to not exist? # "usageMetadata": { # "promptTokenCount": 9, # "candidatesTokenCount": 27, # "totalTokenCount": 36 # } usage = None if response.usage_metadata: # Extract usage via centralized method from letta.schemas.enums import ProviderType usage = self.extract_usage_statistics(response_data, llm_config).to_usage(ProviderType.google_ai) else: # Count it ourselves using the Gemini token counting API assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required" google_messages = PydanticMessage.to_google_dicts_from_list(input_messages, current_model=llm_config.model) prompt_tokens = await self.count_tokens(messages=google_messages, model=llm_config.model) # For completion tokens, wrap the response content in Google format completion_content = [{"role": "model", "parts": [{"text": json_dumps(openai_response_message.model_dump())}]}] completion_tokens = await self.count_tokens(messages=completion_content, model=llm_config.model) total_tokens = prompt_tokens + completion_tokens usage = UsageStatistics( prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, total_tokens=total_tokens, ) response_id = str(uuid.uuid4()) return ChatCompletionResponse( id=response_id, choices=choices, model=llm_config.model, # NOTE: Google API doesn't pass back model in the response created=get_utc_time_int(), usage=usage, ) except KeyError as e: raise e def get_function_call_response_schema(self, tool: dict) -> dict: return { "type": "OBJECT", "properties": { "name": {"type": "STRING", "enum": [tool["name"]]}, "args": { "type": "OBJECT", "properties": tool["parameters"]["properties"], "required": tool["parameters"]["required"], }, }, "propertyOrdering": ["name", "args"], "required": ["name", "args"], } # https://ai.google.dev/gemini-api/docs/thinking#set-budget # | Model | Default setting | Range | Disable thinking | Turn on dynamic thinking| # |-----------------|-------------------------------------------------------------------|--------------|----------------------------|-------------------------| # | 2.5 Pro | Dynamic thinking: Model decides when and how much to think | 128-32768 | N/A: Cannot disable | thinkingBudget = -1 | # | 2.5 Flash | Dynamic thinking: Model decides when and how much to think | 0-24576 | thinkingBudget = 0 | thinkingBudget = -1 | # | 2.5 Flash Lite | Model does not think | 512-24576 | thinkingBudget = 0 | thinkingBudget = -1 | # | 3.x | Dynamic thinking: Model decides when and how much to think | 128-? | N/A: Cannot disable | thinkingBudget = -1 | def get_thinking_budget(self, model: str) -> bool: if model_settings.gemini_force_minimum_thinking_budget: if all(substring in model for substring in ["2.5", "flash", "lite"]): return 512 elif all(substring in model for substring in ["2.5", "flash"]): return 1 # Gemini 3 and 2.5 Pro require thinking mode and cannot have budget 0 if model.startswith("gemini-3") or model.startswith("gemini-2.5-pro"): return 128 # Minimum valid budget for models that require thinking return 0 def is_reasoning_model(self, llm_config: LLMConfig) -> bool: return ( llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro") or llm_config.model.startswith("gemini-3") ) def is_malformed_function_call(self, response_data: dict) -> dict: response = GenerateContentResponse(**response_data) for candidate in response.candidates: content = candidate.content if content is None or content.role is None or content.parts is None: return candidate.finish_reason == "MALFORMED_FUNCTION_CALL" return False @trace_method def handle_llm_error(self, e: Exception, llm_config: Optional[LLMConfig] = None) -> Exception: is_byok = (llm_config.provider_category == ProviderCategory.byok) if llm_config else None # Handle Google GenAI specific errors if isinstance(e, errors.ClientError): if e.code == 499: logger.info(f"{self._provider_prefix()} Request cancelled by client (499): {e}") return LLMConnectionError( message=f"Request to {self._provider_name()} was cancelled (client disconnected): {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"status_code": 499, "cause": "client_cancelled", "is_byok": is_byok}, ) logger.warning(f"{self._provider_prefix()} Client error ({e.code}): {e}") # Handle specific error codes if e.code == 400: error_str = str(e).lower() if ("context" in error_str or "token count" in error_str or "tokens allowed" in error_str) and ( "exceed" in error_str or "limit" in error_str or "too long" in error_str ): return ContextWindowExceededError( message=f"Bad request to {self._provider_name()} (context window exceeded): {str(e)}", details={"is_byok": is_byok}, ) else: return LLMBadRequestError( message=f"Bad request to {self._provider_name()}: {str(e)}", code=ErrorCode.INVALID_ARGUMENT, details={"is_byok": is_byok}, ) elif e.code == 401: return LLMAuthenticationError( message=f"Authentication failed with {self._provider_name()}: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"is_byok": is_byok}, ) elif e.code == 403: return LLMPermissionDeniedError( message=f"Permission denied by {self._provider_name()}: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"is_byok": is_byok}, ) elif e.code == 404: return LLMNotFoundError( message=f"Resource not found in {self._provider_name()}: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"is_byok": is_byok}, ) elif e.code == 408: return LLMTimeoutError( message=f"Request to {self._provider_name()} timed out: {str(e)}", code=ErrorCode.TIMEOUT, details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok}, ) elif e.code == 402 or is_insufficient_credits_message(str(e)): msg = str(e) return LLMInsufficientCreditsError( message=f"Insufficient credits (BYOK): {msg}" if is_byok else f"Insufficient credits: {msg}", code=ErrorCode.PAYMENT_REQUIRED, details={"status_code": e.code, "is_byok": is_byok}, ) elif e.code == 422: return LLMUnprocessableEntityError( message=f"Invalid request content for {self._provider_name()}: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"is_byok": is_byok}, ) elif e.code == 429: logger.warning(f"{self._provider_prefix()} Rate limited (429). Consider backoff.") return LLMRateLimitError( message=f"Rate limited by {self._provider_name()}: {str(e)}", code=ErrorCode.RATE_LIMIT_EXCEEDED, details={"is_byok": is_byok}, ) else: return LLMServerError( message=f"{self._provider_name()} client error: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={ "status_code": e.code, "response_json": getattr(e, "response_json", None), "is_byok": is_byok, }, ) if isinstance(e, errors.ServerError): logger.warning(f"{self._provider_prefix()} Server error ({e.code}): {e}") # Handle specific server error codes if e.code == 500: return LLMServerError( message=f"{self._provider_name()} internal server error: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={ "status_code": e.code, "response_json": getattr(e, "response_json", None), "is_byok": is_byok, }, ) elif e.code == 502: return LLMConnectionError( message=f"Bad gateway from {self._provider_name()}: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok}, ) elif e.code == 503: return LLMServerError( message=f"{self._provider_name()} service unavailable: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={ "status_code": e.code, "response_json": getattr(e, "response_json", None), "is_byok": is_byok, }, ) elif e.code == 504: return LLMTimeoutError( message=f"Gateway timeout from {self._provider_name()}: {str(e)}", code=ErrorCode.TIMEOUT, details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok}, ) else: return LLMServerError( message=f"{self._provider_name()} server error: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={ "status_code": e.code, "response_json": getattr(e, "response_json", None), "is_byok": is_byok, }, ) if isinstance(e, errors.APIError): logger.warning(f"{self._provider_prefix()} API error ({e.code}): {e}") return LLMServerError( message=f"{self._provider_name()} API error: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={ "status_code": e.code, "response_json": getattr(e, "response_json", None), "is_byok": is_byok, }, ) # Handle httpx.RemoteProtocolError which can occur during streaming # when the remote server closes the connection unexpectedly # (e.g., "peer closed connection without sending complete message body") if isinstance(e, httpx.RemoteProtocolError): logger.warning(f"{self._provider_prefix()} Remote protocol error during streaming: {e}") return LLMConnectionError( message=f"Connection error during {self._provider_name()} streaming: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok}, ) # Handle httpx network errors which can occur during streaming # when the connection is unexpectedly closed while reading/writing if isinstance(e, (httpx.ReadError, httpx.WriteError, httpx.ConnectError)): logger.warning(f"{self._provider_prefix()} Network error during streaming: {type(e).__name__}: {e}") return LLMConnectionError( message=f"Network error during {self._provider_name()} streaming: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"cause": str(e.__cause__) if e.__cause__ else None, "error_type": type(e).__name__, "is_byok": is_byok}, ) # Handle connection-related errors if "connection" in str(e).lower() or "timeout" in str(e).lower(): logger.warning(f"{self._provider_prefix()} Connection/timeout error: {e}") return LLMConnectionError( message=f"Failed to connect to {self._provider_name()}: {str(e)}", code=ErrorCode.INTERNAL_SERVER_ERROR, details={"cause": str(e.__cause__) if e.__cause__ else None, "is_byok": is_byok}, ) # Fallback to base implementation for other errors return super().handle_llm_error(e, llm_config=llm_config) async def count_tokens( self, messages: List[dict] | None = None, model: str | None = None, tools: List[OpenAITool] | None = None ) -> int: """ Count tokens for the given messages and tools using the Gemini token counting API. Args: messages: List of message dicts in Google AI format (with 'role' and 'parts' keys) model: The model to use for token counting (defaults to gemini-2.0-flash-lite) tools: List of OpenAI-style Tool objects to include in the count Returns: The total token count for the input """ from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK client = self._get_client() # Default model for token counting if not specified count_model = model or GOOGLE_MODEL_FOR_API_KEY_CHECK # Build the contents parameter # If no messages provided, use empty string (like the API key check) if messages is None or len(messages) == 0: contents = "" else: # Messages should already be in Google format (role + parts) contents = messages try: # Count message tokens result = await client.aio.models.count_tokens( model=count_model, contents=contents, ) total_tokens = result.total_tokens # Count tool tokens separately by serializing to text # The Gemini count_tokens API doesn't support a tools parameter directly if tools and len(tools) > 0: # Serialize tools to JSON text and count those tokens tools_text = json.dumps([t.model_dump() for t in tools]) tools_result = await client.aio.models.count_tokens( model=count_model, contents=tools_text, ) total_tokens += tools_result.total_tokens except Exception as e: raise self.handle_llm_error(e) return total_tokens