From 609e63cb1200ae299bc0294e4418cba4cff7af50 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 8 Oct 2025 17:27:30 -0700 Subject: [PATCH] feat: Adjust message conversion to support multiple tool calls [LET-5336] (#5270) * Adjust message conversion * Make collection logic cleaner --- letta/schemas/message.py | 91 ++++++++-- tests/managers/test_message_manager.py | 233 +++++++++++++++++++++++++ 2 files changed, 305 insertions(+), 19 deletions(-) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 4594bb1b..7b9354a3 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -492,13 +492,60 @@ class Message(BaseMessage): assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> List[LettaMessage]: messages = [] - # This is type FunctionCall + + # If assistant mode is off, just create one ToolCallMessage with all tool calls + if not use_assistant_message: + all_tool_call_objs = [ + ToolCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + tool_call_id=tool_call.id, + ) + for tool_call in self.tool_calls + ] + + if all_tool_call_objs: + otid = Message.generate_otid_from_id(self.id, current_message_count) + messages.append( + ToolCallMessage( + id=self.id, + date=self.created_at, + # use first tool call for the deprecated field + tool_call=all_tool_call_objs[0], + tool_calls=all_tool_call_objs, + name=self.name, + otid=otid, + sender_id=self.sender_id, + step_id=self.step_id, + is_err=self.is_err, + run_id=self.run_id, + ) + ) + return messages + + collected_tool_calls = [] + for tool_call in self.tool_calls: otid = Message.generate_otid_from_id(self.id, current_message_count + len(messages)) - # If we're supporting using assistant message, - # then we want to treat certain function calls as a special case - if use_assistant_message and tool_call.function.name == assistant_message_tool_name: - # We need to unpack the actual message contents from the function call + + if tool_call.function.name == assistant_message_tool_name: + if collected_tool_calls: + tool_call_message = ToolCallMessage( + id=self.id, + date=self.created_at, + # use first tool call for the deprecated field + tool_call=collected_tool_calls[0], + tool_calls=collected_tool_calls.copy(), + name=self.name, + otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)), + sender_id=self.sender_id, + step_id=self.step_id, + is_err=self.is_err, + run_id=self.run_id, + ) + messages.append(tool_call_message) + collected_tool_calls = [] # reset the collection + try: func_args = parse_json(tool_call.function.arguments) message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False) @@ -518,25 +565,31 @@ class Message(BaseMessage): ) ) else: + # non-assistant tool call, collect it tool_call_obj = ToolCall( name=tool_call.function.name, arguments=tool_call.function.arguments, tool_call_id=tool_call.id, ) - messages.append( - ToolCallMessage( - id=self.id, - date=self.created_at, - tool_call=tool_call_obj, - tool_calls=[tool_call_obj], - name=self.name, - otid=otid, - sender_id=self.sender_id, - step_id=self.step_id, - is_err=self.is_err, - run_id=self.run_id, - ) - ) + collected_tool_calls.append(tool_call_obj) + + # flush any remaining collected tool calls + if collected_tool_calls: + tool_call_message = ToolCallMessage( + id=self.id, + date=self.created_at, + # use first tool call for the deprecated field + tool_call=collected_tool_calls[0], + tool_calls=collected_tool_calls, + name=self.name, + otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)), + sender_id=self.sender_id, + step_id=self.step_id, + is_err=self.is_err, + run_id=self.run_id, + ) + messages.append(tool_call_message) + return messages def _convert_tool_return_message(self) -> List[ToolReturnMessage]: diff --git a/tests/managers/test_message_manager.py b/tests/managers/test_message_manager.py index 8a4c12fd..d4a7708d 100644 --- a/tests/managers/test_message_manager.py +++ b/tests/managers/test_message_manager.py @@ -784,3 +784,236 @@ async def test_create_many_messages_async_with_turbopuffer(server: SyncServer, s for msg in created_messages: assert msg.id is not None assert msg.agent_id == sarah_agent.id + + +# ====================================================================================================================== +# Pydantic Object Tests - Tool Call Message Conversion +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_convert_tool_call_messages_no_assistant_mode(server: SyncServer, sarah_agent, default_user): + """Test that when assistant mode is off, all tool calls go into a single ToolCallMessage""" + from letta.schemas.letta_message import ToolCall + + # create a message with multiple tool calls + tool_calls = [ + OpenAIToolCall( + id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "test memory 1"}') + ), + OpenAIToolCall( + id="call_2", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "test search"}') + ), + OpenAIToolCall(id="call_3", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Hello there!"}')), + ] + + message = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.assistant, + content=[TextContent(text="Let me help you with that.")], + tool_calls=tool_calls, + ) + + # convert without assistant mode (reverse=True by default) + letta_messages = message.to_letta_messages(use_assistant_message=False) + + # should have 2 messages in reverse order: tool call message, then reasoning message + assert len(letta_messages) == 2 + assert letta_messages[0].message_type == "tool_call_message" + assert letta_messages[1].message_type == "reasoning_message" + + # check the tool call message has all tool calls in the new field + tool_call_msg = letta_messages[0] + assert tool_call_msg.tool_calls is not None + assert len(tool_call_msg.tool_calls) == 3 + + # check backwards compatibility - first tool call in deprecated field + assert tool_call_msg.tool_call is not None + assert tool_call_msg.tool_call.name == "archival_memory_insert" + assert tool_call_msg.tool_call.tool_call_id == "call_1" + + # verify all tool calls are present in the list + tool_names = [tc.name for tc in tool_call_msg.tool_calls] + assert "archival_memory_insert" in tool_names + assert "conversation_search" in tool_names + assert "send_message" in tool_names + + +@pytest.mark.asyncio +async def test_convert_tool_call_messages_with_assistant_mode(server: SyncServer, sarah_agent, default_user): + """Test that with assistant mode, send_message becomes AssistantMessage and others are grouped""" + + # create a message with tool calls including send_message + tool_calls = [ + OpenAIToolCall( + id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "test memory 1"}') + ), + OpenAIToolCall(id="call_2", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Hello there!"}')), + OpenAIToolCall( + id="call_3", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "test search"}') + ), + ] + + message = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.assistant, + content=[TextContent(text="Let me help you with that.")], + tool_calls=tool_calls, + ) + + # convert with assistant mode (reverse=True by default) + letta_messages = message.to_letta_messages(use_assistant_message=True) + + # should have 4 messages in reverse order: + # conversation_search tool call, assistant message, archival_memory tool call, reasoning + assert len(letta_messages) == 4 + assert letta_messages[0].message_type == "tool_call_message" + assert letta_messages[1].message_type == "assistant_message" + assert letta_messages[2].message_type == "tool_call_message" + assert letta_messages[3].message_type == "reasoning_message" + + # check first tool call message (actually the last in forward order) has conversation_search + first_tool_msg = letta_messages[0] + assert len(first_tool_msg.tool_calls) == 1 + assert first_tool_msg.tool_calls[0].name == "conversation_search" + assert first_tool_msg.tool_call.name == "conversation_search" # backwards compat + + # check assistant message content + assistant_msg = letta_messages[1] + assert assistant_msg.content == "Hello there!" + + # check last tool call message (actually the first in forward order) has archival_memory_insert + last_tool_msg = letta_messages[2] + assert len(last_tool_msg.tool_calls) == 1 + assert last_tool_msg.tool_calls[0].name == "archival_memory_insert" + assert last_tool_msg.tool_call.name == "archival_memory_insert" # backwards compat + + +@pytest.mark.asyncio +async def test_convert_tool_call_messages_multiple_non_assistant_tools(server: SyncServer, sarah_agent, default_user): + """Test that multiple non-assistant tools are batched together until assistant tool is reached""" + + tool_calls = [ + OpenAIToolCall( + id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "memory 1"}') + ), + OpenAIToolCall( + id="call_2", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "search 1"}') + ), + OpenAIToolCall( + id="call_3", type="function", function=OpenAIFunction(name="archival_memory_search", arguments='{"query": "archive search"}') + ), + OpenAIToolCall( + id="call_4", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Results found!"}') + ), + ] + + message = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.assistant, + content=[TextContent(text="Processing...")], + tool_calls=tool_calls, + ) + + # convert with assistant mode (reverse=True by default) + letta_messages = message.to_letta_messages(use_assistant_message=True) + + # should have 3 messages in reverse order: assistant, tool call (with 3 tools), reasoning + assert len(letta_messages) == 3 + assert letta_messages[0].message_type == "assistant_message" + assert letta_messages[1].message_type == "tool_call_message" + assert letta_messages[2].message_type == "reasoning_message" + + # check the tool call message has all 3 non-assistant tools + tool_msg = letta_messages[1] + assert len(tool_msg.tool_calls) == 3 + tool_names = [tc.name for tc in tool_msg.tool_calls] + assert "archival_memory_insert" in tool_names + assert "conversation_search" in tool_names + assert "archival_memory_search" in tool_names + + # check backwards compat field has first tool + assert tool_msg.tool_call.name == "archival_memory_insert" + + # check assistant message + assert letta_messages[0].content == "Results found!" + + +@pytest.mark.asyncio +async def test_convert_single_tool_call_both_fields(server: SyncServer, sarah_agent, default_user): + """Test that a single tool call is written to both tool_call and tool_calls fields""" + + tool_calls = [ + OpenAIToolCall( + id="call_1", + type="function", + function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "single tool call"}'), + ), + ] + + message = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.assistant, + content=[TextContent(text="Saving to memory...")], + tool_calls=tool_calls, + ) + + # test without assistant mode (reverse=True by default) + letta_messages = message.to_letta_messages(use_assistant_message=False) + + assert len(letta_messages) == 2 # tool call + reasoning (reversed) + tool_msg = letta_messages[0] # tool call is first due to reverse + + # both fields should be populated + assert tool_msg.tool_call is not None + assert tool_msg.tool_call.name == "archival_memory_insert" + + assert tool_msg.tool_calls is not None + assert len(tool_msg.tool_calls) == 1 + assert tool_msg.tool_calls[0].name == "archival_memory_insert" + assert tool_msg.tool_calls[0].tool_call_id == "call_1" + + # test with assistant mode (reverse=True by default) + letta_messages_assist = message.to_letta_messages(use_assistant_message=True) + + assert len(letta_messages_assist) == 2 # tool call + reasoning (reversed) + tool_msg_assist = letta_messages_assist[0] # tool call is first due to reverse + + # both fields should still be populated + assert tool_msg_assist.tool_call is not None + assert tool_msg_assist.tool_calls is not None + assert len(tool_msg_assist.tool_calls) == 1 + + +@pytest.mark.asyncio +async def test_convert_tool_calls_only_assistant_tools(server: SyncServer, sarah_agent, default_user): + """Test that only send_message tools are converted to AssistantMessages""" + + tool_calls = [ + OpenAIToolCall( + id="call_1", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "First message"}') + ), + OpenAIToolCall( + id="call_2", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Second message"}') + ), + ] + + message = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.assistant, + content=[TextContent(text="Sending messages...")], + tool_calls=tool_calls, + ) + + # convert with assistant mode (reverse=True by default) + letta_messages = message.to_letta_messages(use_assistant_message=True) + + # should have 3 messages in reverse order: 2 assistant messages, then reasoning + assert len(letta_messages) == 3 + assert letta_messages[0].message_type == "assistant_message" + assert letta_messages[1].message_type == "assistant_message" + assert letta_messages[2].message_type == "reasoning_message" + + # check assistant messages content (they appear in reverse order) + assert letta_messages[0].content == "Second message" + assert letta_messages[1].content == "First message"