fix: retry on MALFORMED_FUNCTION_CALL for gemini [LET-4089]

---------

Co-authored-by: Letta Bot <noreply@letta.com>
This commit is contained in:
jnjpng
2025-09-01 07:26:13 -07:00
committed by GitHub
parent 86b073d726
commit c9c9e727b8
3 changed files with 64 additions and 11 deletions

View File

@@ -31,6 +31,8 @@ logger = get_logger(__name__)
class GoogleVertexClient(LLMClientBase):
MAX_RETRIES = model_settings.gemini_max_retries
def _get_client(self):
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
return genai.Client(
@@ -59,12 +61,49 @@ class GoogleVertexClient(LLMClientBase):
Performs underlying request to llm and returns raw response.
"""
client = self._get_client()
response = await client.aio.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
return response.model_dump()
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
# https://github.com/googleapis/python-aiplatform/issues/4472
retry_count = 1
should_retry = True
while should_retry and retry_count <= self.MAX_RETRIES:
response = await client.aio.models.generate_content(
model=llm_config.model,
contents=request_data["contents"],
config=request_data["config"],
)
response_data = response.model_dump()
is_malformed_function_call = self.is_malformed_function_call(response_data)
if is_malformed_function_call:
logger.warning(
f"Received FinishReason.MALFORMED_FUNCTION_CALL in response for {llm_config.model}, retrying {retry_count}/{self.MAX_RETRIES}"
)
# Modify the last message if it's a heartbeat to include warning about special characters
if request_data["contents"] and len(request_data["contents"]) > 0:
last_message = request_data["contents"][-1]
if last_message.get("role") == "user" and last_message.get("parts"):
for part in last_message["parts"]:
if "text" in part:
try:
# Try to parse as JSON to check if it's a heartbeat
message_json = json_loads(part["text"])
if message_json.get("type") == "heartbeat" and "reason" in message_json:
# Append warning to the reason
warning = f" RETRY {retry_count}/{self.MAX_RETRIES} ***DO NOT USE SPECIAL CHARACTERS OR QUOTATIONS INSIDE FUNCTION CALL ARGUMENTS. IF YOU MUST, MAKE SURE TO ESCAPE THEM PROPERLY***"
message_json["reason"] = message_json["reason"] + warning
# Update the text with modified JSON
part["text"] = json_dumps(message_json)
logger.warning(
f"Modified heartbeat message with special character warning for retry {retry_count}/{self.MAX_RETRIES}"
)
except (json.JSONDecodeError, TypeError):
# Not a JSON message or not a heartbeat, skip modification
pass
should_retry = is_malformed_function_call
retry_count += 1
return response_data
@staticmethod
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
@@ -299,7 +338,6 @@ class GoogleVertexClient(LLMClientBase):
}
}
"""
response = GenerateContentResponse(**response_data)
try:
choices = []
@@ -517,6 +555,14 @@ class GoogleVertexClient(LLMClientBase):
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro")
def is_malformed_function_call(self, response_data: dict) -> dict:
response = GenerateContentResponse(**response_data)
for candidate in response.candidates:
content = candidate.content
if content is None or content.role is None or content.parts is None:
return candidate.finish_reason == "MALFORMED_FUNCTION_CALL"
return False
@trace_method
def handle_llm_error(self, e: Exception) -> Exception:
# Fallback to base implementation

View File

@@ -145,6 +145,7 @@ class ModelSettings(BaseSettings):
gemini_api_key: Optional[str] = None
gemini_base_url: str = "https://generativelanguage.googleapis.com/"
gemini_force_minimum_thinking_budget: bool = False
gemini_max_retries: int = 5
# google vertex
google_cloud_project: Optional[str] = None

View File

@@ -752,10 +752,11 @@ def test_tool_call(
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
request_options={"timeout_in_seconds": 300},
)
except Exception as e:
if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
pytest.skip("Skipping test for flash model due to malformed function call from llm")
# if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
# pytest.skip("Skipping test for flash model due to malformed function call from llm")
raise e
assert_tool_call_response(response.messages, llm_config=llm_config)
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
@@ -967,6 +968,7 @@ def test_step_streaming_tool_call(
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
request_options={"timeout_in_seconds": 300},
)
messages = accumulate_chunks(list(response))
assert_tool_call_response(messages, streaming=True, llm_config=llm_config)
@@ -1115,6 +1117,7 @@ def test_token_streaming_tool_call(
agent_id=agent_state.id,
messages=messages_to_send,
stream_tokens=True,
request_options={"timeout_in_seconds": 300},
)
verify_token_streaming = (
llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model
@@ -1183,6 +1186,7 @@ def test_background_token_streaming_greeting_with_assistant_message(
messages=messages_to_send,
stream_tokens=True,
background=True,
request_options={"timeout_in_seconds": 300},
)
verify_token_streaming = (
llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model
@@ -1418,6 +1422,7 @@ def test_async_tool_call(
run = client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
request_options={"timeout_in_seconds": 300},
)
run = wait_for_run_completion(client, run.id)
@@ -1639,10 +1644,11 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM
client.agents.messages.create(
agent_id=temp_agent_state.id,
messages=[MessageCreate(role="user", content=philosophical_question)],
request_options={"timeout_in_seconds": 300},
)
except Exception as e:
if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
pytest.skip("Skipping test for flash model due to malformed function call from llm")
# if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
# pytest.skip("Skipping test for flash model due to malformed function call from llm")
raise e
temp_agent_state = client.agents.retrieve(agent_id=temp_agent_state.id)