feat: Adjust message conversion to support multiple tool calls [LET-5336] (#5270)

* Adjust message conversion

* Make collection logic cleaner
This commit is contained in:
Matthew Zhou
2025-10-08 17:27:30 -07:00
committed by Caren Thomas
parent e5657bac5d
commit 609e63cb12
2 changed files with 305 additions and 19 deletions

View File

@@ -492,13 +492,60 @@ class Message(BaseMessage):
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> List[LettaMessage]:
messages = []
# This is type FunctionCall
# If assistant mode is off, just create one ToolCallMessage with all tool calls
if not use_assistant_message:
all_tool_call_objs = [
ToolCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
tool_call_id=tool_call.id,
)
for tool_call in self.tool_calls
]
if all_tool_call_objs:
otid = Message.generate_otid_from_id(self.id, current_message_count)
messages.append(
ToolCallMessage(
id=self.id,
date=self.created_at,
# use first tool call for the deprecated field
tool_call=all_tool_call_objs[0],
tool_calls=all_tool_call_objs,
name=self.name,
otid=otid,
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
run_id=self.run_id,
)
)
return messages
collected_tool_calls = []
for tool_call in self.tool_calls:
otid = Message.generate_otid_from_id(self.id, current_message_count + len(messages))
# If we're supporting using assistant message,
# then we want to treat certain function calls as a special case
if use_assistant_message and tool_call.function.name == assistant_message_tool_name:
# We need to unpack the actual message contents from the function call
if tool_call.function.name == assistant_message_tool_name:
if collected_tool_calls:
tool_call_message = ToolCallMessage(
id=self.id,
date=self.created_at,
# use first tool call for the deprecated field
tool_call=collected_tool_calls[0],
tool_calls=collected_tool_calls.copy(),
name=self.name,
otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)),
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
run_id=self.run_id,
)
messages.append(tool_call_message)
collected_tool_calls = [] # reset the collection
try:
func_args = parse_json(tool_call.function.arguments)
message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False)
@@ -518,25 +565,31 @@ class Message(BaseMessage):
)
)
else:
# non-assistant tool call, collect it
tool_call_obj = ToolCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
tool_call_id=tool_call.id,
)
messages.append(
ToolCallMessage(
id=self.id,
date=self.created_at,
tool_call=tool_call_obj,
tool_calls=[tool_call_obj],
name=self.name,
otid=otid,
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
run_id=self.run_id,
)
)
collected_tool_calls.append(tool_call_obj)
# flush any remaining collected tool calls
if collected_tool_calls:
tool_call_message = ToolCallMessage(
id=self.id,
date=self.created_at,
# use first tool call for the deprecated field
tool_call=collected_tool_calls[0],
tool_calls=collected_tool_calls,
name=self.name,
otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)),
sender_id=self.sender_id,
step_id=self.step_id,
is_err=self.is_err,
run_id=self.run_id,
)
messages.append(tool_call_message)
return messages
def _convert_tool_return_message(self) -> List[ToolReturnMessage]:

View File

