diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index 31dc7a8d..ffdd4547 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -78,12 +78,6 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): use_responses = "input" in request_data and "messages" not in request_data # No support for Responses API proxy is_proxy = self.llm_config.provider_name == "lmstudio_openai" - # Use parallel tool calling interface if enabled in config - use_parallel = self.llm_config.parallel_tool_calls and tools and not use_responses and not is_proxy - - # TODO: Temp, remove - if use_parallel: - raise RuntimeError("Parallel tool calling not supported for OpenAI streaming") if use_responses and not is_proxy: self.interface = SimpleOpenAIResponsesStreamingInterface( diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 416e28c2..69416202 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -524,10 +524,9 @@ class SimpleOpenAIStreamingInterface: self.messages = messages or [] self.tools = tools or [] - # Buffers to hold accumulating tools - self.tool_call_name = "" - self.tool_call_args = "" - self.tool_call_id = "" + # Accumulate per-index tool call fragments and preserve order + self._tool_calls_acc: dict[int, dict[str, str]] = {} + self._tool_call_start_order: list[int] = [] self.content_messages = [] self.emitted_hidden_reasoning = False # Track if we've emitted hidden reasoning message @@ -561,19 +560,27 @@ class SimpleOpenAIStreamingInterface: return merged_messages - def get_tool_call_object(self) -> ToolCall: - """Useful for agent loop""" - if not self.tool_call_name: - raise ValueError("No tool call name available") - if not self.tool_call_args: - raise ValueError("No tool call arguments available") - if not self.tool_call_id: - raise ValueError("No tool call ID available") + def get_tool_call_objects(self) -> list[ToolCall]: + """Return finalized tool calls (parallel supported).""" + if not self._tool_calls_acc: + return [] + ordered_indices = [i for i in self._tool_call_start_order if i in self._tool_calls_acc] + result: list[ToolCall] = [] + for idx in ordered_indices: + ctx = self._tool_calls_acc[idx] + name = ctx.get("name", "") + args = ctx.get("arguments", "") + call_id = ctx.get("id", "") + if call_id and name: + result.append(ToolCall(id=call_id, function=FunctionCall(arguments=args or "", name=name))) + return result - return ToolCall( - id=self.tool_call_id, - function=FunctionCall(arguments=self.tool_call_args, name=self.tool_call_name), - ) + def get_tool_call_object(self) -> ToolCall: + """Backwards-compatible single tool call accessor (first tool if multiple).""" + calls = self.get_tool_call_objects() + if not calls: + raise ValueError("No tool calls available") + return calls[0] async def process( self, @@ -718,70 +725,61 @@ class SimpleOpenAIStreamingInterface: yield reasoning_msg if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: - tool_call = message_delta.tool_calls[0] + # Accumulate per-index tool call fragments and emit deltas + for tool_call in message_delta.tool_calls: + if ( + not (tool_call.function and (tool_call.function.name or tool_call.function.arguments)) + and not tool_call.id + and getattr(tool_call, "index", None) is None + ): + continue - # For OpenAI reasoning models, emit a hidden reasoning message before the first tool call - # if not self.emitted_hidden_reasoning and is_openai_reasoning_model(self.model): - # self.emitted_hidden_reasoning = True - # if prev_message_type and prev_message_type != "hidden_reasoning_message": - # message_index += 1 - # hidden_message = HiddenReasoningMessage( - # id=self.letta_message_id, - # date=datetime.now(timezone.utc), - # state="omitted", - # hidden_reasoning=None, - # otid=Message.generate_otid_from_id(self.letta_message_id, message_index), - # ) - # self.content_messages.append(hidden_message) - # prev_message_type = hidden_message.message_type - # message_index += 1 # Increment for the next message - # yield hidden_message + idx = getattr(tool_call, "index", None) + if idx is None: + idx = 0 - if not tool_call.function.name and not tool_call.function.arguments and not tool_call.id: - # No chunks to process, exit - return + if idx not in self._tool_call_start_order: + self._tool_call_start_order.append(idx) + if idx not in self._tool_calls_acc: + self._tool_calls_acc[idx] = {"name": "", "arguments": "", "id": ""} + acc = self._tool_calls_acc[idx] - if tool_call.function.name: - self.tool_call_name += tool_call.function.name - if tool_call.function.arguments: - self.tool_call_args += tool_call.function.arguments - if tool_call.id: - self.tool_call_id += tool_call.id + if tool_call.function and tool_call.function.name: + acc["name"] += tool_call.function.name + if tool_call.function and tool_call.function.arguments: + acc["arguments"] += tool_call.function.arguments + if tool_call.id: + acc["id"] += tool_call.id - if self.requires_approval_tools: - tool_call_msg = ApprovalRequestMessage( - id=decrement_message_uuid(self.letta_message_id), - date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - tool_call_id=tool_call.id, - ), - # name=name, - otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1), - run_id=self.run_id, - step_id=self.step_id, + delta = ToolCallDelta( + name=tool_call.function.name if (tool_call.function and tool_call.function.name) else None, + arguments=tool_call.function.arguments if (tool_call.function and tool_call.function.arguments) else None, + tool_call_id=tool_call.id if tool_call.id else None, ) - else: - if prev_message_type and prev_message_type != "tool_call_message": - message_index += 1 - tool_call_delta = ToolCallDelta( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - tool_call_id=tool_call.id, - ) - tool_call_msg = ToolCallMessage( - id=self.letta_message_id, - date=datetime.now(timezone.utc), - tool_call=tool_call_delta, - tool_calls=tool_call_delta, - # name=name, - otid=Message.generate_otid_from_id(self.letta_message_id, message_index), - run_id=self.run_id, - step_id=self.step_id, - ) - prev_message_type = tool_call_msg.message_type - yield tool_call_msg + + if acc.get("name") and acc["name"] in self.requires_approval_tools: + tool_call_msg = ApprovalRequestMessage( + id=decrement_message_uuid(self.letta_message_id), + date=datetime.now(timezone.utc), + tool_call=delta, + otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1), + run_id=self.run_id, + step_id=self.step_id, + ) + else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 + tool_call_msg = ToolCallMessage( + id=self.letta_message_id, + date=datetime.now(timezone.utc), + tool_call=delta, + tool_calls=delta, + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), + run_id=self.run_id, + step_id=self.step_id, + ) + prev_message_type = tool_call_msg.message_type + yield tool_call_msg class SimpleOpenAIResponsesStreamingInterface: diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 52bd4a63..5135256c 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -543,9 +543,6 @@ async def test_parallel_tool_calls( if llm_config.model_endpoint_type != "anthropic" and llm_config.model_endpoint_type != "openai": pytest.skip("Parallel tool calling test only applies to Anthropic and OpenAI models.") - if llm_config.model_endpoint_type == "openai" and send_type not in {"step", "stream_steps"}: - pytest.skip(f"OpenAI reasoning model {llm_config.model} does not support streaming parallel tool calling for now.") - # change llm_config to support parallel tool calling llm_config.parallel_tool_calls = True agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)