fix: fix bug triggered by using ada embeddings (#1915)

This commit is contained in:
Charles Packer
2024-10-22 16:42:02 -07:00
committed by GitHub
parent 7871eeb9c2
commit a70dea15a4
3 changed files with 177 additions and 12 deletions

View File

@@ -314,11 +314,17 @@ def openai_chat_completions_process_stream(
for _ in range(len(tool_calls_delta))
]
# There may be many tool calls in a tool calls delta (e.g. parallel tool calls)
for tool_call_delta in tool_calls_delta:
if tool_call_delta.id is not None:
# TODO assert that we're not overwriting?
# TODO += instead of =?
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
if tool_call_delta.index not in range(len(accum_message.tool_calls)):
warnings.warn(
f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}"
)
else:
accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id
if tool_call_delta.function is not None:
if tool_call_delta.function.name is not None:
# TODO assert that we're not overwriting?

View File

@@ -312,11 +312,20 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
self.function_name_buffer = None
self.function_args_buffer = None
self.function_id_buffer = None
# extra prints
self.debug = False
self.timeout = 30
def _reset_inner_thoughts_json_reader(self):
# A buffer for accumulating function arguments (we want to buffer keys and run checks on each one)
self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=self.inner_thoughts_kwarg, wait_for_first_key=True)
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
self.function_name_buffer = None
self.function_args_buffer = None
self.function_id_buffer = None
async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""An asynchronous generator that yields chunks as they become available."""
while self._active:
@@ -376,6 +385,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
self._push_to_buffer(self.multi_step_gen_indicator)
# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()
def step_complete(self):
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
if not self.multi_step:
@@ -386,6 +398,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# signal that a new step has started in the stream
self._push_to_buffer(self.multi_step_indicator)
# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()
def step_yield(self):
"""If multi_step, this is the true 'stream_end' function."""
self._active = False
@@ -498,6 +513,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
self.function_name_buffer += tool_call.function.name
if tool_call.id:
# Buffer until next time
if self.function_id_buffer is None:
self.function_id_buffer = tool_call.id
else:
self.function_id_buffer += tool_call.id
if tool_call.function.arguments:
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
@@ -518,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# If we have main_json, we should output a FunctionCallMessage
elif updates_main_json:
# If there's something in the function_name buffer, we should release it first
# NOTE: we could output it as part of a chunk that has both name and args,
# however the frontend may expect name first, then args, so to be
@@ -526,18 +549,23 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=self.function_name_buffer, arguments=None),
function_call=FunctionCallDelta(
name=self.function_name_buffer,
arguments=None,
function_call_id=self.function_id_buffer,
),
)
# Clear the buffer
self.function_name_buffer = None
self.function_id_buffer = None
# Since we're clearing the name buffer, we should store
# any updates to the arguments inside a separate buffer
if updates_main_json:
# Add any main_json updates to the arguments buffer
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# Add any main_json updates to the arguments buffer
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a FunctionCallMessage
@@ -550,17 +578,27 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=None, arguments=combined_chunk),
function_call=FunctionCallDelta(
name=None,
arguments=combined_chunk,
function_call_id=self.function_id_buffer,
),
)
# clear buffer
self.function_args_buffer = None
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=None, arguments=updates_main_json),
function_call=FunctionCallDelta(
name=None,
arguments=updates_main_json,
function_call_id=self.function_id_buffer,
),
)
self.function_id_buffer = None
# # If there's something in the main_json buffer, we should add if to the arguments and release it together
# tool_call_delta = {}

View File

@@ -1,10 +1,12 @@
import json
import pytest
from letta.streaming_utils import JSONInnerThoughtsExtractor
@pytest.mark.parametrize("wait_for_first_key", [True, False])
def test_inner_thoughts_in_args(wait_for_first_key):
def test_inner_thoughts_in_args_simple(wait_for_first_key):
"""Test case where the function_delta.arguments contains inner_thoughts
Correct output should be inner_thoughts VALUE (not KEY) being written to one buffer
@@ -17,7 +19,7 @@ def test_inner_thoughts_in_args(wait_for_first_key):
""""inner_thoughts":"Chad's x2 tradition""",
" is going strong! 😂 I love the enthusiasm!",
" Time to delve into something imaginative:",
""" If you could swap lives with any fictional character for a day, who would it be?",""",
""" If you could swap lives with any fictional character for a day, who would it be?\"""",
",",
""""message":"Here we are again, with 'x2'!""",
" 🎉 Let's take this chance: If you could swap",
@@ -25,6 +27,9 @@ def test_inner_thoughts_in_args(wait_for_first_key):
''' who would it be?"''',
"}",
]
print("Basic inner thoughts testcase:", fragments1, "".join(fragments1))
# Make sure the string is valid JSON
_ = json.loads("".join(fragments1))
if wait_for_first_key:
# If we're waiting for the first key, then the first opening brace should be buffered/held back
@@ -78,6 +83,119 @@ def test_inner_thoughts_in_args(wait_for_first_key):
), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'"
@pytest.mark.parametrize("wait_for_first_key", [True, False])
def test_inner_thoughts_in_args_trailing_quote(wait_for_first_key):
# Another test case where there's a function call that has a chunk that ends with a double quote
print("Running Test Case: chunk ends with double quote")
handler1 = JSONInnerThoughtsExtractor(inner_thoughts_key="inner_thoughts", wait_for_first_key=wait_for_first_key)
fragments1 = [
# 1
"{",
# 2
"""\"inner_thoughts\":\"User wants to add 'banana' again for a fourth time; I'll track another addition.""",
# 3
'",',
# 4
"""\"content\":\"banana""",
# 5
"""\",\"""",
# 6
"""request_heartbeat\":\"""",
# 7
"""true\"""",
# 8
"}",
]
print("Double quote test case:", fragments1, "".join(fragments1))
# Make sure the string is valid JSON
_ = json.loads("".join(fragments1))
if wait_for_first_key:
# If we're waiting for the first key, then the first opening brace should be buffered/held back
# until after the inner thoughts are finished
expected_updates1 = [
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different)
{
"main_json_update": "",
"inner_thoughts_update": "User wants to add 'banana' again for a fourth time; I'll track another addition.",
}, # Fragment 2
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 3
{
"main_json_update": '{"content":"banana',
"inner_thoughts_update": "",
}, # Fragment 4
{
# "main_json_update": '","',
"main_json_update": '",',
"inner_thoughts_update": "",
}, # Fragment 5
{
# "main_json_update": 'request_heartbeat":"',
"main_json_update": '"request_heartbeat":"',
"inner_thoughts_update": "",
}, # Fragment 6
{
"main_json_update": 'true"',
"inner_thoughts_update": "",
}, # Fragment 7
{
"main_json_update": "}",
"inner_thoughts_update": "",
}, # Fragment 8
]
else:
pass
# If we're not waiting for the first key, then the first opening brace should be written immediately
expected_updates1 = [
{"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different)
{
"main_json_update": "",
"inner_thoughts_update": "User wants to add 'banana' again for a fourth time; I'll track another addition.",
}, # Fragment 2
{"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 3
{
"main_json_update": '"content":"banana',
"inner_thoughts_update": "",
}, # Fragment 4
{
# "main_json_update": '","',
"main_json_update": '",',
"inner_thoughts_update": "",
}, # Fragment 5
{
# "main_json_update": 'request_heartbeat":"',
"main_json_update": '"request_heartbeat":"',
"inner_thoughts_update": "",
}, # Fragment 6
{
"main_json_update": 'true"',
"inner_thoughts_update": "",
}, # Fragment 7
{
"main_json_update": "}",
"inner_thoughts_update": "",
}, # Fragment 8
]
current_inner_thoughts = ""
current_main_json = ""
for idx, (fragment, expected) in enumerate(zip(fragments1, expected_updates1)):
updates_main_json, updates_inner_thoughts = handler1.process_fragment(fragment)
# Assertions
assert (
updates_main_json == expected["main_json_update"]
), f"Test Case 1, Fragment {idx+1}: Main JSON update mismatch.\nFragment: '{fragment}'\nExpected: '{expected['main_json_update']}'\nGot: '{updates_main_json}'\nCurrent JSON: '{current_main_json}'\nCurrent Inner Thoughts: '{current_inner_thoughts}'"
assert (
updates_inner_thoughts == expected["inner_thoughts_update"]
), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'\nCurrent JSON: '{current_main_json}'\nCurrent Inner Thoughts: '{current_inner_thoughts}'"
current_main_json += updates_main_json
current_inner_thoughts += updates_inner_thoughts
print(f"Final JSON: '{current_main_json}'")
print(f"Final Inner Thoughts: '{current_inner_thoughts}'")
_ = json.loads(current_main_json)
def test_inner_thoughts_not_in_args():
"""Test case where the function_delta.arguments does not contain inner_thoughts
@@ -93,6 +211,9 @@ def test_inner_thoughts_not_in_args():
''' who would it be?"''',
"}",
]
print("Basic inner thoughts not in kwargs testcase:", fragments2, "".join(fragments2))
# Make sure the string is valid JSON
_ = json.loads("".join(fragments2))
expected_updates2 = [
{"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1