fix: retry on MALFORMED_FUNCTION_CALL for gemini [LET-4089]
--------- Co-authored-by: Letta Bot <noreply@letta.com>
This commit is contained in:
@@ -31,6 +31,8 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class GoogleVertexClient(LLMClientBase):
|
||||
MAX_RETRIES = model_settings.gemini_max_retries
|
||||
|
||||
def _get_client(self):
|
||||
timeout_ms = int(settings.llm_request_timeout_seconds * 1000)
|
||||
return genai.Client(
|
||||
@@ -59,12 +61,49 @@ class GoogleVertexClient(LLMClientBase):
|
||||
Performs underlying request to llm and returns raw response.
|
||||
"""
|
||||
client = self._get_client()
|
||||
response = await client.aio.models.generate_content(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
return response.model_dump()
|
||||
|
||||
# Gemini 2.5 models will often return MALFORMED_FUNCTION_CALL, force a retry
|
||||
# https://github.com/googleapis/python-aiplatform/issues/4472
|
||||
retry_count = 1
|
||||
should_retry = True
|
||||
while should_retry and retry_count <= self.MAX_RETRIES:
|
||||
response = await client.aio.models.generate_content(
|
||||
model=llm_config.model,
|
||||
contents=request_data["contents"],
|
||||
config=request_data["config"],
|
||||
)
|
||||
response_data = response.model_dump()
|
||||
is_malformed_function_call = self.is_malformed_function_call(response_data)
|
||||
if is_malformed_function_call:
|
||||
logger.warning(
|
||||
f"Received FinishReason.MALFORMED_FUNCTION_CALL in response for {llm_config.model}, retrying {retry_count}/{self.MAX_RETRIES}"
|
||||
)
|
||||
# Modify the last message if it's a heartbeat to include warning about special characters
|
||||
if request_data["contents"] and len(request_data["contents"]) > 0:
|
||||
last_message = request_data["contents"][-1]
|
||||
if last_message.get("role") == "user" and last_message.get("parts"):
|
||||
for part in last_message["parts"]:
|
||||
if "text" in part:
|
||||
try:
|
||||
# Try to parse as JSON to check if it's a heartbeat
|
||||
message_json = json_loads(part["text"])
|
||||
if message_json.get("type") == "heartbeat" and "reason" in message_json:
|
||||
# Append warning to the reason
|
||||
warning = f" RETRY {retry_count}/{self.MAX_RETRIES} ***DO NOT USE SPECIAL CHARACTERS OR QUOTATIONS INSIDE FUNCTION CALL ARGUMENTS. IF YOU MUST, MAKE SURE TO ESCAPE THEM PROPERLY***"
|
||||
message_json["reason"] = message_json["reason"] + warning
|
||||
# Update the text with modified JSON
|
||||
part["text"] = json_dumps(message_json)
|
||||
logger.warning(
|
||||
f"Modified heartbeat message with special character warning for retry {retry_count}/{self.MAX_RETRIES}"
|
||||
)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
# Not a JSON message or not a heartbeat, skip modification
|
||||
pass
|
||||
|
||||
should_retry = is_malformed_function_call
|
||||
retry_count += 1
|
||||
|
||||
return response_data
|
||||
|
||||
@staticmethod
|
||||
def add_dummy_model_messages(messages: List[dict]) -> List[dict]:
|
||||
@@ -299,7 +338,6 @@ class GoogleVertexClient(LLMClientBase):
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
response = GenerateContentResponse(**response_data)
|
||||
try:
|
||||
choices = []
|
||||
@@ -517,6 +555,14 @@ class GoogleVertexClient(LLMClientBase):
|
||||
def is_reasoning_model(self, llm_config: LLMConfig) -> bool:
|
||||
return llm_config.model.startswith("gemini-2.5-flash") or llm_config.model.startswith("gemini-2.5-pro")
|
||||
|
||||
def is_malformed_function_call(self, response_data: dict) -> dict:
|
||||
response = GenerateContentResponse(**response_data)
|
||||
for candidate in response.candidates:
|
||||
content = candidate.content
|
||||
if content is None or content.role is None or content.parts is None:
|
||||
return candidate.finish_reason == "MALFORMED_FUNCTION_CALL"
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
def handle_llm_error(self, e: Exception) -> Exception:
|
||||
# Fallback to base implementation
|
||||
|
||||
@@ -145,6 +145,7 @@ class ModelSettings(BaseSettings):
|
||||
gemini_api_key: Optional[str] = None
|
||||
gemini_base_url: str = "https://generativelanguage.googleapis.com/"
|
||||
gemini_force_minimum_thinking_budget: bool = False
|
||||
gemini_max_retries: int = 5
|
||||
|
||||
# google vertex
|
||||
google_cloud_project: Optional[str] = None
|
||||
|
||||
@@ -752,10 +752,11 @@ def test_tool_call(
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
request_options={"timeout_in_seconds": 300},
|
||||
)
|
||||
except Exception as e:
|
||||
if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
|
||||
pytest.skip("Skipping test for flash model due to malformed function call from llm")
|
||||
# if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
|
||||
# pytest.skip("Skipping test for flash model due to malformed function call from llm")
|
||||
raise e
|
||||
assert_tool_call_response(response.messages, llm_config=llm_config)
|
||||
messages_from_db = client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
@@ -967,6 +968,7 @@ def test_step_streaming_tool_call(
|
||||
response = client.agents.messages.create_stream(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
request_options={"timeout_in_seconds": 300},
|
||||
)
|
||||
messages = accumulate_chunks(list(response))
|
||||
assert_tool_call_response(messages, streaming=True, llm_config=llm_config)
|
||||
@@ -1115,6 +1117,7 @@ def test_token_streaming_tool_call(
|
||||
agent_id=agent_state.id,
|
||||
messages=messages_to_send,
|
||||
stream_tokens=True,
|
||||
request_options={"timeout_in_seconds": 300},
|
||||
)
|
||||
verify_token_streaming = (
|
||||
llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model
|
||||
@@ -1183,6 +1186,7 @@ def test_background_token_streaming_greeting_with_assistant_message(
|
||||
messages=messages_to_send,
|
||||
stream_tokens=True,
|
||||
background=True,
|
||||
request_options={"timeout_in_seconds": 300},
|
||||
)
|
||||
verify_token_streaming = (
|
||||
llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model
|
||||
@@ -1418,6 +1422,7 @@ def test_async_tool_call(
|
||||
run = client.agents.messages.create_async(
|
||||
agent_id=agent_state.id,
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
request_options={"timeout_in_seconds": 300},
|
||||
)
|
||||
run = wait_for_run_completion(client, run.id)
|
||||
|
||||
@@ -1639,10 +1644,11 @@ def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLM
|
||||
client.agents.messages.create(
|
||||
agent_id=temp_agent_state.id,
|
||||
messages=[MessageCreate(role="user", content=philosophical_question)],
|
||||
request_options={"timeout_in_seconds": 300},
|
||||
)
|
||||
except Exception as e:
|
||||
if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
|
||||
pytest.skip("Skipping test for flash model due to malformed function call from llm")
|
||||
# if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e):
|
||||
# pytest.skip("Skipping test for flash model due to malformed function call from llm")
|
||||
raise e
|
||||
|
||||
temp_agent_state = client.agents.retrieve(agent_id=temp_agent_state.id)
|
||||
|
||||
Reference in New Issue
Block a user