feat: Support parallel tool calling streaming for OpenAI chat completions [LET-4594] (#5865)
* Finish chat completions parallel tool calling * Undo comments * Add comments * Remove test file
This commit is contained in:
committed by
Caren Thomas
parent
599adb4c26
commit
ff81f4153b
@@ -524,10 +524,9 @@ class SimpleOpenAIStreamingInterface:
|
||||
self.messages = messages or []
|
||||
self.tools = tools or []
|
||||
|
||||
# Buffers to hold accumulating tools
|
||||
self.tool_call_name = ""
|
||||
self.tool_call_args = ""
|
||||
self.tool_call_id = ""
|
||||
# Accumulate per-index tool call fragments and preserve order
|
||||
self._tool_calls_acc: dict[int, dict[str, str]] = {}
|
||||
self._tool_call_start_order: list[int] = []
|
||||
|
||||
self.content_messages = []
|
||||
self.emitted_hidden_reasoning = False # Track if we've emitted hidden reasoning message
|
||||
@@ -561,19 +560,27 @@ class SimpleOpenAIStreamingInterface:
|
||||
|
||||
return merged_messages
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
if not self.tool_call_name:
|
||||
raise ValueError("No tool call name available")
|
||||
if not self.tool_call_args:
|
||||
raise ValueError("No tool call arguments available")
|
||||
if not self.tool_call_id:
|
||||
raise ValueError("No tool call ID available")
|
||||
def get_tool_call_objects(self) -> list[ToolCall]:
|
||||
"""Return finalized tool calls (parallel supported)."""
|
||||
if not self._tool_calls_acc:
|
||||
return []
|
||||
ordered_indices = [i for i in self._tool_call_start_order if i in self._tool_calls_acc]
|
||||
result: list[ToolCall] = []
|
||||
for idx in ordered_indices:
|
||||
ctx = self._tool_calls_acc[idx]
|
||||
name = ctx.get("name", "")
|
||||
args = ctx.get("arguments", "")
|
||||
call_id = ctx.get("id", "")
|
||||
if call_id and name:
|
||||
result.append(ToolCall(id=call_id, function=FunctionCall(arguments=args or "", name=name)))
|
||||
return result
|
||||
|
||||
return ToolCall(
|
||||
id=self.tool_call_id,
|
||||
function=FunctionCall(arguments=self.tool_call_args, name=self.tool_call_name),
|
||||
)
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Backwards-compatible single tool call accessor (first tool if multiple)."""
|
||||
calls = self.get_tool_call_objects()
|
||||
if not calls:
|
||||
raise ValueError("No tool calls available")
|
||||
return calls[0]
|
||||
|
||||
async def process(
|
||||
self,
|
||||
@@ -718,70 +725,61 @@ class SimpleOpenAIStreamingInterface:
|
||||
yield reasoning_msg
|
||||
|
||||
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
||||
tool_call = message_delta.tool_calls[0]
|
||||
# Accumulate per-index tool call fragments and emit deltas
|
||||
for tool_call in message_delta.tool_calls:
|
||||
if (
|
||||
not (tool_call.function and (tool_call.function.name or tool_call.function.arguments))
|
||||
and not tool_call.id
|
||||
and getattr(tool_call, "index", None) is None
|
||||
):
|
||||
continue
|
||||
|
||||
# For OpenAI reasoning models, emit a hidden reasoning message before the first tool call
|
||||
# if not self.emitted_hidden_reasoning and is_openai_reasoning_model(self.model):
|
||||
# self.emitted_hidden_reasoning = True
|
||||
# if prev_message_type and prev_message_type != "hidden_reasoning_message":
|
||||
# message_index += 1
|
||||
# hidden_message = HiddenReasoningMessage(
|
||||
# id=self.letta_message_id,
|
||||
# date=datetime.now(timezone.utc),
|
||||
# state="omitted",
|
||||
# hidden_reasoning=None,
|
||||
# otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
# )
|
||||
# self.content_messages.append(hidden_message)
|
||||
# prev_message_type = hidden_message.message_type
|
||||
# message_index += 1 # Increment for the next message
|
||||
# yield hidden_message
|
||||
idx = getattr(tool_call, "index", None)
|
||||
if idx is None:
|
||||
idx = 0
|
||||
|
||||
if not tool_call.function.name and not tool_call.function.arguments and not tool_call.id:
|
||||
# No chunks to process, exit
|
||||
return
|
||||
if idx not in self._tool_call_start_order:
|
||||
self._tool_call_start_order.append(idx)
|
||||
if idx not in self._tool_calls_acc:
|
||||
self._tool_calls_acc[idx] = {"name": "", "arguments": "", "id": ""}
|
||||
acc = self._tool_calls_acc[idx]
|
||||
|
||||
if tool_call.function.name:
|
||||
self.tool_call_name += tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
self.tool_call_args += tool_call.function.arguments
|
||||
if tool_call.id:
|
||||
self.tool_call_id += tool_call.id
|
||||
if tool_call.function and tool_call.function.name:
|
||||
acc["name"] += tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
acc["arguments"] += tool_call.function.arguments
|
||||
if tool_call.id:
|
||||
acc["id"] += tool_call.id
|
||||
|
||||
if self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
tool_call_id=tool_call.id,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
delta = ToolCallDelta(
|
||||
name=tool_call.function.name if (tool_call.function and tool_call.function.name) else None,
|
||||
arguments=tool_call.function.arguments if (tool_call.function and tool_call.function.arguments) else None,
|
||||
tool_call_id=tool_call.id if tool_call.id else None,
|
||||
)
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=tool_call_delta,
|
||||
tool_calls=tool_call_delta,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
if acc.get("name") and acc["name"] in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=delta,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=delta,
|
||||
tool_calls=delta,
|
||||
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
|
||||
run_id=self.run_id,
|
||||
step_id=self.step_id,
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
|
||||
class SimpleOpenAIResponsesStreamingInterface:
|
||||
|
||||
Reference in New Issue
Block a user