diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index b6842e2f..bb32662a 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -314,11 +314,17 @@ def openai_chat_completions_process_stream( for _ in range(len(tool_calls_delta)) ] + # There may be many tool calls in a tool calls delta (e.g. parallel tool calls) for tool_call_delta in tool_calls_delta: if tool_call_delta.id is not None: # TODO assert that we're not overwriting? # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id + if tool_call_delta.index not in range(len(accum_message.tool_calls)): + warnings.warn( + f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}" + ) + else: + accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id if tool_call_delta.function is not None: if tool_call_delta.function.name is not None: # TODO assert that we're not overwriting? diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index a2898e5d..cb5ef669 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -312,11 +312,20 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs) self.function_name_buffer = None self.function_args_buffer = None + self.function_id_buffer = None # extra prints self.debug = False self.timeout = 30 + def _reset_inner_thoughts_json_reader(self): + # A buffer for accumulating function arguments (we want to buffer keys and run checks on each one) + self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=self.inner_thoughts_kwarg, wait_for_first_key=True) + # Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs) + self.function_name_buffer = None + self.function_args_buffer = None + self.function_id_buffer = None + async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]: """An asynchronous generator that yields chunks as they become available.""" while self._active: @@ -376,6 +385,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface): if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: self._push_to_buffer(self.multi_step_gen_indicator) + # Wipe the inner thoughts buffers + self._reset_inner_thoughts_json_reader() + def step_complete(self): """Signal from the agent that one 'step' finished (step = LLM response + tool execution)""" if not self.multi_step: @@ -386,6 +398,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # signal that a new step has started in the stream self._push_to_buffer(self.multi_step_indicator) + # Wipe the inner thoughts buffers + self._reset_inner_thoughts_json_reader() + def step_yield(self): """If multi_step, this is the true 'stream_end' function.""" self._active = False @@ -498,6 +513,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: self.function_name_buffer += tool_call.function.name + if tool_call.id: + # Buffer until next time + if self.function_id_buffer is None: + self.function_id_buffer = tool_call.id + else: + self.function_id_buffer += tool_call.id + if tool_call.function.arguments: updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) @@ -518,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # If we have main_json, we should output a FunctionCallMessage elif updates_main_json: + # If there's something in the function_name buffer, we should release it first # NOTE: we could output it as part of a chunk that has both name and args, # however the frontend may expect name first, then args, so to be @@ -526,18 +549,23 @@ class StreamingServerInterface(AgentChunkStreamingInterface): processed_chunk = FunctionCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta(name=self.function_name_buffer, arguments=None), + function_call=FunctionCallDelta( + name=self.function_name_buffer, + arguments=None, + function_call_id=self.function_id_buffer, + ), ) # Clear the buffer self.function_name_buffer = None + self.function_id_buffer = None # Since we're clearing the name buffer, we should store # any updates to the arguments inside a separate buffer - if updates_main_json: - # Add any main_json updates to the arguments buffer - if self.function_args_buffer is None: - self.function_args_buffer = updates_main_json - else: - self.function_args_buffer += updates_main_json + + # Add any main_json updates to the arguments buffer + if self.function_args_buffer is None: + self.function_args_buffer = updates_main_json + else: + self.function_args_buffer += updates_main_json # If there was nothing in the name buffer, we can proceed to # output the arguments chunk as a FunctionCallMessage @@ -550,17 +578,27 @@ class StreamingServerInterface(AgentChunkStreamingInterface): processed_chunk = FunctionCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta(name=None, arguments=combined_chunk), + function_call=FunctionCallDelta( + name=None, + arguments=combined_chunk, + function_call_id=self.function_id_buffer, + ), ) # clear buffer self.function_args_buffer = None + self.function_id_buffer = None else: # If there's no buffer to clear, just output a new chunk with new data processed_chunk = FunctionCallMessage( id=message_id, date=message_date, - function_call=FunctionCallDelta(name=None, arguments=updates_main_json), + function_call=FunctionCallDelta( + name=None, + arguments=updates_main_json, + function_call_id=self.function_id_buffer, + ), ) + self.function_id_buffer = None # # If there's something in the main_json buffer, we should add if to the arguments and release it together # tool_call_delta = {} diff --git a/tests/test_stream_buffer_readers.py b/tests/test_stream_buffer_readers.py index 9351ca54..9a0bb5e8 100644 --- a/tests/test_stream_buffer_readers.py +++ b/tests/test_stream_buffer_readers.py @@ -1,10 +1,12 @@ +import json + import pytest from letta.streaming_utils import JSONInnerThoughtsExtractor @pytest.mark.parametrize("wait_for_first_key", [True, False]) -def test_inner_thoughts_in_args(wait_for_first_key): +def test_inner_thoughts_in_args_simple(wait_for_first_key): """Test case where the function_delta.arguments contains inner_thoughts Correct output should be inner_thoughts VALUE (not KEY) being written to one buffer @@ -17,7 +19,7 @@ def test_inner_thoughts_in_args(wait_for_first_key): """"inner_thoughts":"Chad's x2 tradition""", " is going strong! 😂 I love the enthusiasm!", " Time to delve into something imaginative:", - """ If you could swap lives with any fictional character for a day, who would it be?",""", + """ If you could swap lives with any fictional character for a day, who would it be?\"""", ",", """"message":"Here we are again, with 'x2'!""", " 🎉 Let's take this chance: If you could swap", @@ -25,6 +27,9 @@ def test_inner_thoughts_in_args(wait_for_first_key): ''' who would it be?"''', "}", ] + print("Basic inner thoughts testcase:", fragments1, "".join(fragments1)) + # Make sure the string is valid JSON + _ = json.loads("".join(fragments1)) if wait_for_first_key: # If we're waiting for the first key, then the first opening brace should be buffered/held back @@ -78,6 +83,119 @@ def test_inner_thoughts_in_args(wait_for_first_key): ), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'" +@pytest.mark.parametrize("wait_for_first_key", [True, False]) +def test_inner_thoughts_in_args_trailing_quote(wait_for_first_key): + # Another test case where there's a function call that has a chunk that ends with a double quote + print("Running Test Case: chunk ends with double quote") + handler1 = JSONInnerThoughtsExtractor(inner_thoughts_key="inner_thoughts", wait_for_first_key=wait_for_first_key) + fragments1 = [ + # 1 + "{", + # 2 + """\"inner_thoughts\":\"User wants to add 'banana' again for a fourth time; I'll track another addition.""", + # 3 + '",', + # 4 + """\"content\":\"banana""", + # 5 + """\",\"""", + # 6 + """request_heartbeat\":\"""", + # 7 + """true\"""", + # 8 + "}", + ] + print("Double quote test case:", fragments1, "".join(fragments1)) + # Make sure the string is valid JSON + _ = json.loads("".join(fragments1)) + + if wait_for_first_key: + # If we're waiting for the first key, then the first opening brace should be buffered/held back + # until after the inner thoughts are finished + expected_updates1 = [ + {"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different) + { + "main_json_update": "", + "inner_thoughts_update": "User wants to add 'banana' again for a fourth time; I'll track another addition.", + }, # Fragment 2 + {"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 3 + { + "main_json_update": '{"content":"banana', + "inner_thoughts_update": "", + }, # Fragment 4 + { + # "main_json_update": '","', + "main_json_update": '",', + "inner_thoughts_update": "", + }, # Fragment 5 + { + # "main_json_update": 'request_heartbeat":"', + "main_json_update": '"request_heartbeat":"', + "inner_thoughts_update": "", + }, # Fragment 6 + { + "main_json_update": 'true"', + "inner_thoughts_update": "", + }, # Fragment 7 + { + "main_json_update": "}", + "inner_thoughts_update": "", + }, # Fragment 8 + ] + else: + pass + # If we're not waiting for the first key, then the first opening brace should be written immediately + expected_updates1 = [ + {"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1 (NOTE: different) + { + "main_json_update": "", + "inner_thoughts_update": "User wants to add 'banana' again for a fourth time; I'll track another addition.", + }, # Fragment 2 + {"main_json_update": "", "inner_thoughts_update": ""}, # Fragment 3 + { + "main_json_update": '"content":"banana', + "inner_thoughts_update": "", + }, # Fragment 4 + { + # "main_json_update": '","', + "main_json_update": '",', + "inner_thoughts_update": "", + }, # Fragment 5 + { + # "main_json_update": 'request_heartbeat":"', + "main_json_update": '"request_heartbeat":"', + "inner_thoughts_update": "", + }, # Fragment 6 + { + "main_json_update": 'true"', + "inner_thoughts_update": "", + }, # Fragment 7 + { + "main_json_update": "}", + "inner_thoughts_update": "", + }, # Fragment 8 + ] + + current_inner_thoughts = "" + current_main_json = "" + for idx, (fragment, expected) in enumerate(zip(fragments1, expected_updates1)): + updates_main_json, updates_inner_thoughts = handler1.process_fragment(fragment) + # Assertions + assert ( + updates_main_json == expected["main_json_update"] + ), f"Test Case 1, Fragment {idx+1}: Main JSON update mismatch.\nFragment: '{fragment}'\nExpected: '{expected['main_json_update']}'\nGot: '{updates_main_json}'\nCurrent JSON: '{current_main_json}'\nCurrent Inner Thoughts: '{current_inner_thoughts}'" + assert ( + updates_inner_thoughts == expected["inner_thoughts_update"] + ), f"Test Case 1, Fragment {idx+1}: Inner Thoughts update mismatch.\nExpected: '{expected['inner_thoughts_update']}'\nGot: '{updates_inner_thoughts}'\nCurrent JSON: '{current_main_json}'\nCurrent Inner Thoughts: '{current_inner_thoughts}'" + current_main_json += updates_main_json + current_inner_thoughts += updates_inner_thoughts + + print(f"Final JSON: '{current_main_json}'") + print(f"Final Inner Thoughts: '{current_inner_thoughts}'") + _ = json.loads(current_main_json) + + def test_inner_thoughts_not_in_args(): """Test case where the function_delta.arguments does not contain inner_thoughts @@ -93,6 +211,9 @@ def test_inner_thoughts_not_in_args(): ''' who would it be?"''', "}", ] + print("Basic inner thoughts not in kwargs testcase:", fragments2, "".join(fragments2)) + # Make sure the string is valid JSON + _ = json.loads("".join(fragments2)) expected_updates2 = [ {"main_json_update": "{", "inner_thoughts_update": ""}, # Fragment 1