feat: gemini parallel tool calling non streaming [LET-5993] (#5889)

* first hack

* just test non streaming

* stream_steps should pass too

* clean up

---------

Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
Ari Webb
2025-10-31 17:03:26 -07:00
committed by Caren Thomas
parent da11d80bf4
commit 7427c0998e
3 changed files with 23 additions and 9 deletions

View File

@@ -455,6 +455,13 @@ class LettaAgentV3(LettaAgentV2):
request_data["parallel_tool_calls"] = True
else:
request_data["parallel_tool_calls"] = False
# Gemini (Google AI/Vertex) parallel tool use
elif self.agent_state.llm_config.model_endpoint_type in ["google_ai", "google_vertex"]:
# Gemini supports parallel tool calling natively through multiple parts in the response
# We just need to ensure the config flag is set for tracking purposes
# The actual handling happens in GoogleVertexClient.convert_response_to_chat_completion
pass # No specific request_data field needed for Gemini
except Exception:
# if this fails, we simply don't enable parallel tool use
pass

View File

@@ -444,14 +444,14 @@ class GoogleVertexClient(LLMClientBase):
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
# To patch this, if we have multiple parts we can take the last one
if len(parts) > 1 and not llm_config.enable_reasoner:
# Unless parallel tool calling is enabled, in which case multiple parts may be intentional
if len(parts) > 1 and not llm_config.enable_reasoner and not llm_config.parallel_tool_calls:
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
# only truncate if reasoning is off
# only truncate if reasoning is off and parallel tool calls are disabled
parts = [parts[-1]]
# TODO support parts / multimodal
# TODO support parallel tool calling natively
# TODO Alternative here is to throw away everything else except for the first part
# Parallel tool calling is now supported when llm_config.parallel_tool_calls is enabled
openai_response_message = None
for response_message in parts:
# Convert the actual message style to OpenAI style

View File

@@ -540,8 +540,11 @@ async def test_parallel_tool_calls(
llm_config: LLMConfig,
send_type: str,
) -> None:
if llm_config.model_endpoint_type != "anthropic" and llm_config.model_endpoint_type != "openai":
pytest.skip("Parallel tool calling test only applies to Anthropic and OpenAI models.")
if llm_config.model_endpoint_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
if llm_config.model_endpoint_type in ["google_ai", "google_vertex"] and send_type not in ["step", "async", "stream_steps"]:
pytest.skip("Gemini parallel tool calling test only for non streaming scenarios. FIX WHEN STREAMING IS IMPLEMENTED")
# change llm_config to support parallel tool calling
llm_config.parallel_tool_calls = True
@@ -587,10 +590,14 @@ async def test_parallel_tool_calls(
# verify each tool call
for tc in tool_call_msg.tool_calls:
assert tc["name"] == "roll_dice"
# Support both Anthropic (toolu_) and OpenAI (call_) tool call ID formats
assert tc["tool_call_id"].startswith("toolu_") or tc["tool_call_id"].startswith("call_"), (
f"Unexpected tool call ID format: {tc['tool_call_id']}"
# Support Anthropic (toolu_), OpenAI (call_), and Gemini (UUID) tool call ID formats
# Gemini uses UUID format which could start with any alphanumeric character
valid_id_format = (
tc["tool_call_id"].startswith("toolu_")
or tc["tool_call_id"].startswith("call_")
or (len(tc["tool_call_id"]) > 0 and tc["tool_call_id"][0].isalnum()) # UUID format for Gemini
)
assert valid_id_format, f"Unexpected tool call ID format: {tc['tool_call_id']}"
assert "num_sides" in tc["arguments"]
# assert tool returns match the tool calls