feat: fix streaming put_inner_thoughts_in_kwargs (#1913)

This commit is contained in:
Charles Packer
2024-10-21 17:07:20 -07:00
committed by GitHub
parent e940511a6f
commit 1a93b85bfd
6 changed files with 677 additions and 126 deletions

View File

@@ -1,6 +1,7 @@
import copy
import json
import warnings
from collections import OrderedDict
from typing import Any, List, Union
import requests
@@ -10,6 +11,30 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.utils import json_dumps, printd
def convert_to_structured_output(openai_function: dict) -> dict:
"""Convert function call objects to structured output objects
See: https://platform.openai.com/docs/guides/structured-outputs/supported-schemas
"""
structured_output = {
"name": openai_function["name"],
"description": openai_function["description"],
"strict": True,
"parameters": {"type": "object", "properties": {}, "additionalProperties": False, "required": []},
}
for param, details in openai_function["parameters"]["properties"].items():
structured_output["parameters"]["properties"][param] = {"type": details["type"], "description": details["description"]}
if "enum" in details:
structured_output["parameters"]["properties"][param]["enum"] = details["enum"]
# Add all properties to required list
structured_output["parameters"]["required"] = list(structured_output["parameters"]["properties"].keys())
return structured_output
def make_post_request(url: str, headers: dict[str, str], data: dict[str, Any]) -> dict[str, Any]:
printd(f"Sending request to {url}")
try:
@@ -78,33 +103,34 @@ def add_inner_thoughts_to_functions(
inner_thoughts_key: str,
inner_thoughts_description: str,
inner_thoughts_required: bool = True,
# inner_thoughts_to_front: bool = True, TODO support sorting somewhere, probably in the to_dict?
) -> List[dict]:
"""Add an inner_thoughts kwarg to every function in the provided list"""
# return copies
"""Add an inner_thoughts kwarg to every function in the provided list, ensuring it's the first parameter"""
new_functions = []
# functions is a list of dicts in the OpenAI schema (https://platform.openai.com/docs/api-reference/chat/create)
for function_object in functions:
function_params = function_object["parameters"]["properties"]
required_params = list(function_object["parameters"]["required"])
# if the inner thoughts arg doesn't exist, add it
if inner_thoughts_key not in function_params:
function_params[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}
# make sure it's tagged as required
new_function_object = copy.deepcopy(function_object)
if inner_thoughts_required and inner_thoughts_key not in required_params:
required_params.append(inner_thoughts_key)
new_function_object["parameters"]["required"] = required_params
# Create a new OrderedDict with inner_thoughts as the first item
new_properties = OrderedDict()
new_properties[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}
# Add the rest of the properties
new_properties.update(function_object["parameters"]["properties"])
# Cast OrderedDict back to a regular dict
new_function_object["parameters"]["properties"] = dict(new_properties)
# Update required parameters if necessary
if inner_thoughts_required:
required_params = new_function_object["parameters"].get("required", [])
if inner_thoughts_key not in required_params:
required_params.insert(0, inner_thoughts_key)
new_function_object["parameters"]["required"] = required_params
new_functions.append(new_function_object)
# return a list of copies
return new_functions

View File

@@ -9,7 +9,11 @@ from httpx_sse._exceptions import SSEError
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.llm_api.helpers import add_inner_thoughts_to_functions, make_post_request
from letta.llm_api.helpers import (
add_inner_thoughts_to_functions,
convert_to_structured_output,
make_post_request,
)
from letta.local_llm.constants import (
INNER_THOUGHTS_KWARG,
INNER_THOUGHTS_KWARG_DESCRIPTION,
@@ -112,7 +116,7 @@ def build_openai_chat_completions_request(
use_tool_naming: bool,
max_tokens: Optional[int],
) -> ChatCompletionRequest:
if llm_config.put_inner_thoughts_in_kwargs:
if functions and llm_config.put_inner_thoughts_in_kwargs:
functions = add_inner_thoughts_to_functions(
functions=functions,
inner_thoughts_key=INNER_THOUGHTS_KWARG,
@@ -154,8 +158,8 @@ def build_openai_chat_completions_request(
)
# https://platform.openai.com/docs/guides/text-generation/json-mode
# only supported by gpt-4o, gpt-4-turbo, or gpt-3.5-turbo
if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model:
data.response_format = {"type": "json_object"}
# if "gpt-4o" in llm_config.model or "gpt-4-turbo" in llm_config.model or "gpt-3.5-turbo" in llm_config.model:
# data.response_format = {"type": "json_object"}
if "inference.memgpt.ai" in llm_config.model_endpoint:
# override user id for inference.memgpt.ai
@@ -362,6 +366,8 @@ def openai_chat_completions_process_stream(
chat_completion_response.usage.completion_tokens = n_chunks
chat_completion_response.usage.total_tokens = prompt_tokens + n_chunks
assert len(chat_completion_response.choices) > 0, chat_completion_response
# printd(chat_completion_response)
return chat_completion_response
@@ -461,6 +467,13 @@ def openai_chat_completions_request_stream(
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")
if "tools" in data:
for tool in data["tools"]:
# tool["strict"] = True
tool["function"] = convert_to_structured_output(tool["function"])
# print(f"\n\n\n\nData[tools]: {json.dumps(data['tools'], indent=2)}")
printd(f"Sending request to {url}")
try:
return _sse_post(url=url, data=data, headers=headers)

View File

@@ -8,6 +8,7 @@ from typing import AsyncGenerator, Literal, Optional, Union
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.interface import AgentInterface
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
@@ -23,9 +24,14 @@ from letta.schemas.letta_message import (
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.streaming_utils import (
FunctionArgumentsStreamHandler,
JSONInnerThoughtsExtractor,
)
from letta.utils import is_utc_datetime
# TODO strip from code / deprecate
class QueuingInterface(AgentInterface):
"""Messages are queued inside an internal buffer and manually flushed"""
@@ -248,58 +254,6 @@ class QueuingInterface(AgentInterface):
self._queue_push(message_api=new_message, message_obj=msg_obj)
class FunctionArgumentsStreamHandler:
"""State machine that can process a stream of"""
def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG):
self.json_key = json_key
self.reset()
def reset(self):
self.in_message = False
self.key_buffer = ""
self.accumulating = False
self.message_started = False
def process_json_chunk(self, chunk: str) -> Optional[str]:
"""Process a chunk from the function arguments and return the plaintext version"""
# Use strip to handle only leading and trailing whitespace in control structures
if self.accumulating:
clean_chunk = chunk.strip()
if self.json_key in self.key_buffer:
if ":" in clean_chunk:
self.in_message = True
self.accumulating = False
return None
self.key_buffer += clean_chunk
return None
if self.in_message:
if chunk.strip() == '"' and self.message_started:
self.in_message = False
self.message_started = False
return None
if not self.message_started and chunk.strip() == '"':
self.message_started = True
return None
if self.message_started:
if chunk.strip().endswith('"'):
self.in_message = False
return chunk.rstrip('"\n')
return chunk
if chunk.strip() == "{":
self.key_buffer = ""
self.accumulating = True
return None
if chunk.strip() == "}":
self.in_message = False
self.message_started = False
return None
return None
class StreamingServerInterface(AgentChunkStreamingInterface):
"""Maintain a generator that is a proxy for self.process_chunk()
@@ -316,9 +270,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def __init__(
self,
multi_step=True,
# Related to if we want to try and pass back the AssistantMessage as a special case function
use_assistant_message=False,
assistant_message_function_name=DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
# Related to if we expect inner_thoughts to be in the kwargs
inner_thoughts_in_kwargs=True,
inner_thoughts_kwarg=INNER_THOUGHTS_KWARG,
):
# If streaming mode, ignores base interface calls like .assistant_message, etc
self.streaming_mode = False
@@ -346,6 +304,15 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.assistant_message_function_name = assistant_message_function_name
self.assistant_message_function_kwarg = assistant_message_function_kwarg
# Support for inner_thoughts_in_kwargs
self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs
self.inner_thoughts_kwarg = inner_thoughts_kwarg
# A buffer for accumulating function arguments (we want to buffer keys and run checks on each one)
self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=inner_thoughts_kwarg, wait_for_first_key=True)
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
self.function_name_buffer = None
self.function_args_buffer = None
# extra prints
self.debug = False
self.timeout = 30
@@ -365,16 +332,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# Reset the event until a new item is pushed
self._event.clear()
# while self._active:
# # Wait until there is an item in the deque or the stream is deactivated
# await self._event.wait()
# while self._chunks:
# yield self._chunks.popleft()
# # Reset the event until a new item is pushed
# self._event.clear()
def get_generator(self) -> AsyncGenerator:
"""Get the generator that yields processed chunks."""
if not self._active:
@@ -419,18 +376,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
self._push_to_buffer(self.multi_step_gen_indicator)
# self._active = False
# self._event.set() # Unblock the generator if it's waiting to allow it to complete
# if not self.multi_step:
# # end the stream
# self._active = False
# self._event.set() # Unblock the generator if it's waiting to allow it to complete
# else:
# # signal that a new step has started in the stream
# self._chunks.append(self.multi_step_indicator)
# self._event.set() # Signal that new data is available
def step_complete(self):
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
if not self.multi_step:
@@ -443,8 +388,6 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def step_yield(self):
"""If multi_step, this is the true 'stream_end' function."""
# if self.multi_step:
# end the stream
self._active = False
self._event.set() # Unblock the generator if it's waiting to allow it to complete
@@ -479,8 +422,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
tool_call = message_delta.tool_calls[0]
# TODO(charles) merge into logic for internal_monologue
# special case for trapping `send_message`
if self.use_assistant_message and tool_call.function:
if self.inner_thoughts_in_kwargs:
raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode
@@ -538,6 +484,181 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
),
)
elif self.inner_thoughts_in_kwargs and tool_call.function:
if self.use_assistant_message:
raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")
processed_chunk = None
if tool_call.function.name:
# If we're waiting for the first key, then we should hold back the name
# ie add it to a buffer instead of returning it as a chunk
if self.function_name_buffer is None:
self.function_name_buffer = tool_call.function.name
else:
self.function_name_buffer += tool_call.function.name
if tool_call.function.arguments:
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
# If we have inner thoughts, we should output them as a chunk
if updates_inner_thoughts:
processed_chunk = InternalMonologue(
id=message_id,
date=message_date,
internal_monologue=updates_inner_thoughts,
)
# Additionally inner thoughts may stream back with a chunk of main JSON
# In that case, since we can only return a chunk at a time, we should buffer it
if updates_main_json:
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# If we have main_json, we should output a FunctionCallMessage
elif updates_main_json:
# If there's something in the function_name buffer, we should release it first
# NOTE: we could output it as part of a chunk that has both name and args,
# however the frontend may expect name first, then args, so to be
# safe we'll output name first in a separate chunk
if self.function_name_buffer:
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=self.function_name_buffer, arguments=None),
)
# Clear the buffer
self.function_name_buffer = None
# Since we're clearing the name buffer, we should store
# any updates to the arguments inside a separate buffer
if updates_main_json:
# Add any main_json updates to the arguments buffer
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a FunctionCallMessage
else:
# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=None, arguments=combined_chunk),
)
# clear buffer
self.function_args_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=None, arguments=updates_main_json),
)
# # If there's something in the main_json buffer, we should add if to the arguments and release it together
# tool_call_delta = {}
# if tool_call.id:
# tool_call_delta["id"] = tool_call.id
# if tool_call.function:
# if tool_call.function.arguments:
# # tool_call_delta["arguments"] = tool_call.function.arguments
# # NOTE: using the stripped one
# tool_call_delta["arguments"] = updates_main_json
# # We use the buffered name
# if self.function_name_buffer:
# tool_call_delta["name"] = self.function_name_buffer
# # if tool_call.function.name:
# # tool_call_delta["name"] = tool_call.function.name
# processed_chunk = FunctionCallMessage(
# id=message_id,
# date=message_date,
# function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# )
else:
processed_chunk = None
return processed_chunk
# # NOTE: this is a simplified version of the parsing code that:
# # (1) assumes that the inner_thoughts key will always come first
# # (2) assumes that there's no extra spaces in the stringified JSON
# # i.e., the prefix will look exactly like: "{\"variable\":\"}"
# if tool_call.function.arguments:
# self.function_args_buffer += tool_call.function.arguments
# # prefix_str = f'{{"\\"{self.inner_thoughts_kwarg}\\":\\"}}'
# prefix_str = f'{{"{self.inner_thoughts_kwarg}":'
# if self.function_args_buffer.startswith(prefix_str):
# print(f"Found prefix!!!: {self.function_args_buffer}")
# else:
# print(f"No prefix found: {self.function_args_buffer}")
# tool_call_delta = {}
# if tool_call.id:
# tool_call_delta["id"] = tool_call.id
# if tool_call.function:
# if tool_call.function.arguments:
# tool_call_delta["arguments"] = tool_call.function.arguments
# if tool_call.function.name:
# tool_call_delta["name"] = tool_call.function.name
# processed_chunk = FunctionCallMessage(
# id=message_id,
# date=message_date,
# function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# )
# elif False and self.inner_thoughts_in_kwargs and tool_call.function:
# if self.use_assistant_message:
# raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")
# if tool_call.function.arguments:
# Maintain a state machine to track if we're reading a key vs reading a value
# Technically we can we pre-key, post-key, pre-value, post-value
# for c in tool_call.function.arguments:
# if self.function_chunks_parsing_state == FunctionChunksParsingState.PRE_KEY:
# if c == '"':
# self.function_chunks_parsing_state = FunctionChunksParsingState.READING_KEY
# elif self.function_chunks_parsing_state == FunctionChunksParsingState.READING_KEY:
# if c == '"':
# self.function_chunks_parsing_state = FunctionChunksParsingState.POST_KEY
# If we're reading a key:
# if self.function_chunks_parsing_state == FunctionChunksParsingState.READING_KEY:
# We need to buffer the function arguments until we get complete keys
# We are reading stringified-JSON, so we need to check for keys in data that looks like:
# "arguments":"{\""
# "arguments":"inner"
# "arguments":"_th"
# "arguments":"ought"
# "arguments":"s"
# "arguments":"\":\""
# Once we get a complete key, check if the key matches
# If it does match, start processing the value (stringified-JSON string
# And with each new chunk, output it as a chunk of type InternalMonologue
# If the key doesn't match, then flush the buffer as a single FunctionCallMessage chunk
# If we're reading a value
# If we're reading the inner thoughts value, we output chunks of type InternalMonologue
# Otherwise, do simple chunks of FunctionCallMessage
else:
tool_call_delta = {}
@@ -563,7 +684,14 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# skip if there's a finish
return None
else:
raise ValueError(f"Couldn't find delta in chunk: {chunk}")
# Example case that would trigger here:
# id='chatcmpl-AKtUvREgRRvgTW6n8ZafiKuV0mxhQ'
# choices=[ChunkChoice(finish_reason=None, index=0, delta=MessageDelta(content=None, tool_calls=None, function_call=None), logprobs=None)]
# created=datetime.datetime(2024, 10, 21, 20, 40, 57, tzinfo=TzInfo(UTC))
# model='gpt-4o-mini-2024-07-18'
# object='chat.completion.chunk'
warnings.warn(f"Couldn't find delta in chunk: {chunk}")
return None
return processed_chunk
@@ -663,6 +791,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
# "id": str(msg_obj.id) if msg_obj is not None else None,
# }
assert msg_obj is not None, "Internal monologue requires msg_obj references for metadata"
processed_chunk = InternalMonologue(
id=msg_obj.id,
date=msg_obj.created_at,
@@ -676,18 +805,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None):
"""Letta uses send_message"""
# if not self.streaming_mode and self.send_message_special_case:
# # create a fake "chunk" of a stream
# processed_chunk = {
# "assistant_message": msg,
# "date": msg_obj.created_at.isoformat() if msg_obj is not None else get_utc_time().isoformat(),
# "id": str(msg_obj.id) if msg_obj is not None else None,
# }
# self._chunks.append(processed_chunk)
# self._event.set() # Signal that new data is available
# NOTE: this is a no-op, we handle this special case in function_message instead
return
def function_message(self, msg: str, msg_obj: Optional[Message] = None):
@@ -699,6 +817,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if msg.startswith("Running "):
if not self.streaming_mode:
# create a fake "chunk" of a stream
assert msg_obj.tool_calls is not None and len(msg_obj.tool_calls) > 0, "Function call required for function_message"
function_call = msg_obj.tool_calls[0]
if self.nonstreaming_legacy_mode:
@@ -784,13 +903,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
return
else:
return
# msg = msg.replace("Running ", "")
# new_message = {"function_call": msg}
elif msg.startswith("Ran "):
return
# msg = msg.replace("Ran ", "Function call returned: ")
# new_message = {"function_call": msg}
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
@@ -821,10 +936,4 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
raise ValueError(msg)
new_message = {"function_message": msg}
# add extra metadata
# if msg_obj is not None:
# new_message["id"] = str(msg_obj.id)
# assert is_utc_datetime(msg_obj.created_at), msg_obj.created_at
# new_message["date"] = msg_obj.created_at.isoformat()
self._push_to_buffer(new_message)

View File

@@ -430,9 +430,6 @@ async def send_message_to_agent(
# Get the generator object off of the agent's streaming interface
# This will be attached to the POST SSE request used under-the-hood
letta_agent = server._get_or_load_agent(agent_id=agent_id)
streaming_interface = letta_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
# Disable token streaming if not OpenAI
# TODO: cleanup this logic
@@ -441,6 +438,12 @@ async def send_message_to_agent(
print("Warning: token streaming is only supported for OpenAI models. Setting to False.")
stream_tokens = False
# Create a new interface per request
letta_agent.interface = StreamingServerInterface()
streaming_interface = letta_agent.interface
if not isinstance(streaming_interface, StreamingServerInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
# Enable token-streaming within the request if desired
streaming_interface.streaming_mode = stream_tokens
# "chatcompletion mode" does some remapping and ignores inner thoughts
@@ -454,6 +457,11 @@ async def send_message_to_agent(
streaming_interface.assistant_message_function_name = assistant_message_function_name
streaming_interface.assistant_message_function_kwarg = assistant_message_function_kwarg
# Related to JSON buffer reader
streaming_interface.inner_thoughts_in_kwargs = (
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
)
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
task = asyncio.create_task(

270
letta/streaming_utils.py Normal file
View File

@@ -0,0 +1,270 @@
from typing import Optional
from letta.constants import DEFAULT_MESSAGE_TOOL_KWARG
class JSONInnerThoughtsExtractor:
"""
A class to process incoming JSON fragments and extract 'inner_thoughts' separately from the main JSON.
This handler processes JSON fragments incrementally, parsing out the value associated with a specified key (default is 'inner_thoughts'). It maintains two separate buffers:
- `main_json`: Accumulates the JSON data excluding the 'inner_thoughts' key-value pair.
- `inner_thoughts`: Accumulates the value associated with the 'inner_thoughts' key.
**Parameters:**
- `inner_thoughts_key` (str): The key to extract from the JSON (default is 'inner_thoughts').
- `wait_for_first_key` (bool): If `True`, holds back main JSON output until after the 'inner_thoughts' value is processed.
**Functionality:**
- **Stateful Parsing:** Maintains parsing state across fragments.
- **String Handling:** Correctly processes strings, escape sequences, and quotation marks.
- **Selective Extraction:** Identifies and extracts the value of the specified key.
- **Fragment Processing:** Handles data that arrives in chunks.
**Usage:**
```python
extractor = JSONInnerThoughtsExtractor(wait_for_first_key=True)
for fragment in fragments:
updates_main_json, updates_inner_thoughts = extractor.process_fragment(fragment)
```
"""
def __init__(self, inner_thoughts_key="inner_thoughts", wait_for_first_key=False):
self.inner_thoughts_key = inner_thoughts_key
self.wait_for_first_key = wait_for_first_key
self.main_buffer = ""
self.inner_thoughts_buffer = ""
self.state = "start" # Possible states: start, key, colon, value, comma_or_end, end
self.in_string = False
self.escaped = False
self.current_key = ""
self.is_inner_thoughts_value = False
self.inner_thoughts_processed = False
self.hold_main_json = wait_for_first_key
self.main_json_held_buffer = ""
def process_fragment(self, fragment):
updates_main_json = ""
updates_inner_thoughts = ""
i = 0
while i < len(fragment):
c = fragment[i]
if self.escaped:
self.escaped = False
if self.in_string:
if self.state == "key":
self.current_key += c
elif self.state == "value":
if self.is_inner_thoughts_value:
updates_inner_thoughts += c
self.inner_thoughts_buffer += c
else:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
else:
if not self.is_inner_thoughts_value:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
elif c == "\\":
self.escaped = True
if self.in_string:
if self.state == "key":
self.current_key += c
elif self.state == "value":
if self.is_inner_thoughts_value:
updates_inner_thoughts += c
self.inner_thoughts_buffer += c
else:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
else:
if not self.is_inner_thoughts_value:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
elif c == '"':
if not self.escaped:
self.in_string = not self.in_string
if self.in_string:
if self.state in ["start", "comma_or_end"]:
self.state = "key"
self.current_key = ""
# Release held main_json when starting to process the next key
if self.wait_for_first_key and self.hold_main_json and self.inner_thoughts_processed:
updates_main_json += self.main_json_held_buffer
self.main_buffer += self.main_json_held_buffer
self.main_json_held_buffer = ""
self.hold_main_json = False
else:
if self.state == "key":
self.state = "colon"
elif self.state == "value":
# End of value
if self.is_inner_thoughts_value:
self.inner_thoughts_processed = True
# Do not release held main_json here
else:
if self.hold_main_json:
self.main_json_held_buffer += '"'
else:
updates_main_json += '"'
self.main_buffer += '"'
self.state = "comma_or_end"
else:
self.escaped = False
if self.in_string:
if self.state == "key":
self.current_key += '"'
elif self.state == "value":
if self.is_inner_thoughts_value:
updates_inner_thoughts += '"'
self.inner_thoughts_buffer += '"'
else:
if self.hold_main_json:
self.main_json_held_buffer += '"'
else:
updates_main_json += '"'
self.main_buffer += '"'
elif self.in_string:
if self.state == "key":
self.current_key += c
elif self.state == "value":
if self.is_inner_thoughts_value:
updates_inner_thoughts += c
self.inner_thoughts_buffer += c
else:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
else:
if c == ":" and self.state == "colon":
self.state = "value"
self.is_inner_thoughts_value = self.current_key == self.inner_thoughts_key
if self.is_inner_thoughts_value:
pass # Do not include 'inner_thoughts' key in main_json
else:
key_colon = f'"{self.current_key}":'
if self.hold_main_json:
self.main_json_held_buffer += key_colon + '"'
else:
updates_main_json += key_colon + '"'
self.main_buffer += key_colon + '"'
elif c == "," and self.state == "comma_or_end":
if self.is_inner_thoughts_value:
# Inner thoughts value ended
self.is_inner_thoughts_value = False
self.state = "start"
# Do not release held main_json here
else:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
self.state = "start"
elif c == "{":
if not self.is_inner_thoughts_value:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
elif c == "}":
self.state = "end"
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
else:
if self.state == "value":
if self.is_inner_thoughts_value:
updates_inner_thoughts += c
self.inner_thoughts_buffer += c
else:
if self.hold_main_json:
self.main_json_held_buffer += c
else:
updates_main_json += c
self.main_buffer += c
i += 1
return updates_main_json, updates_inner_thoughts
@property
def main_json(self):
return self.main_buffer
@property
def inner_thoughts(self):
return self.inner_thoughts_buffer
class FunctionArgumentsStreamHandler:
"""State machine that can process a stream of"""
def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG):
self.json_key = json_key
self.reset()
def reset(self):
self.in_message = False
self.key_buffer = ""
self.accumulating = False
self.message_started = False
def process_json_chunk(self, chunk: str) -> Optional[str]:
"""Process a chunk from the function arguments and return the plaintext version"""
# Use strip to handle only leading and trailing whitespace in control structures
if self.accumulating:
clean_chunk = chunk.strip()
if self.json_key in self.key_buffer:
if ":" in clean_chunk:
self.in_message = True
self.accumulating = False
return None
self.key_buffer += clean_chunk
return None
if self.in_message:
if chunk.strip() == '"' and self.message_started:
self.in_message = False
self.message_started = False
return None
if not self.message_started and chunk.strip() == '"':
self.message_started = True
return None
if self.message_started:
if chunk.strip().endswith('"'):
self.in_message = False
return chunk.rstrip('"\n')
return chunk
if chunk.strip() == "{":
self.key_buffer = ""
self.accumulating = True
return None
if chunk.strip() == "}":
self.in_message = False
self.message_started = False
return None
return None

View File

@@ -0,0 +1,125 @@
import pytest
from letta.streaming_utils import JSONInnerThoughtsExtractor
@pytest.mark.parametrize("wait_for_first_key", [True, False])
def test_inner_thoughts_in_args(wait_for_first_key):
"""Test case where the function_delta.arguments contains inner_thoughts
Correct output should be inner_thoughts VALUE (not KEY) being written to one buffer
And everything else (omiting inner_thoughts KEY) being written to the other buffer
"""
print("Running Test Case 1: With 'inner_thoughts'")
handler1 = JSONInnerThoughtsExtractor(inner_thoughts_key="inner_thoughts", wait_for_first_key=wait_for_first_key)
fragments1 = [
"{",
""""inner_thoughts":"Chad's x2 tradition""",
" is going strong! 😂 I love the enthusiasm!",
" Time to delve into something imaginative:",
""" If you could swap lives with any fictional character for a day, who would it be?",""",
",",
""""message":"Here we are again, with 'x2'!""",
" 🎉 Let's take this chance: If you could swap",
" lives with any fictional character for a day,",
''' who would it be?"''',
"}",
]
if wait_for_first_key:
# If we're waiting for the first key, then the first opening brace should be buffered/held back
# until after the inner thoughts are finished
expected_updates1 = [
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different)
{"main_json_update": "", "inner_thoughts_update": "Chad's x2 tradition"}, # Fragment 2
{"main_json_update": "", "inner_thoughts_update": " is going strong! 😂 I love the enthusiasm!"}, # Fragment 3
{"main_json_update": "", "inner_thoughts_update": " Time to delve into something imaginative:"}, # Fragment 4
{
"main_json_update": "",
"inner_thoughts_update": " If you could swap lives with any fictional character for a day, who would it be?",
}, # Fragment 5
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 6 (comma after inner_thoughts)
{
"main_json_update": '{"message":"Here we are again, with \'x2\'!',
"inner_thoughts_update": "",
}, # Fragment 7 (NOTE: the brace is included here, instead of at the beginning)
{"main_json_update": " 🎉 Let's take this chance: If you could swap", "inner_thoughts_update": ""}, # Fragment 8
{"main_json_update": " lives with any fictional character for a day,", "inner_thoughts_update": ""}, # Fragment 9
{"main_json_update": ' who would it be?"', "inner_thoughts_update": ""}, # Fragment 10
{"main_json_update": "}", "inner_thoughts_update": ""}, # Fragment 11
]
else:
# If we're not waiting for the first key, then the first opening brace should be written immediately
expected_updates1 = [
{"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1
{"main_json_update": "", "inner_thoughts_update": "Chad's x2 tradition"}, # Fragment 2
{"main_json_update": "", "inner_thoughts_update": " is going strong! 😂 I love the enthusiasm!"}, # Fragment 3
{"main_json_update": "", "inner_thoughts_update": " Time to delve into something imaginative:"}, # Fragment 4
{
"main_json_update": "",
"inner_thoughts_update": " If you could swap lives with any fictional character for a day, who would it be?",
}, # Fragment 5
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 6 (comma after inner_thoughts)
{"main_json_update": '"message":"Here we are again, with \'x2\'!', "inner_thoughts_update": ""}, # Fragment 7
{"main_json_update": " 🎉 Let's take this chance: If you could swap", "inner_thoughts_update": ""}, # Fragment 8
{"main_json_update": " lives with any fictional character for a day,", "inner_thoughts_update": ""}, # Fragment 9
{"main_json_update": ' who would it be?"', "inner_thoughts_update": ""}, # Fragment 10
{"main_json_update": "}", "inner_thoughts_update": ""}, # Fragment 11
]
for idx, (fragment, expected) in enumerate(zip(fragments1, expected_updates1)):
updates_main_json, updates_inner_thoughts = handler1.process_fragment(fragment)
# Assertions
assert (
updates_main_json == expected["main_json_update"]
), f"Test Case 1, Fragment {idx+1}: Main JSON update mismatch.\nExpected: '{expected['main_json_update']}'\nGot: '{updates_main_json}'"
assert (
updates_inner_thoughts == expected["inner_thoughts_update"]
), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'"
def test_inner_thoughts_not_in_args():
"""Test case where the function_delta.arguments does not contain inner_thoughts
Correct output should be everything being written to the main_json buffer
"""
print("Running Test Case 2: Without 'inner_thoughts'")
handler2 = JSONInnerThoughtsExtractor(inner_thoughts_key="inner_thoughts")
fragments2 = [
"{",
""""message":"Here we are again, with 'x2'!""",
" 🎉 Let's take this chance: If you could swap",
" lives with any fictional character for a day,",
''' who would it be?"''',
"}",
]
expected_updates2 = [
{"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1
{"main_json_update": '"message":"Here we are again, with \'x2\'!', "inner_thoughts_update": ""}, # Fragment 2
{"main_json_update": " 🎉 Let's take this chance: If you could swap", "inner_thoughts_update": ""}, # Fragment 3
{"main_json_update": " lives with any fictional character for a day,", "inner_thoughts_update": ""}, # Fragment 4
{"main_json_update": ' who would it be?"', "inner_thoughts_update": ""}, # Fragment 5
{"main_json_update": "}", "inner_thoughts_update": ""}, # Fragment 6
]
for idx, (fragment, expected) in enumerate(zip(fragments2, expected_updates2)):
updates_main_json, updates_inner_thoughts = handler2.process_fragment(fragment)
# Assertions
assert (
updates_main_json == expected["main_json_update"]
), f"Test Case 2, Fragment {idx+1}: Main JSON update mismatch.\nExpected: '{expected['main_json_update']}'\nGot: '{updates_main_json}'"
assert (
updates_inner_thoughts == expected["inner_thoughts_update"]
), f"Test Case 2, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'"
# Final assertions for Test Case 2
expected_final_main_json2 = '{"message":"Here we are again, with \'x2\'! 🎉 Let\'s take this chance: If you could swap lives with any fictional character for a day, who would it be?"}'
expected_final_inner_thoughts2 = ""
assert (
handler2.main_json == expected_final_main_json2
), f"Test Case 2: Final main_json mismatch.\nExpected: '{expected_final_main_json2}'\nGot: '{handler2.main_json}'"
assert (
handler2.inner_thoughts == expected_final_inner_thoughts2
), f"Test Case 2: Final inner_thoughts mismatch.\nExpected: '{expected_final_inner_thoughts2}'\nGot: '{handler2.inner_thoughts}'"