fix: fix bug triggered by using ada embeddings (#1915)
This commit is contained in:
@@ -314,11 +314,17 @@ def openai_chat_completions_process_stream(
|
||||
for _ in range(len(tool_calls_delta))
|
||||
]
|
||||
|
||||
# There may be many tool calls in a tool calls delta (e.g. parallel tool calls)
|
||||
for tool_call_delta in tool_calls_delta:
|
||||
if tool_call_delta.id is not None:
|
||||
# TODO assert that we're not overwriting?
|
||||
# TODO += instead of =?
|
||||
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
|
||||
if tool_call_delta.index not in range(len(accum_message.tool_calls)):
|
||||
warnings.warn(
|
||||
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
|
||||
)
|
||||
else:
|
||||
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
|
||||
if tool_call_delta.function is not None:
|
||||
if tool_call_delta.function.name is not None:
|
||||
# TODO assert that we're not overwriting?
|
||||
|
||||
@@ -312,11 +312,20 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
# extra prints
|
||||
self.debug = False
|
||||
self.timeout = 30
|
||||
|
||||
def _reset_inner_thoughts_json_reader(self):
|
||||
# A buffer for accumulating function arguments (we want to buffer keys and run checks on each one)
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=self.inner_thoughts_kwarg, wait_for_first_key=True)
|
||||
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
"""An asynchronous generator that yields chunks as they become available."""
|
||||
while self._active:
|
||||
@@ -376,6 +385,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
self._push_to_buffer(self.multi_step_gen_indicator)
|
||||
|
||||
# Wipe the inner thoughts buffers
|
||||
self._reset_inner_thoughts_json_reader()
|
||||
|
||||
def step_complete(self):
|
||||
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
|
||||
if not self.multi_step:
|
||||
@@ -386,6 +398,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# signal that a new step has started in the stream
|
||||
self._push_to_buffer(self.multi_step_indicator)
|
||||
|
||||
# Wipe the inner thoughts buffers
|
||||
self._reset_inner_thoughts_json_reader()
|
||||
|
||||
def step_yield(self):
|
||||
"""If multi_step, this is the true 'stream_end' function."""
|
||||
self._active = False
|
||||
@@ -498,6 +513,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
else:
|
||||
self.function_name_buffer += tool_call.function.name
|
||||
|
||||
if tool_call.id:
|
||||
# Buffer until next time
|
||||
if self.function_id_buffer is None:
|
||||
self.function_id_buffer = tool_call.id
|
||||
else:
|
||||
self.function_id_buffer += tool_call.id
|
||||
|
||||
if tool_call.function.arguments:
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
@@ -518,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
# If we have main_json, we should output a FunctionCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
@@ -526,18 +549,23 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=self.function_name_buffer, arguments=None),
|
||||
function_call=FunctionCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=None,
|
||||
function_call_id=self.function_id_buffer,
|
||||
),
|
||||
)
|
||||
# Clear the buffer
|
||||
self.function_name_buffer = None
|
||||
self.function_id_buffer = None
|
||||
# Since we're clearing the name buffer, we should store
|
||||
# any updates to the arguments inside a separate buffer
|
||||
if updates_main_json:
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a FunctionCallMessage
|
||||
@@ -550,17 +578,27 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=None, arguments=combined_chunk),
|
||||
function_call=FunctionCallDelta(
|
||||
name=None,
|
||||
arguments=combined_chunk,
|
||||
function_call_id=self.function_id_buffer,
|
||||
),
|
||||
)
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=None, arguments=updates_main_json),
|
||||
function_call=FunctionCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
function_call_id=self.function_id_buffer,
|
||||
),
|
||||
)
|
||||
self.function_id_buffer = None
|
||||
|
||||
# # If there's something in the main_json buffer, we should add if to the arguments and release it together
|
||||
# tool_call_delta = {}
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.streaming_utils import JSONInnerThoughtsExtractor
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wait_for_first_key", [True, False])
|
||||
def test_inner_thoughts_in_args(wait_for_first_key):
|
||||
def test_inner_thoughts_in_args_simple(wait_for_first_key):
|
||||
"""Test case where the function_delta.arguments contains inner_thoughts
|
||||
|
||||
Correct output should be inner_thoughts VALUE (not KEY) being written to one buffer
|
||||
@@ -17,7 +19,7 @@ def test_inner_thoughts_in_args(wait_for_first_key):
|
||||
""""inner_thoughts":"Chad's x2 tradition""",
|
||||
" is going strong! 😂 I love the enthusiasm!",
|
||||
" Time to delve into something imaginative:",
|
||||
""" If you could swap lives with any fictional character for a day, who would it be?",""",
|
||||
""" If you could swap lives with any fictional character for a day, who would it be?\"""",
|
||||
",",
|
||||
""""message":"Here we are again, with 'x2'!""",
|
||||
" 🎉 Let's take this chance: If you could swap",
|
||||
@@ -25,6 +27,9 @@ def test_inner_thoughts_in_args(wait_for_first_key):
|
||||
''' who would it be?"''',
|
||||
"}",
|
||||
]
|
||||
print("Basic inner thoughts testcase:", fragments1, "".join(fragments1))
|
||||
# Make sure the string is valid JSON
|
||||
_ = json.loads("".join(fragments1))
|
||||
|
||||
if wait_for_first_key:
|
||||
# If we're waiting for the first key, then the first opening brace should be buffered/held back
|
||||
@@ -78,6 +83,119 @@ def test_inner_thoughts_in_args(wait_for_first_key):
|
||||
), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("wait_for_first_key", [True, False])
|
||||
def test_inner_thoughts_in_args_trailing_quote(wait_for_first_key):
|
||||
# Another test case where there's a function call that has a chunk that ends with a double quote
|
||||
print("Running Test Case: chunk ends with double quote")
|
||||
handler1 = JSONInnerThoughtsExtractor(inner_thoughts_key="inner_thoughts", wait_for_first_key=wait_for_first_key)
|
||||
fragments1 = [
|
||||
# 1
|
||||
"{",
|
||||
# 2
|
||||
"""\"inner_thoughts\":\"User wants to add 'banana' again for a fourth time; I'll track another addition.""",
|
||||
# 3
|
||||
'",',
|
||||
# 4
|
||||
"""\"content\":\"banana""",
|
||||
# 5
|
||||
"""\",\"""",
|
||||
# 6
|
||||
"""request_heartbeat\":\"""",
|
||||
# 7
|
||||
"""true\"""",
|
||||
# 8
|
||||
"}",
|
||||
]
|
||||
print("Double quote test case:", fragments1, "".join(fragments1))
|
||||
# Make sure the string is valid JSON
|
||||
_ = json.loads("".join(fragments1))
|
||||
|
||||
if wait_for_first_key:
|
||||
# If we're waiting for the first key, then the first opening brace should be buffered/held back
|
||||
# until after the inner thoughts are finished
|
||||
expected_updates1 = [
|
||||
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different)
|
||||
{
|
||||
"main_json_update": "",
|
||||
"inner_thoughts_update": "User wants to add 'banana' again for a fourth time; I'll track another addition.",
|
||||
}, # Fragment 2
|
||||
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 3
|
||||
{
|
||||
"main_json_update": '{"content":"banana',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 4
|
||||
{
|
||||
# "main_json_update": '","',
|
||||
"main_json_update": '",',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 5
|
||||
{
|
||||
# "main_json_update": 'request_heartbeat":"',
|
||||
"main_json_update": '"request_heartbeat":"',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 6
|
||||
{
|
||||
"main_json_update": 'true"',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 7
|
||||
{
|
||||
"main_json_update": "}",
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 8
|
||||
]
|
||||
else:
|
||||
pass
|
||||
# If we're not waiting for the first key, then the first opening brace should be written immediately
|
||||
expected_updates1 = [
|
||||
{"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different)
|
||||
{
|
||||
"main_json_update": "",
|
||||
"inner_thoughts_update": "User wants to add 'banana' again for a fourth time; I'll track another addition.",
|
||||
}, # Fragment 2
|
||||
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 3
|
||||
{
|
||||
"main_json_update": '"content":"banana',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 4
|
||||
{
|
||||
# "main_json_update": '","',
|
||||
"main_json_update": '",',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 5
|
||||
{
|
||||
# "main_json_update": 'request_heartbeat":"',
|
||||
"main_json_update": '"request_heartbeat":"',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 6
|
||||
{
|
||||
"main_json_update": 'true"',
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 7
|
||||
{
|
||||
"main_json_update": "}",
|
||||
"inner_thoughts_update": "",
|
||||
}, # Fragment 8
|
||||
]
|
||||
|
||||
current_inner_thoughts = ""
|
||||
current_main_json = ""
|
||||
for idx, (fragment, expected) in enumerate(zip(fragments1, expected_updates1)):
|
||||
updates_main_json, updates_inner_thoughts = handler1.process_fragment(fragment)
|
||||
# Assertions
|
||||
assert (
|
||||
updates_main_json == expected["main_json_update"]
|
||||
), f"Test Case 1, Fragment {idx+1}: Main JSON update mismatch.\nFragment: '{fragment}'\nExpected: '{expected['main_json_update']}'\nGot: '{updates_main_json}'\nCurrent JSON: '{current_main_json}'\nCurrent Inner Thoughts: '{current_inner_thoughts}'"
|
||||
assert (
|
||||
updates_inner_thoughts == expected["inner_thoughts_update"]
|
||||
), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'\nCurrent JSON: '{current_main_json}'\nCurrent Inner Thoughts: '{current_inner_thoughts}'"
|
||||
current_main_json += updates_main_json
|
||||
current_inner_thoughts += updates_inner_thoughts
|
||||
|
||||
print(f"Final JSON: '{current_main_json}'")
|
||||
print(f"Final Inner Thoughts: '{current_inner_thoughts}'")
|
||||
_ = json.loads(current_main_json)
|
||||
|
||||
|
||||
def test_inner_thoughts_not_in_args():
|
||||
"""Test case where the function_delta.arguments does not contain inner_thoughts
|
||||
|
||||
@@ -93,6 +211,9 @@ def test_inner_thoughts_not_in_args():
|
||||
''' who would it be?"''',
|
||||
"}",
|
||||
]
|
||||
print("Basic inner thoughts not in kwargs testcase:", fragments2, "".join(fragments2))
|
||||
# Make sure the string is valid JSON
|
||||
_ = json.loads("".join(fragments2))
|
||||
|
||||
expected_updates2 = [
|
||||
{"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1
|
||||
|
||||
Reference in New Issue
Block a user