From a699aca626e6bab6c11cb5299c5bd2e8f115b0f2 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 4 Nov 2025 18:04:39 -0800 Subject: [PATCH] fix: Eliminate O(n^2) string growth for OpenAI [LET-6065] (#5973) Finish --- .../interfaces/openai_streaming_interface.py | 148 ++++++++++-------- 1 file changed, 83 insertions(+), 65 deletions(-) diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 834a130a..32fa9933 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -93,14 +93,15 @@ class OpenAIStreamingInterface: self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=put_inner_thoughts_in_kwarg) # Reader that extracts only the assistant message value from send_message args self.assistant_message_json_reader = FunctionArgumentsStreamHandler(json_key=self.assistant_message_tool_kwarg) - self.function_name_buffer = None - self.function_args_buffer = None - self.function_id_buffer = None + # Switch to list-based accumulation to avoid O(n^2) string growth + self._function_name_parts: list[str] = [] + self._function_args_buffer_parts: list[str] | None = None + self._function_id_parts: list[str] = [] self.last_flushed_function_name = None self.last_flushed_function_id = None # Buffer to hold function arguments until inner thoughts are complete - self.current_function_arguments = "" + self._current_function_arguments_parts: list[str] = [] self.current_json_parse_result = {} # Premake IDs for database writes @@ -140,17 +141,39 @@ class OpenAIStreamingInterface: else: return [TextContent(text=content)] + def _get_function_name_buffer(self) -> str | None: + return "".join(self._function_name_parts) if self._function_name_parts else None + + def _get_function_id_buffer(self) -> str | None: + return "".join(self._function_id_parts) if self._function_id_parts else None + + def _clear_function_buffers(self) -> None: + self._function_name_parts = [] + self._function_id_parts = [] + + def _append_function_name(self, s: str) -> None: + self._function_name_parts.append(s) + + def _append_function_id(self, s: str) -> None: + self._function_id_parts.append(s) + + def _append_current_function_arguments(self, s: str) -> None: + self._current_function_arguments_parts.append(s) + + def _get_current_function_arguments(self) -> str: + return "".join(self._current_function_arguments_parts) + def get_tool_call_object(self) -> ToolCall: """Useful for agent loop""" - function_name = self.last_flushed_function_name if self.last_flushed_function_name else self.function_name_buffer + function_name = self.last_flushed_function_name if self.last_flushed_function_name else self._get_function_name_buffer() if not function_name: raise ValueError("No tool call ID available") - tool_call_id = self.last_flushed_function_id if self.last_flushed_function_id else self.function_id_buffer + tool_call_id = self.last_flushed_function_id if self.last_flushed_function_id else self._get_function_id_buffer() if not tool_call_id: raise ValueError("No tool call ID available") return ToolCall( id=tool_call_id, - function=FunctionCall(arguments=self.current_function_arguments, name=function_name), + function=FunctionCall(arguments=self._get_current_function_arguments(), name=function_name), ) async def process( @@ -261,21 +284,15 @@ class OpenAIStreamingInterface: if tool_call.function.name: # If we're waiting for the first key, then we should hold back the name # ie add it to a buffer instead of returning it as a chunk - if self.function_name_buffer is None: - self.function_name_buffer = tool_call.function.name - else: - self.function_name_buffer += tool_call.function.name + self._append_function_name(tool_call.function.name) if tool_call.id: # Buffer until next time - if self.function_id_buffer is None: - self.function_id_buffer = tool_call.id - else: - self.function_id_buffer += tool_call.id + self._append_function_id(tool_call.id) if tool_call.function.arguments: # updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) - self.current_function_arguments += tool_call.function.arguments + self._append_current_function_arguments(tool_call.function.arguments) updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) if self.is_openai_proxy: @@ -301,10 +318,10 @@ class OpenAIStreamingInterface: # Additionally inner thoughts may stream back with a chunk of main JSON # In that case, since we can only return a chunk at a time, we should buffer it if updates_main_json: - if self.function_args_buffer is None: - self.function_args_buffer = updates_main_json + if self._function_args_buffer_parts is None: + self._function_args_buffer_parts = [updates_main_json] else: - self.function_args_buffer += updates_main_json + self._function_args_buffer_parts.append(updates_main_json) # If we have main_json, we should output a ToolCallMessage elif updates_main_json: @@ -312,27 +329,27 @@ class OpenAIStreamingInterface: # NOTE: we could output it as part of a chunk that has both name and args, # however the frontend may expect name first, then args, so to be # safe we'll output name first in a separate chunk - if self.function_name_buffer: + if self._get_function_name_buffer(): # use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..." - if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name: + if self.use_assistant_message and self._get_function_name_buffer() == self.assistant_message_tool_name: # Store the ID of the tool call so allow skipping the corresponding response - if self.function_id_buffer: - self.prev_assistant_message_id = self.function_id_buffer + if self._get_function_id_buffer(): + self.prev_assistant_message_id = self._get_function_id_buffer() # Reset message reader at the start of a new send_message stream self.assistant_message_json_reader.reset() else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 - self.tool_call_name = str(self.function_name_buffer) + self.tool_call_name = str(self._get_function_name_buffer()) if self.tool_call_name in self.requires_approval_tools: tool_call_msg = ApprovalRequestMessage( id=decrement_message_uuid(self.letta_message_id), date=datetime.now(timezone.utc), tool_call=ToolCallDelta( - name=self.function_name_buffer, + name=self._get_function_name_buffer(), arguments=None, - tool_call_id=self.function_id_buffer, + tool_call_id=self._get_function_id_buffer(), ), otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1), run_id=self.run_id, @@ -340,9 +357,9 @@ class OpenAIStreamingInterface: ) else: tool_call_delta = ToolCallDelta( - name=self.function_name_buffer, + name=self._get_function_name_buffer(), arguments=None, - tool_call_id=self.function_id_buffer, + tool_call_id=self._get_function_id_buffer(), ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, @@ -357,20 +374,19 @@ class OpenAIStreamingInterface: yield tool_call_msg # Record what the last function name we flushed was - self.last_flushed_function_name = self.function_name_buffer + self.last_flushed_function_name = self._get_function_name_buffer() if self.last_flushed_function_id is None: - self.last_flushed_function_id = self.function_id_buffer + self.last_flushed_function_id = self._get_function_id_buffer() # Clear the buffer - self.function_name_buffer = None - self.function_id_buffer = None + self._clear_function_buffers() # Since we're clearing the name buffer, we should store # any updates to the arguments inside a separate buffer # Add any main_json updates to the arguments buffer - if self.function_args_buffer is None: - self.function_args_buffer = updates_main_json + if self._function_args_buffer_parts is None: + self._function_args_buffer_parts = [updates_main_json] else: - self.function_args_buffer += updates_main_json + self._function_args_buffer_parts.append(updates_main_json) # If there was nothing in the name buffer, we can proceed to # output the arguments chunk as a ToolCallMessage @@ -382,9 +398,9 @@ class OpenAIStreamingInterface: ): # Minimal, robust extraction: only emit the value of "message". # If we buffered a prefix while name was streaming, feed it first. - if self.function_args_buffer: - payload = self.function_args_buffer + tool_call.function.arguments - self.function_args_buffer = None + if self._function_args_buffer_parts: + payload = "".join(self._function_args_buffer_parts + [tool_call.function.arguments]) + self._function_args_buffer_parts = None else: payload = tool_call.function.arguments extracted = self.assistant_message_json_reader.process_json_chunk(payload) @@ -403,24 +419,24 @@ class OpenAIStreamingInterface: prev_message_type = assistant_message.message_type yield assistant_message # Store the ID of the tool call so allow skipping the corresponding response - if self.function_id_buffer: - self.prev_assistant_message_id = self.function_id_buffer + if self._get_function_id_buffer(): + self.prev_assistant_message_id = self._get_function_id_buffer() else: # There may be a buffer from a previous chunk, for example # if the previous chunk had arguments but we needed to flush name - if self.function_args_buffer: + if self._function_args_buffer_parts: # In this case, we should release the buffer + new data at once - combined_chunk = self.function_args_buffer + updates_main_json + combined_chunk = "".join(self._function_args_buffer_parts + [updates_main_json]) if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 - if self.function_name_buffer in self.requires_approval_tools: + if self._get_function_name_buffer() in self.requires_approval_tools: tool_call_msg = ApprovalRequestMessage( id=decrement_message_uuid(self.letta_message_id), date=datetime.now(timezone.utc), tool_call=ToolCallDelta( - name=self.function_name_buffer, + name=self._get_function_name_buffer(), arguments=combined_chunk, - tool_call_id=self.function_id_buffer, + tool_call_id=self._get_function_id_buffer(), ), # name=name, otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1), @@ -429,9 +445,9 @@ class OpenAIStreamingInterface: ) else: tool_call_delta = ToolCallDelta( - name=self.function_name_buffer, + name=self._get_function_name_buffer(), arguments=combined_chunk, - tool_call_id=self.function_id_buffer, + tool_call_id=self._get_function_id_buffer(), ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, @@ -446,20 +462,20 @@ class OpenAIStreamingInterface: prev_message_type = tool_call_msg.message_type yield tool_call_msg # clear buffer - self.function_args_buffer = None - self.function_id_buffer = None + self._function_args_buffer_parts = None + self._function_id_parts = [] else: # If there's no buffer to clear, just output a new chunk with new data if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 - if self.function_name_buffer in self.requires_approval_tools: + if self._get_function_name_buffer() in self.requires_approval_tools: tool_call_msg = ApprovalRequestMessage( id=decrement_message_uuid(self.letta_message_id), date=datetime.now(timezone.utc), tool_call=ToolCallDelta( name=None, arguments=updates_main_json, - tool_call_id=self.function_id_buffer, + tool_call_id=self._get_function_id_buffer(), ), # name=name, otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1), @@ -470,7 +486,7 @@ class OpenAIStreamingInterface: tool_call_delta = ToolCallDelta( name=None, arguments=updates_main_json, - tool_call_id=self.function_id_buffer, + tool_call_id=self._get_function_id_buffer(), ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, @@ -484,7 +500,7 @@ class OpenAIStreamingInterface: ) prev_message_type = tool_call_msg.message_type yield tool_call_msg - self.function_id_buffer = None + self._function_id_parts = [] class SimpleOpenAIStreamingInterface: @@ -539,6 +555,7 @@ class SimpleOpenAIStreamingInterface: concat_content = "" merged_messages = [] reasoning_content = [] + concat_content_parts: list[str] = [] for msg in self.content_messages: if isinstance(msg, HiddenReasoningMessage) and not shown_omitted: @@ -548,16 +565,16 @@ class SimpleOpenAIStreamingInterface: reasoning_content.append(msg.reasoning) elif isinstance(msg, AssistantMessage): if isinstance(msg.content, list): - concat_content += "".join([c.text for c in msg.content]) + concat_content_parts.append("".join([c.text for c in msg.content])) else: - concat_content += msg.content + concat_content_parts.append(msg.content) if reasoning_content: combined_reasoning = "".join(reasoning_content) merged_messages.append(ReasoningContent(is_native=True, reasoning=combined_reasoning, signature=None)) - if concat_content: - merged_messages.append(TextContent(text=concat_content)) + if concat_content_parts: + merged_messages.append(TextContent(text="".join(concat_content_parts))) return merged_messages @@ -569,9 +586,9 @@ class SimpleOpenAIStreamingInterface: result: list[ToolCall] = [] for idx in ordered_indices: ctx = self._tool_calls_acc[idx] - name = ctx.get("name", "") - args = ctx.get("arguments", "") - call_id = ctx.get("id", "") + name = "".join(ctx.get("name_parts", [])) if "name_parts" in ctx else ctx.get("name", "") + args = "".join(ctx.get("arguments_parts", [])) if "arguments_parts" in ctx else ctx.get("arguments", "") + call_id = "".join(ctx.get("id_parts", [])) if "id_parts" in ctx else ctx.get("id", "") if call_id and name: result.append(ToolCall(id=call_id, function=FunctionCall(arguments=args or "", name=name))) return result @@ -742,15 +759,15 @@ class SimpleOpenAIStreamingInterface: if idx not in self._tool_call_start_order: self._tool_call_start_order.append(idx) if idx not in self._tool_calls_acc: - self._tool_calls_acc[idx] = {"name": "", "arguments": "", "id": ""} + self._tool_calls_acc[idx] = {"name_parts": [], "arguments_parts": [], "id_parts": []} acc = self._tool_calls_acc[idx] if tool_call.function and tool_call.function.name: - acc["name"] += tool_call.function.name + acc["name_parts"].append(tool_call.function.name) if tool_call.function and tool_call.function.arguments: - acc["arguments"] += tool_call.function.arguments + acc["arguments_parts"].append(tool_call.function.arguments) if tool_call.id: - acc["id"] += tool_call.id + acc["id_parts"].append(tool_call.id) delta = ToolCallDelta( name=tool_call.function.name if (tool_call.function and tool_call.function.name) else None, @@ -758,7 +775,8 @@ class SimpleOpenAIStreamingInterface: tool_call_id=tool_call.id if tool_call.id else None, ) - if acc.get("name") and acc["name"] in self.requires_approval_tools: + _curr_name = "".join(acc.get("name_parts", [])) if "name_parts" in acc else acc.get("name", "") + if _curr_name and _curr_name in self.requires_approval_tools: tool_call_msg = ApprovalRequestMessage( id=decrement_message_uuid(self.letta_message_id), date=datetime.now(timezone.utc),