From 7b3cb0224a3ee6fc0f0a04f7f4c985fb847ca099 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 3 Nov 2025 12:12:44 -0800 Subject: [PATCH] feat: Add gemini parallel tool call streaming for gemini [LET-6027] (#5913) * Make changes to gemini streaming interface to support parallel tool calling * Finish send message integration test * Add comments --- .../interfaces/gemini_streaming_interface.py | 20 ++++++++++--------- tests/integration_test_send_message_v2.py | 3 --- 2 files changed, 11 insertions(+), 12 deletions(-) diff --git a/letta/interfaces/gemini_streaming_interface.py b/letta/interfaces/gemini_streaming_interface.py index fcf39036..3b2ea4e3 100644 --- a/letta/interfaces/gemini_streaming_interface.py +++ b/letta/interfaces/gemini_streaming_interface.py @@ -58,6 +58,8 @@ class SimpleGeminiStreamingInterface: self.tool_call_name: str | None = None self.tool_call_args: dict | None = None # NOTE: Not a str! + self.collected_tool_calls: list[ToolCall] = [] + # NOTE: signature only is included if tools are present self.thinking_signature: str | None = None @@ -81,6 +83,9 @@ class SimpleGeminiStreamingInterface: def get_tool_call_object(self) -> ToolCall: """Useful for agent loop""" + if self.collected_tool_calls: + return self.collected_tool_calls[-1] + if self.tool_call_id is None: raise ValueError("No tool call ID available") if self.tool_call_name is None: @@ -88,17 +93,12 @@ class SimpleGeminiStreamingInterface: if self.tool_call_args is None: raise ValueError("No tool call arguments available") - # TODO use json_dumps? tool_call_args_str = json.dumps(self.tool_call_args) + return ToolCall(id=self.tool_call_id, function=FunctionCall(name=self.tool_call_name, arguments=tool_call_args_str)) - return ToolCall( - id=self.tool_call_id, - type="function", - function=FunctionCall( - name=self.tool_call_name, - arguments=tool_call_args_str, - ), - ) + def get_tool_call_objects(self) -> list[ToolCall]: + """Return all finalized tool calls collected during this message (parallel supported).""" + return list(self.collected_tool_calls) async def process( self, @@ -255,6 +255,8 @@ class SimpleGeminiStreamingInterface: self.tool_call_name = name self.tool_call_args = arguments + self.collected_tool_calls.append(ToolCall(id=call_id, function=FunctionCall(name=name, arguments=arguments_str))) + if self.tool_call_name and self.tool_call_name in self.requires_approval_tools: yield ApprovalRequestMessage( id=decrement_message_uuid(self.letta_message_id), diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 802ea7f9..8a65ddf2 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -543,9 +543,6 @@ async def test_parallel_tool_calls( if llm_config.model_endpoint_type not in ["anthropic", "openai", "google_ai", "google_vertex"]: pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.") - if llm_config.model_endpoint_type in ["google_ai", "google_vertex"] and send_type not in ["step", "async", "stream_steps"]: - pytest.skip("Gemini parallel tool calling test only for non streaming scenarios. FIX WHEN STREAMING IS IMPLEMENTED") - # change llm_config to support parallel tool calling llm_config.parallel_tool_calls = True agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)