feat: Adjust message conversion to support multiple tool calls [LET-5336] (#5270)
* Adjust message conversion * Make collection logic cleaner
This commit is contained in:
committed by
Caren Thomas
parent
e5657bac5d
commit
609e63cb12
@@ -492,13 +492,60 @@ class Message(BaseMessage):
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> List[LettaMessage]:
|
||||
messages = []
|
||||
# This is type FunctionCall
|
||||
|
||||
# If assistant mode is off, just create one ToolCallMessage with all tool calls
|
||||
if not use_assistant_message:
|
||||
all_tool_call_objs = [
|
||||
ToolCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
for tool_call in self.tool_calls
|
||||
]
|
||||
|
||||
if all_tool_call_objs:
|
||||
otid = Message.generate_otid_from_id(self.id, current_message_count)
|
||||
messages.append(
|
||||
ToolCallMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
# use first tool call for the deprecated field
|
||||
tool_call=all_tool_call_objs[0],
|
||||
tool_calls=all_tool_call_objs,
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
run_id=self.run_id,
|
||||
)
|
||||
)
|
||||
return messages
|
||||
|
||||
collected_tool_calls = []
|
||||
|
||||
for tool_call in self.tool_calls:
|
||||
otid = Message.generate_otid_from_id(self.id, current_message_count + len(messages))
|
||||
# If we're supporting using assistant message,
|
||||
# then we want to treat certain function calls as a special case
|
||||
if use_assistant_message and tool_call.function.name == assistant_message_tool_name:
|
||||
# We need to unpack the actual message contents from the function call
|
||||
|
||||
if tool_call.function.name == assistant_message_tool_name:
|
||||
if collected_tool_calls:
|
||||
tool_call_message = ToolCallMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
# use first tool call for the deprecated field
|
||||
tool_call=collected_tool_calls[0],
|
||||
tool_calls=collected_tool_calls.copy(),
|
||||
name=self.name,
|
||||
otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)),
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
run_id=self.run_id,
|
||||
)
|
||||
messages.append(tool_call_message)
|
||||
collected_tool_calls = [] # reset the collection
|
||||
|
||||
try:
|
||||
func_args = parse_json(tool_call.function.arguments)
|
||||
message_string = validate_function_response(func_args[assistant_message_tool_kwarg], 0, truncate=False)
|
||||
@@ -518,25 +565,31 @@ class Message(BaseMessage):
|
||||
)
|
||||
)
|
||||
else:
|
||||
# non-assistant tool call, collect it
|
||||
tool_call_obj = ToolCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
tool_call_id=tool_call.id,
|
||||
)
|
||||
messages.append(
|
||||
ToolCallMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
tool_call=tool_call_obj,
|
||||
tool_calls=[tool_call_obj],
|
||||
name=self.name,
|
||||
otid=otid,
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
run_id=self.run_id,
|
||||
)
|
||||
)
|
||||
collected_tool_calls.append(tool_call_obj)
|
||||
|
||||
# flush any remaining collected tool calls
|
||||
if collected_tool_calls:
|
||||
tool_call_message = ToolCallMessage(
|
||||
id=self.id,
|
||||
date=self.created_at,
|
||||
# use first tool call for the deprecated field
|
||||
tool_call=collected_tool_calls[0],
|
||||
tool_calls=collected_tool_calls,
|
||||
name=self.name,
|
||||
otid=Message.generate_otid_from_id(self.id, current_message_count + len(messages)),
|
||||
sender_id=self.sender_id,
|
||||
step_id=self.step_id,
|
||||
is_err=self.is_err,
|
||||
run_id=self.run_id,
|
||||
)
|
||||
messages.append(tool_call_message)
|
||||
|
||||
return messages
|
||||
|
||||
def _convert_tool_return_message(self) -> List[ToolReturnMessage]:
|
||||
|
||||
@@ -784,3 +784,236 @@ async def test_create_many_messages_async_with_turbopuffer(server: SyncServer, s
|
||||
for msg in created_messages:
|
||||
assert msg.id is not None
|
||||
assert msg.agent_id == sarah_agent.id
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Pydantic Object Tests - Tool Call Message Conversion
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_tool_call_messages_no_assistant_mode(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that when assistant mode is off, all tool calls go into a single ToolCallMessage"""
|
||||
from letta.schemas.letta_message import ToolCall
|
||||
|
||||
# create a message with multiple tool calls
|
||||
tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "test memory 1"}')
|
||||
),
|
||||
OpenAIToolCall(
|
||||
id="call_2", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "test search"}')
|
||||
),
|
||||
OpenAIToolCall(id="call_3", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Hello there!"}')),
|
||||
]
|
||||
|
||||
message = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text="Let me help you with that.")],
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
# convert without assistant mode (reverse=True by default)
|
||||
letta_messages = message.to_letta_messages(use_assistant_message=False)
|
||||
|
||||
# should have 2 messages in reverse order: tool call message, then reasoning message
|
||||
assert len(letta_messages) == 2
|
||||
assert letta_messages[0].message_type == "tool_call_message"
|
||||
assert letta_messages[1].message_type == "reasoning_message"
|
||||
|
||||
# check the tool call message has all tool calls in the new field
|
||||
tool_call_msg = letta_messages[0]
|
||||
assert tool_call_msg.tool_calls is not None
|
||||
assert len(tool_call_msg.tool_calls) == 3
|
||||
|
||||
# check backwards compatibility - first tool call in deprecated field
|
||||
assert tool_call_msg.tool_call is not None
|
||||
assert tool_call_msg.tool_call.name == "archival_memory_insert"
|
||||
assert tool_call_msg.tool_call.tool_call_id == "call_1"
|
||||
|
||||
# verify all tool calls are present in the list
|
||||
tool_names = [tc.name for tc in tool_call_msg.tool_calls]
|
||||
assert "archival_memory_insert" in tool_names
|
||||
assert "conversation_search" in tool_names
|
||||
assert "send_message" in tool_names
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_tool_call_messages_with_assistant_mode(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that with assistant mode, send_message becomes AssistantMessage and others are grouped"""
|
||||
|
||||
# create a message with tool calls including send_message
|
||||
tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "test memory 1"}')
|
||||
),
|
||||
OpenAIToolCall(id="call_2", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Hello there!"}')),
|
||||
OpenAIToolCall(
|
||||
id="call_3", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "test search"}')
|
||||
),
|
||||
]
|
||||
|
||||
message = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text="Let me help you with that.")],
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
# convert with assistant mode (reverse=True by default)
|
||||
letta_messages = message.to_letta_messages(use_assistant_message=True)
|
||||
|
||||
# should have 4 messages in reverse order:
|
||||
# conversation_search tool call, assistant message, archival_memory tool call, reasoning
|
||||
assert len(letta_messages) == 4
|
||||
assert letta_messages[0].message_type == "tool_call_message"
|
||||
assert letta_messages[1].message_type == "assistant_message"
|
||||
assert letta_messages[2].message_type == "tool_call_message"
|
||||
assert letta_messages[3].message_type == "reasoning_message"
|
||||
|
||||
# check first tool call message (actually the last in forward order) has conversation_search
|
||||
first_tool_msg = letta_messages[0]
|
||||
assert len(first_tool_msg.tool_calls) == 1
|
||||
assert first_tool_msg.tool_calls[0].name == "conversation_search"
|
||||
assert first_tool_msg.tool_call.name == "conversation_search" # backwards compat
|
||||
|
||||
# check assistant message content
|
||||
assistant_msg = letta_messages[1]
|
||||
assert assistant_msg.content == "Hello there!"
|
||||
|
||||
# check last tool call message (actually the first in forward order) has archival_memory_insert
|
||||
last_tool_msg = letta_messages[2]
|
||||
assert len(last_tool_msg.tool_calls) == 1
|
||||
assert last_tool_msg.tool_calls[0].name == "archival_memory_insert"
|
||||
assert last_tool_msg.tool_call.name == "archival_memory_insert" # backwards compat
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_tool_call_messages_multiple_non_assistant_tools(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that multiple non-assistant tools are batched together until assistant tool is reached"""
|
||||
|
||||
tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id="call_1", type="function", function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "memory 1"}')
|
||||
),
|
||||
OpenAIToolCall(
|
||||
id="call_2", type="function", function=OpenAIFunction(name="conversation_search", arguments='{"query": "search 1"}')
|
||||
),
|
||||
OpenAIToolCall(
|
||||
id="call_3", type="function", function=OpenAIFunction(name="archival_memory_search", arguments='{"query": "archive search"}')
|
||||
),
|
||||
OpenAIToolCall(
|
||||
id="call_4", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Results found!"}')
|
||||
),
|
||||
]
|
||||
|
||||
message = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text="Processing...")],
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
# convert with assistant mode (reverse=True by default)
|
||||
letta_messages = message.to_letta_messages(use_assistant_message=True)
|
||||
|
||||
# should have 3 messages in reverse order: assistant, tool call (with 3 tools), reasoning
|
||||
assert len(letta_messages) == 3
|
||||
assert letta_messages[0].message_type == "assistant_message"
|
||||
assert letta_messages[1].message_type == "tool_call_message"
|
||||
assert letta_messages[2].message_type == "reasoning_message"
|
||||
|
||||
# check the tool call message has all 3 non-assistant tools
|
||||
tool_msg = letta_messages[1]
|
||||
assert len(tool_msg.tool_calls) == 3
|
||||
tool_names = [tc.name for tc in tool_msg.tool_calls]
|
||||
assert "archival_memory_insert" in tool_names
|
||||
assert "conversation_search" in tool_names
|
||||
assert "archival_memory_search" in tool_names
|
||||
|
||||
# check backwards compat field has first tool
|
||||
assert tool_msg.tool_call.name == "archival_memory_insert"
|
||||
|
||||
# check assistant message
|
||||
assert letta_messages[0].content == "Results found!"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_single_tool_call_both_fields(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that a single tool call is written to both tool_call and tool_calls fields"""
|
||||
|
||||
tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id="call_1",
|
||||
type="function",
|
||||
function=OpenAIFunction(name="archival_memory_insert", arguments='{"content": "single tool call"}'),
|
||||
),
|
||||
]
|
||||
|
||||
message = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text="Saving to memory...")],
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
# test without assistant mode (reverse=True by default)
|
||||
letta_messages = message.to_letta_messages(use_assistant_message=False)
|
||||
|
||||
assert len(letta_messages) == 2 # tool call + reasoning (reversed)
|
||||
tool_msg = letta_messages[0] # tool call is first due to reverse
|
||||
|
||||
# both fields should be populated
|
||||
assert tool_msg.tool_call is not None
|
||||
assert tool_msg.tool_call.name == "archival_memory_insert"
|
||||
|
||||
assert tool_msg.tool_calls is not None
|
||||
assert len(tool_msg.tool_calls) == 1
|
||||
assert tool_msg.tool_calls[0].name == "archival_memory_insert"
|
||||
assert tool_msg.tool_calls[0].tool_call_id == "call_1"
|
||||
|
||||
# test with assistant mode (reverse=True by default)
|
||||
letta_messages_assist = message.to_letta_messages(use_assistant_message=True)
|
||||
|
||||
assert len(letta_messages_assist) == 2 # tool call + reasoning (reversed)
|
||||
tool_msg_assist = letta_messages_assist[0] # tool call is first due to reverse
|
||||
|
||||
# both fields should still be populated
|
||||
assert tool_msg_assist.tool_call is not None
|
||||
assert tool_msg_assist.tool_calls is not None
|
||||
assert len(tool_msg_assist.tool_calls) == 1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_convert_tool_calls_only_assistant_tools(server: SyncServer, sarah_agent, default_user):
|
||||
"""Test that only send_message tools are converted to AssistantMessages"""
|
||||
|
||||
tool_calls = [
|
||||
OpenAIToolCall(
|
||||
id="call_1", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "First message"}')
|
||||
),
|
||||
OpenAIToolCall(
|
||||
id="call_2", type="function", function=OpenAIFunction(name="send_message", arguments='{"message": "Second message"}')
|
||||
),
|
||||
]
|
||||
|
||||
message = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(text="Sending messages...")],
|
||||
tool_calls=tool_calls,
|
||||
)
|
||||
|
||||
# convert with assistant mode (reverse=True by default)
|
||||
letta_messages = message.to_letta_messages(use_assistant_message=True)
|
||||
|
||||
# should have 3 messages in reverse order: 2 assistant messages, then reasoning
|
||||
assert len(letta_messages) == 3
|
||||
assert letta_messages[0].message_type == "assistant_message"
|
||||
assert letta_messages[1].message_type == "assistant_message"
|
||||
assert letta_messages[2].message_type == "reasoning_message"
|
||||
|
||||
# check assistant messages content (they appear in reverse order)
|
||||
assert letta_messages[0].content == "Second message"
|
||||
assert letta_messages[1].content == "First message"
|
||||
|
||||
Reference in New Issue
Block a user