feat: gemini parallel tool calling non streaming [LET-5993] (#5889)
* first hack * just test non streaming * stream_steps should pass too * clean up --------- Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
@@ -455,6 +455,13 @@ class LettaAgentV3(LettaAgentV2):
|
||||
request_data["parallel_tool_calls"] = True
|
||||
else:
|
||||
request_data["parallel_tool_calls"] = False
|
||||
|
||||
# Gemini (Google AI/Vertex) parallel tool use
|
||||
elif self.agent_state.llm_config.model_endpoint_type in ["google_ai", "google_vertex"]:
|
||||
# Gemini supports parallel tool calling natively through multiple parts in the response
|
||||
# We just need to ensure the config flag is set for tracking purposes
|
||||
# The actual handling happens in GoogleVertexClient.convert_response_to_chat_completion
|
||||
pass # No specific request_data field needed for Gemini
|
||||
except Exception:
|
||||
# if this fails, we simply don't enable parallel tool use
|
||||
pass
|
||||
|
||||
@@ -444,14 +444,14 @@ class GoogleVertexClient(LLMClientBase):
|
||||
# NOTE(Apr 9, 2025): there's a very strange bug on 2.5 where the response has a part with broken text
|
||||
# {'candidates': [{'content': {'parts': [{'functionCall': {'name': 'send_message', 'args': {'request_heartbeat': False, 'message': 'Hello! How can I make your day better?', 'inner_thoughts': 'User has initiated contact. Sending a greeting.'}}}], 'role': 'model'}, 'finishReason': 'STOP', 'avgLogprobs': -0.25891534213362066}], 'usageMetadata': {'promptTokenCount': 2493, 'candidatesTokenCount': 29, 'totalTokenCount': 2522, 'promptTokensDetails': [{'modality': 'TEXT', 'tokenCount': 2493}], 'candidatesTokensDetails': [{'modality': 'TEXT', 'tokenCount': 29}]}, 'modelVersion': 'gemini-1.5-pro-002'}
|
||||
# To patch this, if we have multiple parts we can take the last one
|
||||
if len(parts) > 1 and not llm_config.enable_reasoner:
|
||||
# Unless parallel tool calling is enabled, in which case multiple parts may be intentional
|
||||
if len(parts) > 1 and not llm_config.enable_reasoner and not llm_config.parallel_tool_calls:
|
||||
logger.warning(f"Unexpected multiple parts in response from Google AI: {parts}")
|
||||
# only truncate if reasoning is off
|
||||
# only truncate if reasoning is off and parallel tool calls are disabled
|
||||
parts = [parts[-1]]
|
||||
|
||||
# TODO support parts / multimodal
|
||||
# TODO support parallel tool calling natively
|
||||
# TODO Alternative here is to throw away everything else except for the first part
|
||||
# Parallel tool calling is now supported when llm_config.parallel_tool_calls is enabled
|
||||
openai_response_message = None
|
||||
for response_message in parts:
|
||||
# Convert the actual message style to OpenAI style
|
||||
|
||||
@@ -540,8 +540,11 @@ async def test_parallel_tool_calls(
|
||||
llm_config: LLMConfig,
|
||||
send_type: str,
|
||||
) -> None:
|
||||
if llm_config.model_endpoint_type != "anthropic" and llm_config.model_endpoint_type != "openai":
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic and OpenAI models.")
|
||||
if llm_config.model_endpoint_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
|
||||
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
|
||||
|
||||
if llm_config.model_endpoint_type in ["google_ai", "google_vertex"] and send_type not in ["step", "async", "stream_steps"]:
|
||||
pytest.skip("Gemini parallel tool calling test only for non streaming scenarios. FIX WHEN STREAMING IS IMPLEMENTED")
|
||||
|
||||
# change llm_config to support parallel tool calling
|
||||
llm_config.parallel_tool_calls = True
|
||||
@@ -587,10 +590,14 @@ async def test_parallel_tool_calls(
|
||||
# verify each tool call
|
||||
for tc in tool_call_msg.tool_calls:
|
||||
assert tc["name"] == "roll_dice"
|
||||
# Support both Anthropic (toolu_) and OpenAI (call_) tool call ID formats
|
||||
assert tc["tool_call_id"].startswith("toolu_") or tc["tool_call_id"].startswith("call_"), (
|
||||
f"Unexpected tool call ID format: {tc['tool_call_id']}"
|
||||
# Support Anthropic (toolu_), OpenAI (call_), and Gemini (UUID) tool call ID formats
|
||||
# Gemini uses UUID format which could start with any alphanumeric character
|
||||
valid_id_format = (
|
||||
tc["tool_call_id"].startswith("toolu_")
|
||||
or tc["tool_call_id"].startswith("call_")
|
||||
or (len(tc["tool_call_id"]) > 0 and tc["tool_call_id"][0].isalnum()) # UUID format for Gemini
|
||||
)
|
||||
assert valid_id_format, f"Unexpected tool call ID format: {tc['tool_call_id']}"
|
||||
assert "num_sides" in tc["arguments"]
|
||||
|
||||
# assert tool returns match the tool calls
|
||||
|
||||
Reference in New Issue
Block a user