diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 3b87b4ff..9234b5cd 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -279,9 +279,11 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 if self.tool_call_name not in self.requires_approval_tools: + tool_call_delta = ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id) tool_call_msg = ToolCallMessage( id=self.letta_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, @@ -423,15 +425,17 @@ class AnthropicStreamingInterface: tool_call_args += buffered_msg.tool_call.arguments if buffered_msg.tool_call.arguments else "" tool_call_args = tool_call_args.replace(f'"{INNER_THOUGHTS_KWARG}": "{current_inner_thoughts}"', "") + tool_call_delta = ToolCallDelta( + name=self.tool_call_name, + tool_call_id=self.tool_call_id, + arguments=tool_call_args, + ) tool_call_msg = ToolCallMessage( id=self.tool_call_buffer[0].id, otid=Message.generate_otid_from_id(self.tool_call_buffer[0].id, message_index), date=self.tool_call_buffer[0].date, - tool_call=ToolCallDelta( - name=self.tool_call_name, - tool_call_id=self.tool_call_id, - arguments=tool_call_args, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, run_id=self.run_id, ) prev_message_type = tool_call_msg.message_type @@ -467,9 +471,13 @@ class AnthropicStreamingInterface: run_id=self.run_id, ) else: + tool_call_delta = ToolCallDelta( + name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json + ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, date=datetime.now(timezone.utc).isoformat(), run_id=self.run_id, ) @@ -778,9 +786,11 @@ class SimpleAnthropicStreamingInterface: else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tool_call_delta = ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id) tool_call_msg = ToolCallMessage( id=self.letta_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, @@ -860,9 +870,11 @@ class SimpleAnthropicStreamingInterface: else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tool_call_delta = ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json) tool_call_msg = ToolCallMessage( id=self.letta_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, diff --git a/letta/interfaces/gemini_streaming_interface.py b/letta/interfaces/gemini_streaming_interface.py index 97fd613d..9e3daf9e 100644 --- a/letta/interfaces/gemini_streaming_interface.py +++ b/letta/interfaces/gemini_streaming_interface.py @@ -273,15 +273,17 @@ class SimpleGeminiStreamingInterface: else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tool_call_delta = ToolCallDelta( + name=name, + arguments=arguments_str, + tool_call_id=call_id, + ) yield ToolCallMessage( id=self.letta_message_id, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=name, - arguments=arguments_str, - tool_call_id=call_id, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, run_id=self.run_id, step_id=self.step_id, ) diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 6e539579..ee8c7bfa 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -336,14 +336,16 @@ class OpenAIStreamingInterface: step_id=self.step_id, ) else: + tool_call_delta = ToolCallDelta( + name=self.function_name_buffer, + arguments=None, + tool_call_id=self.function_id_buffer, + ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=self.function_name_buffer, - arguments=None, - tool_call_id=self.function_id_buffer, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, step_id=self.step_id, @@ -423,14 +425,16 @@ class OpenAIStreamingInterface: step_id=self.step_id, ) else: + tool_call_delta = ToolCallDelta( + name=self.function_name_buffer, + arguments=combined_chunk, + tool_call_id=self.function_id_buffer, + ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=self.function_name_buffer, - arguments=combined_chunk, - tool_call_id=self.function_id_buffer, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, @@ -460,14 +464,16 @@ class OpenAIStreamingInterface: step_id=self.step_id, ) else: + tool_call_delta = ToolCallDelta( + name=None, + arguments=updates_main_json, + tool_call_id=self.function_id_buffer, + ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=None, - arguments=updates_main_json, - tool_call_id=self.function_id_buffer, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, @@ -717,14 +723,16 @@ class SimpleOpenAIStreamingInterface: step_id=self.step_id, ) else: + tool_call_delta = ToolCallDelta( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) tool_call_msg = ToolCallMessage( id=self.letta_message_id, date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - tool_call_id=tool_call.id, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, # name=name, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, @@ -945,15 +953,17 @@ class SimpleOpenAIResponsesStreamingInterface: else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tool_call_delta = ToolCallDelta( + name=name, + arguments=arguments if arguments != "" else None, + tool_call_id=call_id, + ) yield ToolCallMessage( id=self.letta_message_id, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=name, - arguments=arguments if arguments != "" else None, - tool_call_id=call_id, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, run_id=self.run_id, step_id=self.step_id, ) @@ -1113,15 +1123,17 @@ class SimpleOpenAIResponsesStreamingInterface: else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tool_call_delta = ToolCallDelta( + name=None, + arguments=delta, + tool_call_id=None, + ) yield ToolCallMessage( id=self.letta_message_id, otid=Message.generate_otid_from_id(self.letta_message_id, message_index), date=datetime.now(timezone.utc), - tool_call=ToolCallDelta( - name=None, - arguments=delta, - tool_call_id=None, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, run_id=self.run_id, step_id=self.step_id, ) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 71f6d328..98e3097b 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -518,15 +518,17 @@ class Message(BaseMessage): ) ) else: + tool_call_obj = ToolCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) messages.append( ToolCallMessage( id=self.id, date=self.created_at, - tool_call=ToolCall( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - tool_call_id=tool_call.id, - ), + tool_call=tool_call_obj, + tool_calls=tool_call_obj, name=self.name, otid=otid, sender_id=self.sender_id, diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index bda4567d..78593e97 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -562,14 +562,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tool_call_delta = ToolCallDelta( + name=json_reasoning_content.get("name"), + arguments=json.dumps(json_reasoning_content.get("arguments")), + tool_call_id=None, + ) processed_chunk = ToolCallMessage( id=message_id, date=message_date, - tool_call=ToolCallDelta( - name=json_reasoning_content.get("name"), - arguments=json.dumps(json_reasoning_content.get("arguments")), - tool_call_id=None, - ), + tool_call=tool_call_delta, + tool_calls=tool_call_delta, name=name, otid=Message.generate_otid_from_id(message_id, message_index), ) @@ -703,14 +705,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tc_delta = ToolCallDelta( + name=tool_call_delta.get("name"), + arguments=tool_call_delta.get("arguments"), + tool_call_id=tool_call_delta.get("id"), + ) processed_chunk = ToolCallMessage( id=message_id, date=message_date, - tool_call=ToolCallDelta( - name=tool_call_delta.get("name"), - arguments=tool_call_delta.get("arguments"), - tool_call_id=tool_call_delta.get("id"), - ), + tool_call=tc_delta, + tool_calls=tc_delta, name=name, otid=Message.generate_otid_from_id(message_id, message_index), ) @@ -779,14 +783,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tc_delta = ToolCallDelta( + name=self.function_name_buffer, + arguments=None, + tool_call_id=self.function_id_buffer, + ) processed_chunk = ToolCallMessage( id=message_id, date=message_date, - tool_call=ToolCallDelta( - name=self.function_name_buffer, - arguments=None, - tool_call_id=self.function_id_buffer, - ), + tool_call=tc_delta, + tool_calls=tc_delta, name=name, otid=Message.generate_otid_from_id(message_id, message_index), ) @@ -843,14 +849,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): combined_chunk = self.function_args_buffer + updates_main_json if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tc_delta = ToolCallDelta( + name=None, + arguments=combined_chunk, + tool_call_id=self.function_id_buffer, + ) processed_chunk = ToolCallMessage( id=message_id, date=message_date, - tool_call=ToolCallDelta( - name=None, - arguments=combined_chunk, - tool_call_id=self.function_id_buffer, - ), + tool_call=tc_delta, + tool_calls=tc_delta, name=name, otid=Message.generate_otid_from_id(message_id, message_index), ) @@ -861,14 +869,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # If there's no buffer to clear, just output a new chunk with new data if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tc_delta = ToolCallDelta( + name=None, + arguments=updates_main_json, + tool_call_id=self.function_id_buffer, + ) processed_chunk = ToolCallMessage( id=message_id, date=message_date, - tool_call=ToolCallDelta( - name=None, - arguments=updates_main_json, - tool_call_id=self.function_id_buffer, - ), + tool_call=tc_delta, + tool_calls=tc_delta, name=name, otid=Message.generate_otid_from_id(message_id, message_index), ) @@ -992,14 +1002,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 + tc_delta = ToolCallDelta( + name=tool_call_delta.get("name"), + arguments=tool_call_delta.get("arguments"), + tool_call_id=tool_call_delta.get("id"), + ) processed_chunk = ToolCallMessage( id=message_id, date=message_date, - tool_call=ToolCallDelta( - name=tool_call_delta.get("name"), - arguments=tool_call_delta.get("arguments"), - tool_call_id=tool_call_delta.get("id"), - ), + tool_call=tc_delta, + tool_calls=tc_delta, name=name, otid=Message.generate_otid_from_id(message_id, message_index), ) @@ -1262,14 +1274,16 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # Store the ID of the tool call so allow skipping the corresponding response self.prev_assistant_message_id = function_call.id else: + tool_call_obj = ToolCall( + name=function_call.function.name, + arguments=function_call.function.arguments, + tool_call_id=function_call.id, + ) processed_chunk = ToolCallMessage( id=msg_obj.id, date=msg_obj.created_at, - tool_call=ToolCall( - name=function_call.function.name, - arguments=function_call.function.arguments, - tool_call_id=function_call.id, - ), + tool_call=tool_call_obj, + tool_calls=tool_call_obj, name=msg_obj.name, otid=Message.generate_otid_from_id(msg_obj.id, chunk_index) if chunk_index is not None else None, )