diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index 47671398..f1d8e091 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -7,7 +7,10 @@ from letta.errors import ErrorCode, LLMAuthenticationError, LLMError from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.log import get_logger +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage from letta.settings import model_settings +from letta.tracing import trace_method logger = get_logger(__name__) @@ -17,6 +20,18 @@ class GoogleAIClient(GoogleVertexClient): def _get_client(self): return genai.Client(api_key=model_settings.gemini_api_key) + @trace_method + def build_request_data( + self, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: List[dict], + force_tool_call: Optional[str] = None, + ) -> dict: + request = super().build_request_data(messages, llm_config, tools, force_tool_call) + del request["config"]["thinking_config"] + return request + def get_gemini_endpoint_and_headers( base_url: str, model: Optional[str], api_key: str, key_in_header: bool = True, generate_content: bool = False diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index e8215813..afc80ebd 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -244,11 +244,10 @@ class GoogleVertexClient(LLMClientBase): # Add thinking_config # If enable_reasoner is False, set thinking_budget to 0 # Otherwise, use the value from max_reasoning_tokens - if llm_config.enable_reasoner: - thinking_config = ThinkingConfig( - thinking_budget=llm_config.max_reasoning_tokens, - ) - request_data["config"]["thinking_config"] = thinking_config.model_dump() + thinking_config = ThinkingConfig( + thinking_budget=llm_config.max_reasoning_tokens if llm_config.enable_reasoner else 0, + ) + request_data["config"]["thinking_config"] = thinking_config.model_dump() return request_data