@@ -784,3 +784,236 @@ async def test_create_many_messages_async_with_turbopuffer(server: SyncServer, s
for msg in created_messages:
assert msg.id is not None
assert msg.agent_id == sarah_agent.id
# ======================================================================================================================
# Pydantic Object Tests - Tool Call Message Conversion
# ======================================================================================================================
@pytest.mark.asyncio
async def test_convert_tool_call_messages_no_assistant_mode(server: SyncServer, sarah_agent, default_user):
"""Test that when assistant mode is off, all tool calls go into a single ToolCallMessage"""
from letta.schemas.letta_message import ToolCall
# create a message with multiple tool calls
tool_calls = [
OpenAIToolCall(
id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "test memory 1"}')
),
OpenAIToolCall(
id="call_2", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "test search"}')
),
OpenAIToolCall(id="call_3", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Hello there!"}')),
]
message = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.assistant,
content=[TextContent(text="Let me help you with that.")],
tool_calls=tool_calls,
)
# convert without assistant mode (reverse=True by default)
letta_messages = message.to_letta_messages(use_assistant_message=False)
# should have 2 messages in reverse order: tool call message, then reasoning message
assert len(letta_messages) == 2
assert letta_messages[0].message_type == "tool_call_message"
assert letta_messages[1].message_type == "reasoning_message"
# check the tool call message has all tool calls in the new field
tool_call_msg = letta_messages[0]
assert tool_call_msg.tool_calls is not None
assert len(tool_call_msg.tool_calls) == 3
# check backwards compatibility - first tool call in deprecated field
assert tool_call_msg.tool_call is not None
assert tool_call_msg.tool_call.name == "archival_memory_insert"
assert tool_call_msg.tool_call.tool_call_id == "call_1"
# verify all tool calls are present in the list
tool_names = [tc.name for tc in tool_call_msg.tool_calls]
assert "archival_memory_insert" in tool_names
assert "conversation_search" in tool_names
assert "send_message" in tool_names
@pytest.mark.asyncio
async def test_convert_tool_call_messages_with_assistant_mode(server: SyncServer, sarah_agent, default_user):
"""Test that with assistant mode, send_message becomes AssistantMessage and others are grouped"""
# create a message with tool calls including send_message
tool_calls = [
OpenAIToolCall(
id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "test memory 1"}')
),
OpenAIToolCall(id="call_2", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Hello there!"}')),
OpenAIToolCall(
id="call_3", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "test search"}')
),
]
message = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.assistant,
content=[TextContent(text="Let me help you with that.")],
tool_calls=tool_calls,
)
# convert with assistant mode (reverse=True by default)
letta_messages = message.to_letta_messages(use_assistant_message=True)
# should have 4 messages in reverse order:
# conversation_search tool call, assistant message, archival_memory tool call, reasoning
assert len(letta_messages) == 4
assert letta_messages[0].message_type == "tool_call_message"
assert letta_messages[1].message_type == "assistant_message"
assert letta_messages[2].message_type == "tool_call_message"
assert letta_messages[3].message_type == "reasoning_message"
# check first tool call message (actually the last in forward order) has conversation_search
first_tool_msg = letta_messages[0]
assert len(first_tool_msg.tool_calls) == 1
assert first_tool_msg.tool_calls[0].name == "conversation_search"
assert first_tool_msg.tool_call.name == "conversation_search" # backwards compat
# check assistant message content
assistant_msg = letta_messages[1]
assert assistant_msg.content == "Hello there!"
# check last tool call message (actually the first in forward order) has archival_memory_insert
last_tool_msg = letta_messages[2]
assert len(last_tool_msg.tool_calls) == 1
assert last_tool_msg.tool_calls[0].name == "archival_memory_insert"
assert last_tool_msg.tool_call.name == "archival_memory_insert" # backwards compat
@pytest.mark.asyncio
async def test_convert_tool_call_messages_multiple_non_assistant_tools(server: SyncServer, sarah_agent, default_user):
"""Test that multiple non-assistant tools are batched together until assistant tool is reached"""
tool_calls = [
OpenAIToolCall(
id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "memory 1"}')
),
OpenAIToolCall(
id="call_2", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "search 1"}')
),
OpenAIToolCall(
id="call_3", type="function", function=OpenAIFunction(name="archival_memory_search", arguments='{"query": "archive search"}')
),
OpenAIToolCall(
id="call_4", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Results found!"}')
),
]
message = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.assistant,
content=[TextContent(text="Processing...")],
tool_calls=tool_calls,
)
# convert with assistant mode (reverse=True by default)
letta_messages = message.to_letta_messages(use_assistant_message=True)
# should have 3 messages in reverse order: assistant, tool call (with 3 tools), reasoning
assert len(letta_messages) == 3
assert letta_messages[0].message_type == "assistant_message"
assert letta_messages[1].message_type == "tool_call_message"
assert letta_messages[2].message_type == "reasoning_message"
# check the tool call message has all 3 non-assistant tools
tool_msg = letta_messages[1]
assert len(tool_msg.tool_calls) == 3
tool_names = [tc.name for tc in tool_msg.tool_calls]
assert "archival_memory_insert" in tool_names
assert "conversation_search" in tool_names
assert "archival_memory_search" in tool_names
# check backwards compat field has first tool
assert tool_msg.tool_call.name == "archival_memory_insert"
# check assistant message
assert letta_messages[0].content == "Results found!"
@pytest.mark.asyncio
async def test_convert_single_tool_call_both_fields(server: SyncServer, sarah_agent, default_user):
"""Test that a single tool call is written to both tool_call and tool_calls fields"""
tool_calls = [
OpenAIToolCall(
id="call_1",
type="function",
function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "single tool call"}'),
),
]
message = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.assistant,
content=[TextContent(text="Saving to memory...")],
tool_calls=tool_calls,
)
# test without assistant mode (reverse=True by default)
letta_messages = message.to_letta_messages(use_assistant_message=False)
assert len(letta_messages) == 2 # tool call + reasoning (reversed)
tool_msg = letta_messages[0] # tool call is first due to reverse
# both fields should be populated
assert tool_msg.tool_call is not None
assert tool_msg.tool_call.name == "archival_memory_insert"
assert tool_msg.tool_calls is not None
assert len(tool_msg.tool_calls) == 1
assert tool_msg.tool_calls[0].name == "archival_memory_insert"
assert tool_msg.tool_calls[0].tool_call_id == "call_1"
# test with assistant mode (reverse=True by default)
letta_messages_assist = message.to_letta_messages(use_assistant_message=True)
assert len(letta_messages_assist) == 2 # tool call + reasoning (reversed)
tool_msg_assist = letta_messages_assist[0] # tool call is first due to reverse
# both fields should still be populated
assert tool_msg_assist.tool_call is not None
assert tool_msg_assist.tool_calls is not None
assert len(tool_msg_assist.tool_calls) == 1
@pytest.mark.asyncio
async def test_convert_tool_calls_only_assistant_tools(server: SyncServer, sarah_agent, default_user):
"""Test that only send_message tools are converted to AssistantMessages"""
tool_calls = [
OpenAIToolCall(
id="call_1", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "First message"}')
),
OpenAIToolCall(
id="call_2", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Second message"}')
),
]
message = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.assistant,
content=[TextContent(text="Sending messages...")],
tool_calls=tool_calls,
)
# convert with assistant mode (reverse=True by default)
letta_messages = message.to_letta_messages(use_assistant_message=True)
# should have 3 messages in reverse order: 2 assistant messages, then reasoning
assert len(letta_messages) == 3
assert letta_messages[0].message_type == "assistant_message"
assert letta_messages[1].message_type == "assistant_message"
assert letta_messages[2].message_type == "reasoning_message"
# check assistant messages content (they appear in reverse order)
assert letta_messages[0].content == "Second message"
assert letta_messages[1].content == "First message"