From 1a93b85bfd3789a743f6bdf49639f4d9caf25f4b Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Mon, 21 Oct 2024 17:07:20 -0700 Subject: [PATCH] feat: fix streaming `put_inner_thoughts_in_kwargs` (#1913) --- letta/llm_api/helpers.py | 66 +++-- letta/llm_api/openai.py | 21 +- letta/server/rest_api/interface.py | 307 ++++++++++++++------- letta/server/rest_api/routers/v1/agents.py | 14 +- letta/streaming_utils.py | 270 ++++++++++++++++++ tests/test_stream_buffer_readers.py | 125 +++++++++ 6 files changed, 677 insertions(+), 126 deletions(-) create mode 100644 letta/streaming_utils.py create mode 100644 tests/test_stream_buffer_readers.py diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index f35c9a91..2ebc7ae1 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -1,6 +1,7 @@ import copy import json import warnings +from collections import OrderedDict from typing import Any, List, Union import requests @@ -10,6 +11,30 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.utils import json_dumps, printd +def convert_to_structured_output(openai_function: dict) -> dict: + """Convert function call objects to structured output objects + + See: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas + """ + structured_output = { + "name": openai_function["name"], + "description": openai_function["description"], + "strict": True, + "parameters": {"type": "object", "properties": {}, "additionalProperties": False, "required": []}, + } + + for param, details in openai_function["parameters"]["properties"].items(): + structured_output["parameters"]["properties"][param] = {"type": details["type"], "description": details["description"]} + + if "enum" in details: + structured_output["parameters"]["properties"][param]["enum"] = details["enum"] + + # Add all properties to required list + structured_output["parameters"]["required"] = list(structured_output["parameters"]["properties"].keys()) + + return structured_output + + def make_post_request(url: str, headers: dict[str, str], data: dict[str, Any]) -> dict[str, Any]: printd(f"Sending request to {url}") try: @@ -78,33 +103,34 @@ def add_inner_thoughts_to_functions( inner_thoughts_key: str, inner_thoughts_description: str, inner_thoughts_required: bool = True, - # inner_thoughts_to_front: bool = True, TODO support sorting somewhere, probably in the to_dict? ) -> List[dict]: - """Add an inner_thoughts kwarg to every function in the provided list""" - # return copies + """Add an inner_thoughts kwarg to every function in the provided list, ensuring it's the first parameter""" new_functions = [] - - # functions is a list of dicts in the OpenAI schema (https://platform.openai.com/docs/api-reference/chat/create) for function_object in functions: - function_params = function_object["parameters"]["properties"] - required_params = list(function_object["parameters"]["required"]) - - # if the inner thoughts arg doesn't exist, add it - if inner_thoughts_key not in function_params: - function_params[inner_thoughts_key] = { - "type": "string", - "description": inner_thoughts_description, - } - - # make sure it's tagged as required new_function_object = copy.deepcopy(function_object) - if inner_thoughts_required and inner_thoughts_key not in required_params: - required_params.append(inner_thoughts_key) - new_function_object["parameters"]["required"] = required_params + + # Create a new OrderedDict with inner_thoughts as the first item + new_properties = OrderedDict() + new_properties[inner_thoughts_key] = { + "type": "string", + "description": inner_thoughts_description, + } + + # Add the rest of the properties + new_properties.update(function_object["parameters"]["properties"]) + + # Cast OrderedDict back to a regular dict + new_function_object["parameters"]["properties"] = dict(new_properties) + + # Update required parameters if necessary + if inner_thoughts_required: + required_params = new_function_object["parameters"].get("required", []) + if inner_thoughts_key not in required_params: + required_params.insert(0, inner_thoughts_key) + new_function_object["parameters"]["required"] = required_params new_functions.append(new_function_object) - # return a list of copies return new_functions diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 29ba9cfe..b6842e2f 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -9,7 +9,11 @@ from httpx_sse._exceptions import SSEError from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING from letta.errors import LLMError -from letta.llm_api.helpers import add_inner_thoughts_to_functions, make_post_request +from letta.llm_api.helpers import ( + add_inner_thoughts_to_functions, + convert_to_structured_output, + make_post_request, +) from letta.local_llm.constants import ( INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION, @@ -112,7 +116,7 @@ def build_openai_chat_completions_request( use_tool_naming: bool, max_tokens: Optional[int], ) -> ChatCompletionRequest: - if llm_config.put_inner_thoughts_in_kwargs: + if functions and llm_config.put_inner_thoughts_in_kwargs: functions = add_inner_thoughts_to_functions( functions=functions, inner_thoughts_key=INNER_THOUGHTS_KWARG, @@ -154,8 +158,8 @@ def build_openai_chat_completions_request( ) # https://platform.openai.com/docs/guides/text-generation/json-mode # only supported by gpt-4o, gpt-4-turbo, or gpt-3.5-turbo - if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model: - data.response_format = {"type": "json_object"} + # if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model: + # data.response_format = {"type": "json_object"} if "inference.memgpt.ai" in llm_config.model_endpoint: # override user id for inference.memgpt.ai @@ -362,6 +366,8 @@ def openai_chat_completions_process_stream( chat_completion_response.usage.completion_tokens = n_chunks chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks + assert len(chat_completion_response.choices) > 0, chat_completion_response + # printd(chat_completion_response) return chat_completion_response @@ -461,6 +467,13 @@ def openai_chat_completions_request_stream( data.pop("tools") data.pop("tool_choice", None) # extra safe, should exist always (default="auto") + if "tools" in data: + for tool in data["tools"]: + # tool["strict"] = True + tool["function"] = convert_to_structured_output(tool["function"]) + + # print(f"\n\n\n\nData[tools]: {json.dumps(data['tools'], indent=2)}") + printd(f"Sending request to {url}") try: return _sse_post(url=url, data=data, headers=headers) diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 17731f15..a2898e5d 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -8,6 +8,7 @@ from typing import AsyncGenerator, Literal, Optional, Union from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.interface import AgentInterface +from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( AssistantMessage, @@ -23,9 +24,14 @@ from letta.schemas.letta_message import ( from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse from letta.streaming_interface import AgentChunkStreamingInterface +from letta.streaming_utils import ( + FunctionArgumentsStreamHandler, + JSONInnerThoughtsExtractor, +) from letta.utils import is_utc_datetime +# TODO strip from code / deprecate class QueuingInterface(AgentInterface): """Messages are queued inside an internal buffer and manually flushed""" @@ -248,58 +254,6 @@ class QueuingInterface(AgentInterface): self._queue_push(message_api=new_message, message_obj=msg_obj) -class FunctionArgumentsStreamHandler: - """State machine that can process a stream of""" - - def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG): - self.json_key = json_key - self.reset() - - def reset(self): - self.in_message = False - self.key_buffer = "" - self.accumulating = False - self.message_started = False - - def process_json_chunk(self, chunk: str) -> Optional[str]: - """Process a chunk from the function arguments and return the plaintext version""" - - # Use strip to handle only leading and trailing whitespace in control structures - if self.accumulating: - clean_chunk = chunk.strip() - if self.json_key in self.key_buffer: - if ":" in clean_chunk: - self.in_message = True - self.accumulating = False - return None - self.key_buffer += clean_chunk - return None - - if self.in_message: - if chunk.strip() == '"' and self.message_started: - self.in_message = False - self.message_started = False - return None - if not self.message_started and chunk.strip() == '"': - self.message_started = True - return None - if self.message_started: - if chunk.strip().endswith('"'): - self.in_message = False - return chunk.rstrip('"\n') - return chunk - - if chunk.strip() == "{": - self.key_buffer = "" - self.accumulating = True - return None - if chunk.strip() == "}": - self.in_message = False - self.message_started = False - return None - return None - - class StreamingServerInterface(AgentChunkStreamingInterface): """Maintain a generator that is a proxy for self.process_chunk() @@ -316,9 +270,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def __init__( self, multi_step=True, + # Related to if we want to try and pass back the AssistantMessage as a special case function use_assistant_message=False, assistant_message_function_name=DEFAULT_MESSAGE_TOOL, assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, + # Related to if we expect inner_thoughts to be in the kwargs + inner_thoughts_in_kwargs=True, + inner_thoughts_kwarg=INNER_THOUGHTS_KWARG, ): # If streaming mode, ignores base interface calls like .assistant_message, etc self.streaming_mode = False @@ -346,6 +304,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.assistant_message_function_name = assistant_message_function_name self.assistant_message_function_kwarg = assistant_message_function_kwarg + # Support for inner_thoughts_in_kwargs + self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs + self.inner_thoughts_kwarg = inner_thoughts_kwarg + # A buffer for accumulating function arguments (we want to buffer keys and run checks on each one) + self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=inner_thoughts_kwarg, wait_for_first_key=True) + # Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs) + self.function_name_buffer = None + self.function_args_buffer = None + # extra prints self.debug = False self.timeout = 30 @@ -365,16 +332,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # Reset the event until a new item is pushed self._event.clear() - # while self._active: - # # Wait until there is an item in the deque or the stream is deactivated - # await self._event.wait() - - # while self._chunks: - # yield self._chunks.popleft() - - # # Reset the event until a new item is pushed - # self._event.clear() - def get_generator(self) -> AsyncGenerator: """Get the generator that yields processed chunks.""" if not self._active: @@ -419,18 +376,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface): if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: self._push_to_buffer(self.multi_step_gen_indicator) - # self._active = False - # self._event.set() # Unblock the generator if it's waiting to allow it to complete - - # if not self.multi_step: - # # end the stream - # self._active = False - # self._event.set() # Unblock the generator if it's waiting to allow it to complete - # else: - # # signal that a new step has started in the stream - # self._chunks.append(self.multi_step_indicator) - # self._event.set() # Signal that new data is available - def step_complete(self): """Signal from the agent that one 'step' finished (step = LLM response + tool execution)""" if not self.multi_step: @@ -443,8 +388,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def step_yield(self): """If multi_step, this is the true 'stream_end' function.""" - # if self.multi_step: - # end the stream self._active = False self._event.set() # Unblock the generator if it's waiting to allow it to complete @@ -479,8 +422,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface): elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: tool_call = message_delta.tool_calls[0] + # TODO(charles) merge into logic for internal_monologue # special case for trapping `send_message` if self.use_assistant_message and tool_call.function: + if self.inner_thoughts_in_kwargs: + raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported") # If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode @@ -538,6 +484,181 @@ class StreamingServerInterface(AgentChunkStreamingInterface): ), ) + elif self.inner_thoughts_in_kwargs and tool_call.function: + if self.use_assistant_message: + raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported") + + processed_chunk = None + + if tool_call.function.name: + # If we're waiting for the first key, then we should hold back the name + # ie add it to a buffer instead of returning it as a chunk + if self.function_name_buffer is None: + self.function_name_buffer = tool_call.function.name + else: + self.function_name_buffer += tool_call.function.name + + if tool_call.function.arguments: + updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) + + # If we have inner thoughts, we should output them as a chunk + if updates_inner_thoughts: + processed_chunk = InternalMonologue( + id=message_id, + date=message_date, + internal_monologue=updates_inner_thoughts, + ) + # Additionally inner thoughts may stream back with a chunk of main JSON + # In that case, since we can only return a chunk at a time, we should buffer it + if updates_main_json: + if self.function_args_buffer is None: + self.function_args_buffer = updates_main_json + else: + self.function_args_buffer += updates_main_json + + # If we have main_json, we should output a FunctionCallMessage + elif updates_main_json: + # If there's something in the function_name buffer, we should release it first + # NOTE: we could output it as part of a chunk that has both name and args, + # however the frontend may expect name first, then args, so to be + # safe we'll output name first in a separate chunk + if self.function_name_buffer: + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=self.function_name_buffer, arguments=None), + ) + # Clear the buffer + self.function_name_buffer = None + # Since we're clearing the name buffer, we should store + # any updates to the arguments inside a separate buffer + if updates_main_json: + # Add any main_json updates to the arguments buffer + if self.function_args_buffer is None: + self.function_args_buffer = updates_main_json + else: + self.function_args_buffer += updates_main_json + + # If there was nothing in the name buffer, we can proceed to + # output the arguments chunk as a FunctionCallMessage + else: + # There may be a buffer from a previous chunk, for example + # if the previous chunk had arguments but we needed to flush name + if self.function_args_buffer: + # In this case, we should release the buffer + new data at once + combined_chunk = self.function_args_buffer + updates_main_json + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=None, arguments=combined_chunk), + ) + # clear buffer + self.function_args_buffer = None + else: + # If there's no buffer to clear, just output a new chunk with new data + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=None, arguments=updates_main_json), + ) + + # # If there's something in the main_json buffer, we should add if to the arguments and release it together + # tool_call_delta = {} + # if tool_call.id: + # tool_call_delta["id"] = tool_call.id + # if tool_call.function: + # if tool_call.function.arguments: + # # tool_call_delta["arguments"] = tool_call.function.arguments + # # NOTE: using the stripped one + # tool_call_delta["arguments"] = updates_main_json + # # We use the buffered name + # if self.function_name_buffer: + # tool_call_delta["name"] = self.function_name_buffer + # # if tool_call.function.name: + # # tool_call_delta["name"] = tool_call.function.name + + # processed_chunk = FunctionCallMessage( + # id=message_id, + # date=message_date, + # function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + # ) + + else: + processed_chunk = None + + return processed_chunk + + # # NOTE: this is a simplified version of the parsing code that: + # # (1) assumes that the inner_thoughts key will always come first + # # (2) assumes that there's no extra spaces in the stringified JSON + # # i.e., the prefix will look exactly like: "{\"variable\":\"}" + # if tool_call.function.arguments: + # self.function_args_buffer += tool_call.function.arguments + + # # prefix_str = f'{{"\\"{self.inner_thoughts_kwarg}\\":\\"}}' + # prefix_str = f'{{"{self.inner_thoughts_kwarg}":' + # if self.function_args_buffer.startswith(prefix_str): + # print(f"Found prefix!!!: {self.function_args_buffer}") + # else: + # print(f"No prefix found: {self.function_args_buffer}") + + # tool_call_delta = {} + # if tool_call.id: + # tool_call_delta["id"] = tool_call.id + # if tool_call.function: + # if tool_call.function.arguments: + # tool_call_delta["arguments"] = tool_call.function.arguments + # if tool_call.function.name: + # tool_call_delta["name"] = tool_call.function.name + + # processed_chunk = FunctionCallMessage( + # id=message_id, + # date=message_date, + # function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + # ) + + # elif False and self.inner_thoughts_in_kwargs and tool_call.function: + # if self.use_assistant_message: + # raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported") + + # if tool_call.function.arguments: + + # Maintain a state machine to track if we're reading a key vs reading a value + # Technically we can we pre-key, post-key, pre-value, post-value + + # for c in tool_call.function.arguments: + # if self.function_chunks_parsing_state == FunctionChunksParsingState.PRE_KEY: + # if c == '"': + # self.function_chunks_parsing_state = FunctionChunksParsingState.READING_KEY + # elif self.function_chunks_parsing_state == FunctionChunksParsingState.READING_KEY: + # if c == '"': + # self.function_chunks_parsing_state = FunctionChunksParsingState.POST_KEY + + # If we're reading a key: + # if self.function_chunks_parsing_state == FunctionChunksParsingState.READING_KEY: + + # We need to buffer the function arguments until we get complete keys + # We are reading stringified-JSON, so we need to check for keys in data that looks like: + # "arguments":"{\"" + # "arguments":"inner" + # "arguments":"_th" + # "arguments":"ought" + # "arguments":"s" + # "arguments":"\":\"" + + # Once we get a complete key, check if the key matches + + # If it does match, start processing the value (stringified-JSON string + # And with each new chunk, output it as a chunk of type InternalMonologue + + # If the key doesn't match, then flush the buffer as a single FunctionCallMessage chunk + + # If we're reading a value + + # If we're reading the inner thoughts value, we output chunks of type InternalMonologue + + # Otherwise, do simple chunks of FunctionCallMessage + else: tool_call_delta = {} @@ -563,7 +684,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # skip if there's a finish return None else: - raise ValueError(f"Couldn't find delta in chunk: {chunk}") + # Example case that would trigger here: + # id='chatcmpl-AKtUvREgRRvgTW6n8ZafiKuV0mxhQ' + # choices=[ChunkChoice(finish_reason=None, index=0, delta=MessageDelta(content=None, tool_calls=None, function_call=None), logprobs=None)] + # created=datetime.datetime(2024, 10, 21, 20, 40, 57, tzinfo=TzInfo(UTC)) + # model='gpt-4o-mini-2024-07-18' + # object='chat.completion.chunk' + warnings.warn(f"Couldn't find delta in chunk: {chunk}") + return None return processed_chunk @@ -663,6 +791,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(), # "id": str(msg_obj.id) if msg_obj is not None else None, # } + assert msg_obj is not None, "Internal monologue requires msg_obj references for metadata" processed_chunk = InternalMonologue( id=msg_obj.id, date=msg_obj.created_at, @@ -676,18 +805,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): """Letta uses send_message""" - # if not self.streaming_mode and self.send_message_special_case: - - # # create a fake "chunk" of a stream - # processed_chunk = { - # "assistant_message": msg, - # "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(), - # "id": str(msg_obj.id) if msg_obj is not None else None, - # } - - # self._chunks.append(processed_chunk) - # self._event.set() # Signal that new data is available - + # NOTE: this is a no-op, we handle this special case in function_message instead return def function_message(self, msg: str, msg_obj: Optional[Message] = None): @@ -699,6 +817,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): if msg.startswith("Running "): if not self.streaming_mode: # create a fake "chunk" of a stream + assert msg_obj.tool_calls is not None and len(msg_obj.tool_calls) > 0, "Function call required for function_message" function_call = msg_obj.tool_calls[0] if self.nonstreaming_legacy_mode: @@ -784,13 +903,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface): return else: return - # msg = msg.replace("Running ", "") - # new_message = {"function_call": msg} elif msg.startswith("Ran "): return - # msg = msg.replace("Ran ", "Function call returned: ") - # new_message = {"function_call": msg} elif msg.startswith("Success: "): msg = msg.replace("Success: ", "") @@ -821,10 +936,4 @@ class StreamingServerInterface(AgentChunkStreamingInterface): raise ValueError(msg) new_message = {"function_message": msg} - # add extra metadata - # if msg_obj is not None: - # new_message["id"] = str(msg_obj.id) - # assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at - # new_message["date"] = msg_obj.created_at.isoformat() - self._push_to_buffer(new_message) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index b928509f..15ea2109 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -430,9 +430,6 @@ async def send_message_to_agent( # Get the generator object off of the agent's streaming interface # This will be attached to the POST SSE request used under-the-hood letta_agent = server._get_or_load_agent(agent_id=agent_id) - streaming_interface = letta_agent.interface - if not isinstance(streaming_interface, StreamingServerInterface): - raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") # Disable token streaming if not OpenAI # TODO: cleanup this logic @@ -441,6 +438,12 @@ async def send_message_to_agent( print("Warning: token streaming is only supported for OpenAI models. Setting to False.") stream_tokens = False + # Create a new interface per request + letta_agent.interface = StreamingServerInterface() + streaming_interface = letta_agent.interface + if not isinstance(streaming_interface, StreamingServerInterface): + raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}") + # Enable token-streaming within the request if desired streaming_interface.streaming_mode = stream_tokens # "chatcompletion mode" does some remapping and ignores inner thoughts @@ -454,6 +457,11 @@ async def send_message_to_agent( streaming_interface.assistant_message_function_name = assistant_message_function_name streaming_interface.assistant_message_function_kwarg = assistant_message_function_kwarg + # Related to JSON buffer reader + streaming_interface.inner_thoughts_in_kwargs = ( + llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False + ) + # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() task = asyncio.create_task( diff --git a/letta/streaming_utils.py b/letta/streaming_utils.py new file mode 100644 index 00000000..61b6fa7a --- /dev/null +++ b/letta/streaming_utils.py @@ -0,0 +1,270 @@ +from typing import Optional + +from letta.constants import DEFAULT_MESSAGE_TOOL_KWARG + + +class JSONInnerThoughtsExtractor: + """ + A class to process incoming JSON fragments and extract 'inner_thoughts' separately from the main JSON. + + This handler processes JSON fragments incrementally, parsing out the value associated with a specified key (default is 'inner_thoughts'). It maintains two separate buffers: + + - `main_json`: Accumulates the JSON data excluding the 'inner_thoughts' key-value pair. + - `inner_thoughts`: Accumulates the value associated with the 'inner_thoughts' key. + + **Parameters:** + + - `inner_thoughts_key` (str): The key to extract from the JSON (default is 'inner_thoughts'). + - `wait_for_first_key` (bool): If `True`, holds back main JSON output until after the 'inner_thoughts' value is processed. + + **Functionality:** + + - **Stateful Parsing:** Maintains parsing state across fragments. + - **String Handling:** Correctly processes strings, escape sequences, and quotation marks. + - **Selective Extraction:** Identifies and extracts the value of the specified key. + - **Fragment Processing:** Handles data that arrives in chunks. + + **Usage:** + + ```python + extractor = JSONInnerThoughtsExtractor(wait_for_first_key=True) + for fragment in fragments: + updates_main_json, updates_inner_thoughts = extractor.process_fragment(fragment) + ``` + + """ + + def __init__(self, inner_thoughts_key="inner_thoughts", wait_for_first_key=False): + self.inner_thoughts_key = inner_thoughts_key + self.wait_for_first_key = wait_for_first_key + self.main_buffer = "" + self.inner_thoughts_buffer = "" + self.state = "start" # Possible states: start, key, colon, value, comma_or_end, end + self.in_string = False + self.escaped = False + self.current_key = "" + self.is_inner_thoughts_value = False + self.inner_thoughts_processed = False + self.hold_main_json = wait_for_first_key + self.main_json_held_buffer = "" + + def process_fragment(self, fragment): + updates_main_json = "" + updates_inner_thoughts = "" + i = 0 + while i < len(fragment): + c = fragment[i] + if self.escaped: + self.escaped = False + if self.in_string: + if self.state == "key": + self.current_key += c + elif self.state == "value": + if self.is_inner_thoughts_value: + updates_inner_thoughts += c + self.inner_thoughts_buffer += c + else: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + else: + if not self.is_inner_thoughts_value: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + elif c == "\\": + self.escaped = True + if self.in_string: + if self.state == "key": + self.current_key += c + elif self.state == "value": + if self.is_inner_thoughts_value: + updates_inner_thoughts += c + self.inner_thoughts_buffer += c + else: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + else: + if not self.is_inner_thoughts_value: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + elif c == '"': + if not self.escaped: + self.in_string = not self.in_string + if self.in_string: + if self.state in ["start", "comma_or_end"]: + self.state = "key" + self.current_key = "" + # Release held main_json when starting to process the next key + if self.wait_for_first_key and self.hold_main_json and self.inner_thoughts_processed: + updates_main_json += self.main_json_held_buffer + self.main_buffer += self.main_json_held_buffer + self.main_json_held_buffer = "" + self.hold_main_json = False + else: + if self.state == "key": + self.state = "colon" + elif self.state == "value": + # End of value + if self.is_inner_thoughts_value: + self.inner_thoughts_processed = True + # Do not release held main_json here + else: + if self.hold_main_json: + self.main_json_held_buffer += '"' + else: + updates_main_json += '"' + self.main_buffer += '"' + self.state = "comma_or_end" + else: + self.escaped = False + if self.in_string: + if self.state == "key": + self.current_key += '"' + elif self.state == "value": + if self.is_inner_thoughts_value: + updates_inner_thoughts += '"' + self.inner_thoughts_buffer += '"' + else: + if self.hold_main_json: + self.main_json_held_buffer += '"' + else: + updates_main_json += '"' + self.main_buffer += '"' + elif self.in_string: + if self.state == "key": + self.current_key += c + elif self.state == "value": + if self.is_inner_thoughts_value: + updates_inner_thoughts += c + self.inner_thoughts_buffer += c + else: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + else: + if c == ":" and self.state == "colon": + self.state = "value" + self.is_inner_thoughts_value = self.current_key == self.inner_thoughts_key + if self.is_inner_thoughts_value: + pass # Do not include 'inner_thoughts' key in main_json + else: + key_colon = f'"{self.current_key}":' + if self.hold_main_json: + self.main_json_held_buffer += key_colon + '"' + else: + updates_main_json += key_colon + '"' + self.main_buffer += key_colon + '"' + elif c == "," and self.state == "comma_or_end": + if self.is_inner_thoughts_value: + # Inner thoughts value ended + self.is_inner_thoughts_value = False + self.state = "start" + # Do not release held main_json here + else: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + self.state = "start" + elif c == "{": + if not self.is_inner_thoughts_value: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + elif c == "}": + self.state = "end" + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + else: + if self.state == "value": + if self.is_inner_thoughts_value: + updates_inner_thoughts += c + self.inner_thoughts_buffer += c + else: + if self.hold_main_json: + self.main_json_held_buffer += c + else: + updates_main_json += c + self.main_buffer += c + i += 1 + + return updates_main_json, updates_inner_thoughts + + @property + def main_json(self): + return self.main_buffer + + @property + def inner_thoughts(self): + return self.inner_thoughts_buffer + + +class FunctionArgumentsStreamHandler: + """State machine that can process a stream of""" + + def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG): + self.json_key = json_key + self.reset() + + def reset(self): + self.in_message = False + self.key_buffer = "" + self.accumulating = False + self.message_started = False + + def process_json_chunk(self, chunk: str) -> Optional[str]: + """Process a chunk from the function arguments and return the plaintext version""" + + # Use strip to handle only leading and trailing whitespace in control structures + if self.accumulating: + clean_chunk = chunk.strip() + if self.json_key in self.key_buffer: + if ":" in clean_chunk: + self.in_message = True + self.accumulating = False + return None + self.key_buffer += clean_chunk + return None + + if self.in_message: + if chunk.strip() == '"' and self.message_started: + self.in_message = False + self.message_started = False + return None + if not self.message_started and chunk.strip() == '"': + self.message_started = True + return None + if self.message_started: + if chunk.strip().endswith('"'): + self.in_message = False + return chunk.rstrip('"\n') + return chunk + + if chunk.strip() == "{": + self.key_buffer = "" + self.accumulating = True + return None + if chunk.strip() == "}": + self.in_message = False + self.message_started = False + return None + return None diff --git a/tests/test_stream_buffer_readers.py b/tests/test_stream_buffer_readers.py new file mode 100644 index 00000000..9351ca54 --- /dev/null +++ b/tests/test_stream_buffer_readers.py @@ -0,0 +1,125 @@ +import pytest + +from letta.streaming_utils import JSONInnerThoughtsExtractor + + +@pytest.mark.parametrize("wait_for_first_key", [True, False]) +def test_inner_thoughts_in_args(wait_for_first_key): + """Test case where the function_delta.arguments contains inner_thoughts + + Correct output should be inner_thoughts VALUE (not KEY) being written to one buffer + And everything else (omiting inner_thoughts KEY) being written to the other buffer + """ + print("Running Test Case 1: With 'inner_thoughts'") + handler1 = JSONInnerThoughtsExtractor(inner_thoughts_key="inner_thoughts", wait_for_first_key=wait_for_first_key) + fragments1 = [ + "{", + """"inner_thoughts":"Chad's x2 tradition""", + " is going strong! 😂 I love the enthusiasm!", + " Time to delve into something imaginative:", + """ If you could swap lives with any fictional character for a day, who would it be?",""", + ",", + """"message":"Here we are again, with 'x2'!""", + " 🎉 Let's take this chance: If you could swap", + " lives with any fictional character for a day,", + ''' who would it be?"''', + "}", + ] + + if wait_for_first_key: + # If we're waiting for the first key, then the first opening brace should be buffered/held back + # until after the inner thoughts are finished + expected_updates1 = [ + {"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different) + {"main_json_update": "", "inner_thoughts_update": "Chad's x2 tradition"}, # Fragment 2 + {"main_json_update": "", "inner_thoughts_update": " is going strong! 😂 I love the enthusiasm!"}, # Fragment 3 + {"main_json_update": "", "inner_thoughts_update": " Time to delve into something imaginative:"}, # Fragment 4 + { + "main_json_update": "", + "inner_thoughts_update": " If you could swap lives with any fictional character for a day, who would it be?", + }, # Fragment 5 + {"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 6 (comma after inner_thoughts) + { + "main_json_update": '{"message":"Here we are again, with \'x2\'!', + "inner_thoughts_update": "", + }, # Fragment 7 (NOTE: the brace is included here, instead of at the beginning) + {"main_json_update": " 🎉 Let's take this chance: If you could swap", "inner_thoughts_update": ""}, # Fragment 8 + {"main_json_update": " lives with any fictional character for a day,", "inner_thoughts_update": ""}, # Fragment 9 + {"main_json_update": ' who would it be?"', "inner_thoughts_update": ""}, # Fragment 10 + {"main_json_update": "}", "inner_thoughts_update": ""}, # Fragment 11 + ] + else: + # If we're not waiting for the first key, then the first opening brace should be written immediately + expected_updates1 = [ + {"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1 + {"main_json_update": "", "inner_thoughts_update": "Chad's x2 tradition"}, # Fragment 2 + {"main_json_update": "", "inner_thoughts_update": " is going strong! 😂 I love the enthusiasm!"}, # Fragment 3 + {"main_json_update": "", "inner_thoughts_update": " Time to delve into something imaginative:"}, # Fragment 4 + { + "main_json_update": "", + "inner_thoughts_update": " If you could swap lives with any fictional character for a day, who would it be?", + }, # Fragment 5 + {"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 6 (comma after inner_thoughts) + {"main_json_update": '"message":"Here we are again, with \'x2\'!', "inner_thoughts_update": ""}, # Fragment 7 + {"main_json_update": " 🎉 Let's take this chance: If you could swap", "inner_thoughts_update": ""}, # Fragment 8 + {"main_json_update": " lives with any fictional character for a day,", "inner_thoughts_update": ""}, # Fragment 9 + {"main_json_update": ' who would it be?"', "inner_thoughts_update": ""}, # Fragment 10 + {"main_json_update": "}", "inner_thoughts_update": ""}, # Fragment 11 + ] + + for idx, (fragment, expected) in enumerate(zip(fragments1, expected_updates1)): + updates_main_json, updates_inner_thoughts = handler1.process_fragment(fragment) + # Assertions + assert ( + updates_main_json == expected["main_json_update"] + ), f"Test Case 1, Fragment {idx+1}: Main JSON update mismatch.\nExpected: '{expected['main_json_update']}'\nGot: '{updates_main_json}'" + assert ( + updates_inner_thoughts == expected["inner_thoughts_update"] + ), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'" + + +def test_inner_thoughts_not_in_args(): + """Test case where the function_delta.arguments does not contain inner_thoughts + + Correct output should be everything being written to the main_json buffer + """ + print("Running Test Case 2: Without 'inner_thoughts'") + handler2 = JSONInnerThoughtsExtractor(inner_thoughts_key="inner_thoughts") + fragments2 = [ + "{", + """"message":"Here we are again, with 'x2'!""", + " 🎉 Let's take this chance: If you could swap", + " lives with any fictional character for a day,", + ''' who would it be?"''', + "}", + ] + + expected_updates2 = [ + {"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1 + {"main_json_update": '"message":"Here we are again, with \'x2\'!', "inner_thoughts_update": ""}, # Fragment 2 + {"main_json_update": " 🎉 Let's take this chance: If you could swap", "inner_thoughts_update": ""}, # Fragment 3 + {"main_json_update": " lives with any fictional character for a day,", "inner_thoughts_update": ""}, # Fragment 4 + {"main_json_update": ' who would it be?"', "inner_thoughts_update": ""}, # Fragment 5 + {"main_json_update": "}", "inner_thoughts_update": ""}, # Fragment 6 + ] + + for idx, (fragment, expected) in enumerate(zip(fragments2, expected_updates2)): + updates_main_json, updates_inner_thoughts = handler2.process_fragment(fragment) + # Assertions + assert ( + updates_main_json == expected["main_json_update"] + ), f"Test Case 2, Fragment {idx+1}: Main JSON update mismatch.\nExpected: '{expected['main_json_update']}'\nGot: '{updates_main_json}'" + assert ( + updates_inner_thoughts == expected["inner_thoughts_update"] + ), f"Test Case 2, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'" + + # Final assertions for Test Case 2 + expected_final_main_json2 = '{"message":"Here we are again, with \'x2\'! 🎉 Let\'s take this chance: If you could swap lives with any fictional character for a day, who would it be?"}' + expected_final_inner_thoughts2 = "" + + assert ( + handler2.main_json == expected_final_main_json2 + ), f"Test Case 2: Final main_json mismatch.\nExpected: '{expected_final_main_json2}'\nGot: '{handler2.main_json}'" + assert ( + handler2.inner_thoughts == expected_final_inner_thoughts2 + ), f"Test Case 2: Final inner_thoughts mismatch.\nExpected: '{expected_final_inner_thoughts2}'\nGot: '{handler2.inner_thoughts}'"