feat: Support parallel tool calling streaming for OpenAI chat completions [LET-4594] (#5865)

* Finish chat completions parallel tool calling

* Undo comments

* Add comments

* Remove test file
This commit is contained in:
Matthew Zhou
2025-10-30 17:02:47 -07:00
committed by Caren Thomas
parent 599adb4c26
commit ff81f4153b
3 changed files with 73 additions and 84 deletions

View File

@@ -524,10 +524,9 @@ class SimpleOpenAIStreamingInterface:
self.messages = messages or []
self.tools = tools or []
# Buffers to hold accumulating tools
self.tool_call_name = ""
self.tool_call_args = ""
self.tool_call_id = ""
# Accumulate per-index tool call fragments and preserve order
self._tool_calls_acc: dict[int, dict[str, str]] = {}
self._tool_call_start_order: list[int] = []
self.content_messages = []
self.emitted_hidden_reasoning = False # Track if we've emitted hidden reasoning message
@@ -561,19 +560,27 @@ class SimpleOpenAIStreamingInterface:
return merged_messages
def get_tool_call_object(self) -> ToolCall:
"""Useful for agent loop"""
if not self.tool_call_name:
raise ValueError("No tool call name available")
if not self.tool_call_args:
raise ValueError("No tool call arguments available")
if not self.tool_call_id:
raise ValueError("No tool call ID available")
def get_tool_call_objects(self) -> list[ToolCall]:
"""Return finalized tool calls (parallel supported)."""
if not self._tool_calls_acc:
return []
ordered_indices = [i for i in self._tool_call_start_order if i in self._tool_calls_acc]
result: list[ToolCall] = []
for idx in ordered_indices:
ctx = self._tool_calls_acc[idx]
name = ctx.get("name", "")
args = ctx.get("arguments", "")
call_id = ctx.get("id", "")
if call_id and name:
result.append(ToolCall(id=call_id, function=FunctionCall(arguments=args or "", name=name)))
return result
return ToolCall(
id=self.tool_call_id,
function=FunctionCall(arguments=self.tool_call_args, name=self.tool_call_name),
)
def get_tool_call_object(self) -> ToolCall:
"""Backwards-compatible single tool call accessor (first tool if multiple)."""
calls = self.get_tool_call_objects()
if not calls:
raise ValueError("No tool calls available")
return calls[0]
async def process(
self,
@@ -718,70 +725,61 @@ class SimpleOpenAIStreamingInterface:
yield reasoning_msg
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
tool_call = message_delta.tool_calls[0]
# Accumulate per-index tool call fragments and emit deltas
for tool_call in message_delta.tool_calls:
if (
not (tool_call.function and (tool_call.function.name or tool_call.function.arguments))
and not tool_call.id
and getattr(tool_call, "index", None) is None
):
continue
# For OpenAI reasoning models, emit a hidden reasoning message before the first tool call
# if not self.emitted_hidden_reasoning and is_openai_reasoning_model(self.model):
# self.emitted_hidden_reasoning = True
# if prev_message_type and prev_message_type != "hidden_reasoning_message":
# message_index += 1
# hidden_message = HiddenReasoningMessage(
# id=self.letta_message_id,
# date=datetime.now(timezone.utc),
# state="omitted",
# hidden_reasoning=None,
# otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
# )
# self.content_messages.append(hidden_message)
# prev_message_type = hidden_message.message_type
# message_index += 1 # Increment for the next message
# yield hidden_message
idx = getattr(tool_call, "index", None)
if idx is None:
idx = 0
if not tool_call.function.name and not tool_call.function.arguments and not tool_call.id:
# No chunks to process, exit
return
if idx not in self._tool_call_start_order:
self._tool_call_start_order.append(idx)
if idx not in self._tool_calls_acc:
self._tool_calls_acc[idx] = {"name": "", "arguments": "", "id": ""}
acc = self._tool_calls_acc[idx]
if tool_call.function.name:
self.tool_call_name += tool_call.function.name
if tool_call.function.arguments:
self.tool_call_args += tool_call.function.arguments
if tool_call.id:
self.tool_call_id += tool_call.id
if tool_call.function and tool_call.function.name:
acc["name"] += tool_call.function.name
if tool_call.function and tool_call.function.arguments:
acc["arguments"] += tool_call.function.arguments
if tool_call.id:
acc["id"] += tool_call.id
if self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
date=datetime.now(timezone.utc),
tool_call=ToolCallDelta(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
tool_call_id=tool_call.id,
),
# name=name,
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,
step_id=self.step_id,
delta = ToolCallDelta(
name=tool_call.function.name if (tool_call.function and tool_call.function.name) else None,
arguments=tool_call.function.arguments if (tool_call.function and tool_call.function.arguments) else None,
tool_call_id=tool_call.id if tool_call.id else None,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_delta = ToolCallDelta(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
tool_call_id=tool_call.id,
)
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=tool_call_delta,
tool_calls=tool_call_delta,
# name=name,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
if acc.get("name") and acc["name"] in self.requires_approval_tools:
tool_call_msg = ApprovalRequestMessage(
id=decrement_message_uuid(self.letta_message_id),
date=datetime.now(timezone.utc),
tool_call=delta,
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
run_id=self.run_id,
step_id=self.step_id,
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
date=datetime.now(timezone.utc),
tool_call=delta,
tool_calls=delta,
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
run_id=self.run_id,
step_id=self.step_id,
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
class SimpleOpenAIResponsesStreamingInterface: