feat: message orm migration (#2144)
Co-authored-by: Mindy Long <mindy@letta.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com> Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
@@ -17,7 +17,6 @@ from letta.schemas.letta_message import (
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User
|
||||
|
||||
from .test_managers import DEFAULT_EMBEDDING_CONFIG
|
||||
@@ -91,7 +90,7 @@ def agent_id(server, user_id):
|
||||
|
||||
def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
try:
|
||||
fake_agent_id = uuid.uuid4()
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
|
||||
raise Exception("user_message call should have failed")
|
||||
except (KeyError, ValueError) as e:
|
||||
@@ -388,7 +387,7 @@ def _test_get_messages_letta_format(
|
||||
agent_id,
|
||||
reverse=False,
|
||||
):
|
||||
"""Reverse is off by default, the GET goes in chronological order"""
|
||||
"""Test mapping between messages and letta_messages with reverse=False."""
|
||||
|
||||
messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
@@ -397,7 +396,6 @@ def _test_get_messages_letta_format(
|
||||
reverse=reverse,
|
||||
return_message_object=True,
|
||||
)
|
||||
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
@@ -407,140 +405,96 @@ def _test_get_messages_letta_format(
|
||||
reverse=reverse,
|
||||
return_message_object=False,
|
||||
)
|
||||
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
|
||||
assert all(isinstance(m, LettaMessage) for m in letta_messages)
|
||||
|
||||
# Loop through `messages` while also looping through `letta_messages`
|
||||
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
|
||||
# If role of message (in `messages`) is `assistant`,
|
||||
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
|
||||
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
|
||||
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
|
||||
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
|
||||
|
||||
print("MESSAGES (obj):")
|
||||
for i, m in enumerate(messages):
|
||||
# print(m)
|
||||
print(f"{i}: {m.role}, {m.text[:50]}...")
|
||||
# print(m.role)
|
||||
|
||||
print("MEMGPT_MESSAGES:")
|
||||
for i, m in enumerate(letta_messages):
|
||||
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
|
||||
|
||||
# Collect system messages and their texts
|
||||
system_messages = [m for m in messages if m.role == MessageRole.system]
|
||||
system_texts = [m.text for m in system_messages]
|
||||
|
||||
# If there are multiple system messages, print the diff
|
||||
if len(system_messages) > 1:
|
||||
print("Differences between system messages:")
|
||||
for i in range(len(system_texts) - 1):
|
||||
for j in range(i + 1, len(system_texts)):
|
||||
import difflib
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
system_texts[i].splitlines(),
|
||||
system_texts[j].splitlines(),
|
||||
fromfile=f"System Message {i+1}",
|
||||
tofile=f"System Message {j+1}",
|
||||
lineterm="",
|
||||
)
|
||||
print("\n".join(diff))
|
||||
else:
|
||||
print("There is only one or no system message.")
|
||||
print(f"Messages: {len(messages)}, LettaMessages: {len(letta_messages)}")
|
||||
|
||||
letta_message_index = 0
|
||||
for i, message in enumerate(messages):
|
||||
assert isinstance(message, Message)
|
||||
|
||||
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
|
||||
# Defensive bounds check for letta_messages
|
||||
if letta_message_index >= len(letta_messages):
|
||||
print(f"Error: letta_message_index out of range. Expected more letta_messages for message {i}: {message.role}")
|
||||
raise ValueError(f"Mismatch in letta_messages length. Index: {letta_message_index}, Length: {len(letta_messages)}")
|
||||
|
||||
print(f"Processing message {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
|
||||
while letta_message_index < len(letta_messages):
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
|
||||
|
||||
# Validate mappings for assistant role
|
||||
if message.role == MessageRole.assistant:
|
||||
print(f"i={i}, M=assistant, MM={type(letta_message)}")
|
||||
print(f"Assistant Message at {i}: {type(letta_message)}")
|
||||
|
||||
# If reverse, function call will come first
|
||||
if reverse:
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
# Reverse handling: FunctionCallMessages come first
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
letta_message_index += 1
|
||||
if letta_message_index >= len(letta_messages):
|
||||
break
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
if message.text is not None:
|
||||
if message.text:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
else:
|
||||
|
||||
if message.text is not None:
|
||||
else: # Non-reverse handling
|
||||
if message.text:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
if letta_message_index >= len(letta_messages):
|
||||
break
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
|
||||
except json.JSONDecodeError:
|
||||
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
assert tool_call.function.name == letta_message.function_call.name
|
||||
assert tool_call.function.arguments == letta_message.function_call.arguments
|
||||
letta_message_index += 1
|
||||
if letta_message_index >= len(letta_messages):
|
||||
break
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
elif message.role == MessageRole.user:
|
||||
print(f"i={i}, M=user, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, UserMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.system:
|
||||
print(f"i={i}, M=system, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, SystemMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.tool:
|
||||
print(f"i={i}, M=tool, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, FunctionReturn)
|
||||
# Check the the value in `text` is the same
|
||||
assert message.text == letta_message.function_return
|
||||
letta_message_index += 1
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected message role: {message.role}")
|
||||
|
||||
# Move to the next message in the original messages list
|
||||
break
|
||||
break # Exit the letta_messages loop after processing one mapping
|
||||
|
||||
if letta_message_index < len(letta_messages):
|
||||
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
|
||||
|
||||
|
||||
def test_get_messages_letta_format(server, user_id, agent_id):
|
||||
for reverse in [False, True]:
|
||||
# for reverse in [False, True]:
|
||||
for reverse in [False]:
|
||||
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
|
||||
|
||||
|
||||
@@ -586,7 +540,7 @@ def ingest(message: str):
|
||||
'''
|
||||
|
||||
|
||||
def test_tool_run(server, user_id, agent_id):
|
||||
def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
|
||||
"""Test that the server can run tools"""
|
||||
|
||||
result = server.run_tool_from_source(
|
||||
@@ -672,7 +626,7 @@ def test_composio_client_simple(server):
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_memory_rebuild_count(server, user_id):
|
||||
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
|
||||
# create agent
|
||||
@@ -712,7 +666,6 @@ def test_memory_rebuild_count(server, user_id):
|
||||
return len(system_messages), letta_messages
|
||||
|
||||
try:
|
||||
|
||||
# At this stage, there should only be 1 system message inside of recall storage
|
||||
num_system_messages, all_messages = count_system_messages_in_recall()
|
||||
# assert num_system_messages == 1, (num_system_messages, all_messages)
|
||||
|
||||
Reference in New Issue
Block a user