feat: message orm migration (#2144)

Co-authored-by: Mindy Long <mindy@letta.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Matt Zhou <mattzh1314@gmail.com>
This commit is contained in:
mlong93
2024-12-06 11:50:15 -08:00
committed by GitHub
parent 370a0e68dd
commit 6c2c7231ab
45 changed files with 984 additions and 1265 deletions

View File

@@ -17,7 +17,6 @@ from letta.schemas.letta_message import (
SystemMessage,
UserMessage,
)
from letta.schemas.message import Message
from letta.schemas.user import User
from .test_managers import DEFAULT_EMBEDDING_CONFIG
@@ -91,7 +90,7 @@ def agent_id(server, user_id):
def test_error_on_nonexistent_agent(server, user_id, agent_id):
try:
fake_agent_id = uuid.uuid4()
fake_agent_id = str(uuid.uuid4())
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
@@ -388,7 +387,7 @@ def _test_get_messages_letta_format(
agent_id,
reverse=False,
):
"""Reverse is off by default, the GET goes in chronological order"""
"""Test mapping between messages and letta_messages with reverse=False."""
messages = server.get_agent_recall_cursor(
user_id=user_id,
@@ -397,7 +396,6 @@ def _test_get_messages_letta_format(
reverse=reverse,
return_message_object=True,
)
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
assert all(isinstance(m, Message) for m in messages)
letta_messages = server.get_agent_recall_cursor(
@@ -407,140 +405,96 @@ def _test_get_messages_letta_format(
reverse=reverse,
return_message_object=False,
)
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
assert all(isinstance(m, LettaMessage) for m in letta_messages)
# Loop through `messages` while also looping through `letta_messages`
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
# If role of message (in `messages`) is `assistant`,
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
print("MESSAGES (obj):")
for i, m in enumerate(messages):
# print(m)
print(f"{i}: {m.role}, {m.text[:50]}...")
# print(m.role)
print("MEMGPT_MESSAGES:")
for i, m in enumerate(letta_messages):
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
# Collect system messages and their texts
system_messages = [m for m in messages if m.role == MessageRole.system]
system_texts = [m.text for m in system_messages]
# If there are multiple system messages, print the diff
if len(system_messages) > 1:
print("Differences between system messages:")
for i in range(len(system_texts) - 1):
for j in range(i + 1, len(system_texts)):
import difflib
diff = difflib.unified_diff(
system_texts[i].splitlines(),
system_texts[j].splitlines(),
fromfile=f"System Message {i+1}",
tofile=f"System Message {j+1}",
lineterm="",
)
print("\n".join(diff))
else:
print("There is only one or no system message.")
print(f"Messages: {len(messages)}, LettaMessages: {len(letta_messages)}")
letta_message_index = 0
for i, message in enumerate(messages):
assert isinstance(message, Message)
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
# Defensive bounds check for letta_messages
if letta_message_index >= len(letta_messages):
print(f"Error: letta_message_index out of range. Expected more letta_messages for message {i}: {message.role}")
raise ValueError(f"Mismatch in letta_messages length. Index: {letta_message_index}, Length: {len(letta_messages)}")
print(f"Processing message {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
while letta_message_index < len(letta_messages):
letta_message = letta_messages[letta_message_index]
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
# Validate mappings for assistant role
if message.role == MessageRole.assistant:
print(f"i={i}, M=assistant, MM={type(letta_message)}")
print(f"Assistant Message at {i}: {type(letta_message)}")
# If reverse, function call will come first
if reverse:
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
# Reverse handling: FunctionCallMessages come first
if message.tool_calls:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
except json.JSONDecodeError:
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
letta_message = letta_messages[letta_message_index]
if message.text is not None:
if message.text:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
else:
if message.text is not None:
else: # Non-reverse handling
if message.text:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
if message.tool_calls:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
except json.JSONDecodeError:
warnings.warn(f"Invalid JSON in function arguments: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
assert tool_call.function.name == letta_message.function_call.name
assert tool_call.function.arguments == letta_message.function_call.arguments
letta_message_index += 1
if letta_message_index >= len(letta_messages):
break
letta_message = letta_messages[letta_message_index]
elif message.role == MessageRole.user:
print(f"i={i}, M=user, MM={type(letta_message)}")
assert isinstance(letta_message, UserMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.system:
print(f"i={i}, M=system, MM={type(letta_message)}")
assert isinstance(letta_message, SystemMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.tool:
print(f"i={i}, M=tool, MM={type(letta_message)}")
assert isinstance(letta_message, FunctionReturn)
# Check the the value in `text` is the same
assert message.text == letta_message.function_return
letta_message_index += 1
else:
raise ValueError(f"Unexpected message role: {message.role}")
# Move to the next message in the original messages list
break
break # Exit the letta_messages loop after processing one mapping
if letta_message_index < len(letta_messages):
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
def test_get_messages_letta_format(server, user_id, agent_id):
for reverse in [False, True]:
# for reverse in [False, True]:
for reverse in [False]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
@@ -586,7 +540,7 @@ def ingest(message: str):
'''
def test_tool_run(server, user_id, agent_id):
def test_tool_run(server, mock_e2b_api_key_none, user_id, agent_id):
"""Test that the server can run tools"""
result = server.run_tool_from_source(
@@ -672,7 +626,7 @@ def test_composio_client_simple(server):
assert len(actions) > 0
def test_memory_rebuild_count(server, user_id):
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
# create agent
@@ -712,7 +666,6 @@ def test_memory_rebuild_count(server, user_id):
return len(system_messages), letta_messages
try:
# At this stage, there should only be 1 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()
# assert num_system_messages == 1, (num_system_messages, all_messages)