feat: Add gemini parallel tool call streaming for gemini [LET-6027] (#5913)
* Make changes to gemini streaming interface to support parallel tool calling * Finish send message integration test * Add comments
This commit is contained in:
committed by
Caren Thomas
parent
8468ef3cd7
commit
7b3cb0224a
@@ -58,6 +58,8 @@ class SimpleGeminiStreamingInterface:
|
|||||||
self.tool_call_name: str | None = None
|
self.tool_call_name: str | None = None
|
||||||
self.tool_call_args: dict | None = None # NOTE: Not a str!
|
self.tool_call_args: dict | None = None # NOTE: Not a str!
|
||||||
|
|
||||||
|
self.collected_tool_calls: list[ToolCall] = []
|
||||||
|
|
||||||
# NOTE: signature only is included if tools are present
|
# NOTE: signature only is included if tools are present
|
||||||
self.thinking_signature: str | None = None
|
self.thinking_signature: str | None = None
|
||||||
|
|
||||||
@@ -81,6 +83,9 @@ class SimpleGeminiStreamingInterface:
|
|||||||
|
|
||||||
def get_tool_call_object(self) -> ToolCall:
|
def get_tool_call_object(self) -> ToolCall:
|
||||||
"""Useful for agent loop"""
|
"""Useful for agent loop"""
|
||||||
|
if self.collected_tool_calls:
|
||||||
|
return self.collected_tool_calls[-1]
|
||||||
|
|
||||||
if self.tool_call_id is None:
|
if self.tool_call_id is None:
|
||||||
raise ValueError("No tool call ID available")
|
raise ValueError("No tool call ID available")
|
||||||
if self.tool_call_name is None:
|
if self.tool_call_name is None:
|
||||||
@@ -88,17 +93,12 @@ class SimpleGeminiStreamingInterface:
|
|||||||
if self.tool_call_args is None:
|
if self.tool_call_args is None:
|
||||||
raise ValueError("No tool call arguments available")
|
raise ValueError("No tool call arguments available")
|
||||||
|
|
||||||
# TODO use json_dumps?
|
|
||||||
tool_call_args_str = json.dumps(self.tool_call_args)
|
tool_call_args_str = json.dumps(self.tool_call_args)
|
||||||
|
return ToolCall(id=self.tool_call_id, function=FunctionCall(name=self.tool_call_name, arguments=tool_call_args_str))
|
||||||
|
|
||||||
return ToolCall(
|
def get_tool_call_objects(self) -> list[ToolCall]:
|
||||||
id=self.tool_call_id,
|
"""Return all finalized tool calls collected during this message (parallel supported)."""
|
||||||
type="function",
|
return list(self.collected_tool_calls)
|
||||||
function=FunctionCall(
|
|
||||||
name=self.tool_call_name,
|
|
||||||
arguments=tool_call_args_str,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def process(
|
async def process(
|
||||||
self,
|
self,
|
||||||
@@ -255,6 +255,8 @@ class SimpleGeminiStreamingInterface:
|
|||||||
self.tool_call_name = name
|
self.tool_call_name = name
|
||||||
self.tool_call_args = arguments
|
self.tool_call_args = arguments
|
||||||
|
|
||||||
|
self.collected_tool_calls.append(ToolCall(id=call_id, function=FunctionCall(name=name, arguments=arguments_str)))
|
||||||
|
|
||||||
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
if self.tool_call_name and self.tool_call_name in self.requires_approval_tools:
|
||||||
yield ApprovalRequestMessage(
|
yield ApprovalRequestMessage(
|
||||||
id=decrement_message_uuid(self.letta_message_id),
|
id=decrement_message_uuid(self.letta_message_id),
|
||||||
|
|||||||
@@ -543,9 +543,6 @@ async def test_parallel_tool_calls(
|
|||||||
if llm_config.model_endpoint_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
|
if llm_config.model_endpoint_type not in ["anthropic", "openai", "google_ai", "google_vertex"]:
|
||||||
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
|
pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.")
|
||||||
|
|
||||||
if llm_config.model_endpoint_type in ["google_ai", "google_vertex"] and send_type not in ["step", "async", "stream_steps"]:
|
|
||||||
pytest.skip("Gemini parallel tool calling test only for non streaming scenarios. FIX WHEN STREAMING IS IMPLEMENTED")
|
|
||||||
|
|
||||||
# change llm_config to support parallel tool calling
|
# change llm_config to support parallel tool calling
|
||||||
llm_config.parallel_tool_calls = True
|
llm_config.parallel_tool_calls = True
|
||||||
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
||||||
|
|||||||
Reference in New Issue
Block a user