From c9c9e727b8897866b87c6a4db0a2153be91336e2 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Mon, 1 Sep 2025 07:26:13 -0700 Subject: [PATCH] fix: retry on MALFORMED_FUNCTION_CALL for gemini [LET-4089] --------- Co-authored-by: Letta Bot --- letta/llm_api/google_vertex_client.py | 60 +++++++++++++++++++++++--- letta/settings.py | 1 + tests/integration_test_send_message.py | 14 ++++-- 3 files changed, 64 insertions(+), 11 deletions(-) diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 773d9599..ff7a37b2 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -31,6 +31,8 @@ logger = get_logger(__name__) class GoogleVertexClient(LLMClientBase): + MAX_RETRIES = model_settings.gemini_max_retries + def _get_client(self): timeout_ms = int(settings.llm_request_timeout_seconds * 1000) return genai.Client( @@ -59,12 +61,49 @@ class GoogleVertexClient(LLMClientBase): Performs underlying request to llm and returns raw response. """ client = self._get_client() - response = await client.aio.models.generate_content( - model=llm_config.model, - contents=request_data["contents"], - config=request_data["config"], - ) - return response.model_dump() + + # Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry + # https://github.com/googleapis/python-aiplatform/issues/4472 + retry_count = 1 + should_retry = True + while should_retry and retry_count <= self.MAX_RETRIES: + response = await client.aio.models.generate_content( + model=llm_config.model, + contents=request_data["contents"], + config=request_data["config"], + ) + response_data = response.model_dump() + is_malformed_function_call = self.is_malformed_function_call(response_data) + if is_malformed_function_call: + logger.warning( + f"Received FinishReason.MALFORMED_FUNCTION_CALL in response for {llm_config.model}, retrying {retry_count}/{self.MAX_RETRIES}" + ) + # Modify the last message if it's a heartbeat to include warning about special characters + if request_data["contents"] and len(request_data["contents"]) > 0: + last_message = request_data["contents"][-1] + if last_message.get("role") == "user" and last_message.get("parts"): + for part in last_message["parts"]: + if "text" in part: + try: + # Try to parse as JSON to check if it's a heartbeat + message_json = json_loads(part["text"]) + if message_json.get("type") == "heartbeat" and "reason" in message_json: + # Append warning to the reason + warning = f" RETRY {retry_count}/{self.MAX_RETRIES} ***DO NOT USE SPECIAL CHARACTERS OR QUOTATIONS INSIDE FUNCTION CALL ARGUMENTS. IF YOU MUST, MAKE SURE TO ESCAPE THEM PROPERLY***" + message_json["reason"] = message_json["reason"] + warning + # Update the text with modified JSON + part["text"] = json_dumps(message_json) + logger.warning( + f"Modified heartbeat message with special character warning for retry {retry_count}/{self.MAX_RETRIES}" + ) + except (json.JSONDecodeError, TypeError): + # Not a JSON message or not a heartbeat, skip modification + pass + + should_retry = is_malformed_function_call + retry_count += 1 + + return response_data @staticmethod def add_dummy_model_messages(messages: List[dict]) -> List[dict]: @@ -299,7 +338,6 @@ class GoogleVertexClient(LLMClientBase): } } """ - response = GenerateContentResponse(**response_data) try: choices = [] @@ -517,6 +555,14 @@ class GoogleVertexClient(LLMClientBase): def is_reasoning_model(self, llm_config: LLMConfig) -> bool: return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro") + def is_malformed_function_call(self, response_data: dict) -> dict: + response = GenerateContentResponse(**response_data) + for candidate in response.candidates: + content = candidate.content + if content is None or content.role is None or content.parts is None: + return candidate.finish_reason == "MALFORMED_FUNCTION_CALL" + return False + @trace_method def handle_llm_error(self, e: Exception) -> Exception: # Fallback to base implementation diff --git a/letta/settings.py b/letta/settings.py index 1cf8a62e..81b54da6 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -145,6 +145,7 @@ class ModelSettings(BaseSettings): gemini_api_key: Optional[str] = None gemini_base_url: str = "https://generativelanguage.googleapis.com/" gemini_force_minimum_thinking_budget: bool = False + gemini_max_retries: int = 5 # google vertex google_cloud_project: Optional[str] = None diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index a63c6d43..35a50f1b 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -752,10 +752,11 @@ def test_tool_call( response = client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, + request_options={"timeout_in_seconds": 300}, ) except Exception as e: - if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e): - pytest.skip("Skipping test for flash model due to malformed function call from llm") + # if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e): + # pytest.skip("Skipping test for flash model due to malformed function call from llm") raise e assert_tool_call_response(response.messages, llm_config=llm_config) messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) @@ -967,6 +968,7 @@ def test_step_streaming_tool_call( response = client.agents.messages.create_stream( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, + request_options={"timeout_in_seconds": 300}, ) messages = accumulate_chunks(list(response)) assert_tool_call_response(messages, streaming=True, llm_config=llm_config) @@ -1115,6 +1117,7 @@ def test_token_streaming_tool_call( agent_id=agent_state.id, messages=messages_to_send, stream_tokens=True, + request_options={"timeout_in_seconds": 300}, ) verify_token_streaming = ( llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model @@ -1183,6 +1186,7 @@ def test_background_token_streaming_greeting_with_assistant_message( messages=messages_to_send, stream_tokens=True, background=True, + request_options={"timeout_in_seconds": 300}, ) verify_token_streaming = ( llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model @@ -1418,6 +1422,7 @@ def test_async_tool_call( run = client.agents.messages.create_async( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, + request_options={"timeout_in_seconds": 300}, ) run = wait_for_run_completion(client, run.id) @@ -1639,10 +1644,11 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM client.agents.messages.create( agent_id=temp_agent_state.id, messages=[MessageCreate(role="user", content=philosophical_question)], + request_options={"timeout_in_seconds": 300}, ) except Exception as e: - if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e): - pytest.skip("Skipping test for flash model due to malformed function call from llm") + # if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e): + # pytest.skip("Skipping test for flash model due to malformed function call from llm") raise e temp_agent_state = client.agents.retrieve(agent_id=temp_agent_state.id)