fix: Eliminate O(n^2) string growth for OpenAI [LET-6065] (#5973)
Finish
This commit is contained in:
committed by
Caren Thomas
parent
162661457a
commit
a699aca626
@@ -93,14 +93,15 @@ class OpenAIStreamingInterface:
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg)
|
||||
# Reader that extracts only the assistant message value from send_message args
|
||||
self.assistant_message_json_reader = FunctionArgumentsStreamHandler(json_key=self.assistant_message_tool_kwarg)
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
# Switch to list-based accumulation to avoid O(n^2) string growth
|
||||
self._function_name_parts: list[str] = []
|
||||
self._function_args_buffer_parts: list[str] | None = None
|
||||
self._function_id_parts: list[str] = []
|
||||
self.last_flushed_function_name = None
|
||||
self.last_flushed_function_id = None
|
||||
|
||||
# Buffer to hold function arguments until inner thoughts are complete
|
||||
self.current_function_arguments = ""
|
||||
self._current_function_arguments_parts: list[str] = []
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
# Premake IDs for database writes
|
||||
@@ -140,17 +141,39 @@ class OpenAIStreamingInterface:
|
||||
else:
|
||||
return [TextContent(text=content)]
|
||||
|
||||
def _get_function_name_buffer(self) -> str | None:
|
||||
return "".join(self._function_name_parts) if self._function_name_parts else None
|
||||
|
||||
def _get_function_id_buffer(self) -> str | None:
|
||||
return "".join(self._function_id_parts) if self._function_id_parts else None
|
||||
|
||||
def _clear_function_buffers(self) -> None:
|
||||
self._function_name_parts = []
|
||||
self._function_id_parts = []
|
||||
|
||||
def _append_function_name(self, s: str) -> None:
|
||||
self._function_name_parts.append(s)
|
||||
|
||||
def _append_function_id(self, s: str) -> None:
|
||||
self._function_id_parts.append(s)
|
||||
|
||||
def _append_current_function_arguments(self, s: str) -> None:
|
||||
self._current_function_arguments_parts.append(s)
|
||||
|
||||
def _get_current_function_arguments(self) -> str:
|
||||
return "".join(self._current_function_arguments_parts)
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
function_name = self.last_flushed_function_name if self.last_flushed_function_name else self.function_name_buffer
|
||||
function_name = self.last_flushed_function_name if self.last_flushed_function_name else self._get_function_name_buffer()
|
||||
if not function_name:
|
||||
raise ValueError("No tool call ID available")
|
||||
tool_call_id = self.last_flushed_function_id if self.last_flushed_function_id else self.function_id_buffer
|
||||
tool_call_id = self.last_flushed_function_id if self.last_flushed_function_id else self._get_function_id_buffer()
|
||||
if not tool_call_id:
|
||||
raise ValueError("No tool call ID available")
|
||||
return ToolCall(
|
||||
id=tool_call_id,
|
||||
function=FunctionCall(arguments=self.current_function_arguments, name=function_name),
|
||||
function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name),
|
||||
)
|
||||
|
||||
async def process(
|
||||
@@ -261,21 +284,15 @@ class OpenAIStreamingInterface:
|
||||
if tool_call.function.name:
|
||||
# If we're waiting for the first key, then we should hold back the name
|
||||
# ie add it to a buffer instead of returning it as a chunk
|
||||
if self.function_name_buffer is None:
|
||||
self.function_name_buffer = tool_call.function.name
|
||||
else:
|
||||
self.function_name_buffer += tool_call.function.name
|
||||
self._append_function_name(tool_call.function.name)
|
||||
|
||||
if tool_call.id:
|
||||
# Buffer until next time
|
||||
if self.function_id_buffer is None:
|
||||
self.function_id_buffer = tool_call.id
|
||||
else:
|
||||
self.function_id_buffer += tool_call.id
|
||||
self._append_function_id(tool_call.id)
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
self._append_current_function_arguments(tool_call.function.arguments)
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
if self.is_openai_proxy:
|
||||
@@ -301,10 +318,10 @@ class OpenAIStreamingInterface:
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
if updates_main_json:
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
if self._function_args_buffer_parts is None:
|
||||
self._function_args_buffer_parts = [updates_main_json]
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
self._function_args_buffer_parts.append(updates_main_json)
|
||||
|
||||
# If we have main_json, we should output a ToolCallMessage
|
||||
elif updates_main_json:
|
||||
@@ -312,27 +329,27 @@ class OpenAIStreamingInterface:
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
# safe we'll output name first in a separate chunk
|
||||
if self.function_name_buffer:
|
||||
if self._get_function_name_buffer():
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||
if self.use_assistant_message and self._get_function_name_buffer() == self.assistant_message_tool_name:
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
if self._get_function_id_buffer():
|
||||
self.prev_assistant_message_id = self._get_function_id_buffer()
|
||||
# Reset message reader at the start of a new send_message stream
|
||||
self.assistant_message_json_reader.reset()
|
||||
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
self.tool_call_name = str(self.function_name_buffer)
|
||||
self.tool_call_name = str(self._get_function_name_buffer())
|
||||
if self.tool_call_name in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
tool_call_id=self._get_function_id_buffer(),
|
||||
),
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
run_id=self.run_id,
|
||||
@@ -340,9 +357,9 @@ class OpenAIStreamingInterface:
|
||||
)
|
||||
else:
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
tool_call_id=self._get_function_id_buffer(),
|
||||
)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
@@ -357,20 +374,19 @@ class OpenAIStreamingInterface:
|
||||
yield tool_call_msg
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
self.last_flushed_function_name = self.function_name_buffer
|
||||
self.last_flushed_function_name = self._get_function_name_buffer()
|
||||
if self.last_flushed_function_id is None:
|
||||
self.last_flushed_function_id = self.function_id_buffer
|
||||
self.last_flushed_function_id = self._get_function_id_buffer()
|
||||
# Clear the buffer
|
||||
self.function_name_buffer = None
|
||||
self.function_id_buffer = None
|
||||
self._clear_function_buffers()
|
||||
# Since we're clearing the name buffer, we should store
|
||||
# any updates to the arguments inside a separate buffer
|
||||
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
if self._function_args_buffer_parts is None:
|
||||
self._function_args_buffer_parts = [updates_main_json]
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
self._function_args_buffer_parts.append(updates_main_json)
|
||||
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
@@ -382,9 +398,9 @@ class OpenAIStreamingInterface:
|
||||
):
|
||||
# Minimal, robust extraction: only emit the value of "message".
|
||||
# If we buffered a prefix while name was streaming, feed it first.
|
||||
if self.function_args_buffer:
|
||||
payload = self.function_args_buffer + tool_call.function.arguments
|
||||
self.function_args_buffer = None
|
||||
if self._function_args_buffer_parts:
|
||||
payload = "".join(self._function_args_buffer_parts + [tool_call.function.arguments])
|
||||
self._function_args_buffer_parts = None
|
||||
else:
|
||||
payload = tool_call.function.arguments
|
||||
extracted = self.assistant_message_json_reader.process_json_chunk(payload)
|
||||
@@ -403,24 +419,24 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
if self._get_function_id_buffer():
|
||||
self.prev_assistant_message_id = self._get_function_id_buffer()
|
||||
else:
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
if self._function_args_buffer_parts:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
combined_chunk = "".join(self._function_args_buffer_parts + [updates_main_json])
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
if self.function_name_buffer in self.requires_approval_tools:
|
||||
if self._get_function_name_buffer() in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
tool_call_id=self._get_function_id_buffer(),
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
@@ -429,9 +445,9 @@ class OpenAIStreamingInterface:
|
||||
)
|
||||
else:
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
name=self._get_function_name_buffer(),
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
tool_call_id=self._get_function_id_buffer(),
|
||||
)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
@@ -446,20 +462,20 @@ class OpenAIStreamingInterface:
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
self._function_args_buffer_parts = None
|
||||
self._function_id_parts = []
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
if self.function_name_buffer in self.requires_approval_tools:
|
||||
if self._get_function_name_buffer() in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
tool_call_id=self._get_function_id_buffer(),
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1),
|
||||
@@ -470,7 +486,7 @@ class OpenAIStreamingInterface:
|
||||
tool_call_delta = ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
tool_call_id=self._get_function_id_buffer(),
|
||||
)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_message_id,
|
||||
@@ -484,7 +500,7 @@ class OpenAIStreamingInterface:
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.function_id_buffer = None
|
||||
self._function_id_parts = []
|
||||
|
||||
|
||||
class SimpleOpenAIStreamingInterface:
|
||||
@@ -539,6 +555,7 @@ class SimpleOpenAIStreamingInterface:
|
||||
concat_content = ""
|
||||
merged_messages = []
|
||||
reasoning_content = []
|
||||
concat_content_parts: list[str] = []
|
||||
|
||||
for msg in self.content_messages:
|
||||
if isinstance(msg, HiddenReasoningMessage) and not shown_omitted:
|
||||
@@ -548,16 +565,16 @@ class SimpleOpenAIStreamingInterface:
|
||||
reasoning_content.append(msg.reasoning)
|
||||
elif isinstance(msg, AssistantMessage):
|
||||
if isinstance(msg.content, list):
|
||||
concat_content += "".join([c.text for c in msg.content])
|
||||
concat_content_parts.append("".join([c.text for c in msg.content]))
|
||||
else:
|
||||
concat_content += msg.content
|
||||
concat_content_parts.append(msg.content)
|
||||
|
||||
if reasoning_content:
|
||||
combined_reasoning = "".join(reasoning_content)
|
||||
merged_messages.append(ReasoningContent(is_native=True, reasoning=combined_reasoning, signature=None))
|
||||
|
||||
if concat_content:
|
||||
merged_messages.append(TextContent(text=concat_content))
|
||||
if concat_content_parts:
|
||||
merged_messages.append(TextContent(text="".join(concat_content_parts)))
|
||||
|
||||
return merged_messages
|
||||
|
||||
@@ -569,9 +586,9 @@ class SimpleOpenAIStreamingInterface:
|
||||
result: list[ToolCall] = []
|
||||
for idx in ordered_indices:
|
||||
ctx = self._tool_calls_acc[idx]
|
||||
name = ctx.get("name", "")
|
||||
args = ctx.get("arguments", "")
|
||||
call_id = ctx.get("id", "")
|
||||
name = "".join(ctx.get("name_parts", [])) if "name_parts" in ctx else ctx.get("name", "")
|
||||
args = "".join(ctx.get("arguments_parts", [])) if "arguments_parts" in ctx else ctx.get("arguments", "")
|
||||
call_id = "".join(ctx.get("id_parts", [])) if "id_parts" in ctx else ctx.get("id", "")
|
||||
if call_id and name:
|
||||
result.append(ToolCall(id=call_id, function=FunctionCall(arguments=args or "", name=name)))
|
||||
return result
|
||||
@@ -742,15 +759,15 @@ class SimpleOpenAIStreamingInterface:
|
||||
if idx not in self._tool_call_start_order:
|
||||
self._tool_call_start_order.append(idx)
|
||||
if idx not in self._tool_calls_acc:
|
||||
self._tool_calls_acc[idx] = {"name": "", "arguments": "", "id": ""}
|
||||
self._tool_calls_acc[idx] = {"name_parts": [], "arguments_parts": [], "id_parts": []}
|
||||
acc = self._tool_calls_acc[idx]
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
acc["name"] += tool_call.function.name
|
||||
acc["name_parts"].append(tool_call.function.name)
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
acc["arguments"] += tool_call.function.arguments
|
||||
acc["arguments_parts"].append(tool_call.function.arguments)
|
||||
if tool_call.id:
|
||||
acc["id"] += tool_call.id
|
||||
acc["id_parts"].append(tool_call.id)
|
||||
|
||||
delta = ToolCallDelta(
|
||||
name=tool_call.function.name if (tool_call.function and tool_call.function.name) else None,
|
||||
@@ -758,7 +775,8 @@ class SimpleOpenAIStreamingInterface:
|
||||
tool_call_id=tool_call.id if tool_call.id else None,
|
||||
)
|
||||
|
||||
if acc.get("name") and acc["name"] in self.requires_approval_tools:
|
||||
_curr_name = "".join(acc.get("name_parts", [])) if "name_parts" in acc else acc.get("name", "")
|
||||
if _curr_name and _curr_name in self.requires_approval_tools:
|
||||
tool_call_msg = ApprovalRequestMessage(
|
||||
id=decrement_message_uuid(self.letta_message_id),
|
||||
date=datetime.now(timezone.utc),
|
||||
|
||||
Reference in New Issue
Block a user