feat: Separate out streaming route (#2111)
This commit is contained in:
@@ -5,8 +5,17 @@ import warnings
|
||||
import pytest
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import BASE_TOOLS
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import (
|
||||
FunctionCallMessage,
|
||||
FunctionReturn,
|
||||
InternalMonologue,
|
||||
LettaMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User
|
||||
|
||||
from .test_managers import DEFAULT_EMBEDDING_CONFIG
|
||||
@@ -15,18 +24,8 @@ utils.DEBUG = True
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message import (
|
||||
AssistantMessage,
|
||||
FunctionCallMessage,
|
||||
FunctionReturn,
|
||||
InternalMonologue,
|
||||
LettaMessage,
|
||||
SystemMessage,
|
||||
UserMessage,
|
||||
)
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.source import Source
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
@@ -174,27 +173,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
messages_2[-1].id
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
messages_3[-1].id
|
||||
# [m["id"] for m in messages_3]
|
||||
# [m["id"] for m in messages_2]
|
||||
timestamps = [m.created_at for m in messages_3]
|
||||
print("timestamps", timestamps)
|
||||
assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
all_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
|
||||
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
# TODO: doesn't pass since recall memory also logs all system message changess
|
||||
# print("IN CONTEXT:", [m.text for m in server.get_in_context_messages(agent_id=agent_id)])
|
||||
# print("ALL:", [m.text for m in all_messages])
|
||||
# print()
|
||||
# for message in all_messages:
|
||||
# if message.id not in in_context_ids:
|
||||
# print("NOT IN CONTEXT:", message.id, message.created_at, message.text[-100:])
|
||||
# print()
|
||||
# assert len(in_context_ids) == len(messages_3)
|
||||
message_ids = [m.id for m in messages_3]
|
||||
for message_id in in_context_ids:
|
||||
assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
@@ -248,201 +233,6 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
def _test_get_messages_letta_format(
|
||||
server,
|
||||
user_id,
|
||||
agent_id,
|
||||
reverse=False,
|
||||
# flag that determines whether or not to use AssistantMessage, or just FunctionCallMessage universally
|
||||
use_assistant_message=False,
|
||||
):
|
||||
"""Reverse is off by default, the GET goes in chronological order"""
|
||||
|
||||
messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
return_message_object=True,
|
||||
use_assistant_message=use_assistant_message,
|
||||
)
|
||||
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
return_message_object=False,
|
||||
use_assistant_message=use_assistant_message,
|
||||
)
|
||||
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
|
||||
assert all(isinstance(m, LettaMessage) for m in letta_messages)
|
||||
|
||||
# Loop through `messages` while also looping through `letta_messages`
|
||||
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
|
||||
# If role of message (in `messages`) is `assistant`,
|
||||
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
|
||||
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
|
||||
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
|
||||
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
|
||||
|
||||
print("MESSAGES (obj):")
|
||||
for i, m in enumerate(messages):
|
||||
# print(m)
|
||||
print(f"{i}: {m.role}, {m.text[:50]}...")
|
||||
# print(m.role)
|
||||
|
||||
print("MEMGPT_MESSAGES:")
|
||||
for i, m in enumerate(letta_messages):
|
||||
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
|
||||
|
||||
# Collect system messages and their texts
|
||||
system_messages = [m for m in messages if m.role == MessageRole.system]
|
||||
system_texts = [m.text for m in system_messages]
|
||||
|
||||
# If there are multiple system messages, print the diff
|
||||
if len(system_messages) > 1:
|
||||
print("Differences between system messages:")
|
||||
for i in range(len(system_texts) - 1):
|
||||
for j in range(i + 1, len(system_texts)):
|
||||
import difflib
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
system_texts[i].splitlines(),
|
||||
system_texts[j].splitlines(),
|
||||
fromfile=f"System Message {i+1}",
|
||||
tofile=f"System Message {j+1}",
|
||||
lineterm="",
|
||||
)
|
||||
print("\n".join(diff))
|
||||
else:
|
||||
print("There is only one or no system message.")
|
||||
|
||||
letta_message_index = 0
|
||||
for i, message in enumerate(messages):
|
||||
assert isinstance(message, Message)
|
||||
|
||||
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
|
||||
while letta_message_index < len(letta_messages):
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
|
||||
|
||||
if message.role == MessageRole.assistant:
|
||||
print(f"i={i}, M=assistant, MM={type(letta_message)}")
|
||||
|
||||
# If reverse, function call will come first
|
||||
if reverse:
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
func_args = json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
func_args = {}
|
||||
|
||||
# If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool
|
||||
if (
|
||||
use_assistant_message
|
||||
and tool_call.function.name == DEFAULT_MESSAGE_TOOL
|
||||
and DEFAULT_MESSAGE_TOOL_KWARG in func_args
|
||||
):
|
||||
assert isinstance(letta_message, AssistantMessage)
|
||||
assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
# Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage
|
||||
else:
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
if message.text is not None:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
else:
|
||||
|
||||
if message.text is not None:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
func_args = json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
func_args = {}
|
||||
|
||||
# If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool
|
||||
if (
|
||||
use_assistant_message
|
||||
and tool_call.function.name == DEFAULT_MESSAGE_TOOL
|
||||
and DEFAULT_MESSAGE_TOOL_KWARG in func_args
|
||||
):
|
||||
assert isinstance(letta_message, AssistantMessage)
|
||||
assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
# Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage
|
||||
else:
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
assert tool_call.function.name == letta_message.function_call.name
|
||||
assert tool_call.function.arguments == letta_message.function_call.arguments
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
elif message.role == MessageRole.user:
|
||||
print(f"i={i}, M=user, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, UserMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.system:
|
||||
print(f"i={i}, M=system, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, SystemMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.tool:
|
||||
print(f"i={i}, M=tool, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, FunctionReturn)
|
||||
# Check the the value in `text` is the same
|
||||
assert message.text == letta_message.function_return
|
||||
letta_message_index += 1
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected message role: {message.role}")
|
||||
|
||||
# Move to the next message in the original messages list
|
||||
break
|
||||
|
||||
|
||||
def test_get_messages_letta_format(server, user_id, agent_id):
|
||||
for reverse in [False, True]:
|
||||
for assistant_message in [False, True]:
|
||||
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse, use_assistant_message=assistant_message)
|
||||
|
||||
|
||||
def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
|
||||
"""Test the /rethink, /rewrite, and /retry commands in the CLI
|
||||
|
||||
@@ -597,3 +387,165 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
|
||||
# test that another user in the same org can delete the agent
|
||||
server.delete_agent(another_user.id, agent_state.id)
|
||||
|
||||
|
||||
def _test_get_messages_letta_format(
|
||||
server,
|
||||
user_id,
|
||||
agent_id,
|
||||
reverse=False,
|
||||
):
|
||||
"""Reverse is off by default, the GET goes in chronological order"""
|
||||
|
||||
messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
return_message_object=True,
|
||||
)
|
||||
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
return_message_object=False,
|
||||
)
|
||||
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
|
||||
assert all(isinstance(m, LettaMessage) for m in letta_messages)
|
||||
|
||||
# Loop through `messages` while also looping through `letta_messages`
|
||||
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
|
||||
# If role of message (in `messages`) is `assistant`,
|
||||
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
|
||||
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
|
||||
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
|
||||
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
|
||||
|
||||
print("MESSAGES (obj):")
|
||||
for i, m in enumerate(messages):
|
||||
# print(m)
|
||||
print(f"{i}: {m.role}, {m.text[:50]}...")
|
||||
# print(m.role)
|
||||
|
||||
print("MEMGPT_MESSAGES:")
|
||||
for i, m in enumerate(letta_messages):
|
||||
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
|
||||
|
||||
# Collect system messages and their texts
|
||||
system_messages = [m for m in messages if m.role == MessageRole.system]
|
||||
system_texts = [m.text for m in system_messages]
|
||||
|
||||
# If there are multiple system messages, print the diff
|
||||
if len(system_messages) > 1:
|
||||
print("Differences between system messages:")
|
||||
for i in range(len(system_texts) - 1):
|
||||
for j in range(i + 1, len(system_texts)):
|
||||
import difflib
|
||||
|
||||
diff = difflib.unified_diff(
|
||||
system_texts[i].splitlines(),
|
||||
system_texts[j].splitlines(),
|
||||
fromfile=f"System Message {i+1}",
|
||||
tofile=f"System Message {j+1}",
|
||||
lineterm="",
|
||||
)
|
||||
print("\n".join(diff))
|
||||
else:
|
||||
print("There is only one or no system message.")
|
||||
|
||||
letta_message_index = 0
|
||||
for i, message in enumerate(messages):
|
||||
assert isinstance(message, Message)
|
||||
|
||||
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
|
||||
while letta_message_index < len(letta_messages):
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
|
||||
|
||||
if message.role == MessageRole.assistant:
|
||||
print(f"i={i}, M=assistant, MM={type(letta_message)}")
|
||||
|
||||
# If reverse, function call will come first
|
||||
if reverse:
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
if message.text is not None:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
else:
|
||||
|
||||
if message.text is not None:
|
||||
assert isinstance(letta_message, InternalMonologue)
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
else:
|
||||
# If there's no inner thoughts then there needs to be a tool call
|
||||
assert message.tool_calls is not None
|
||||
|
||||
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
|
||||
if message.tool_calls is not None:
|
||||
for tool_call in message.tool_calls:
|
||||
|
||||
# Try to parse the tool call args
|
||||
try:
|
||||
json.loads(tool_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
|
||||
|
||||
assert isinstance(letta_message, FunctionCallMessage)
|
||||
assert tool_call.function.name == letta_message.function_call.name
|
||||
assert tool_call.function.arguments == letta_message.function_call.arguments
|
||||
letta_message_index += 1
|
||||
letta_message = letta_messages[letta_message_index]
|
||||
|
||||
elif message.role == MessageRole.user:
|
||||
print(f"i={i}, M=user, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, UserMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.system:
|
||||
print(f"i={i}, M=system, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, SystemMessage)
|
||||
assert message.text == letta_message.message
|
||||
letta_message_index += 1
|
||||
|
||||
elif message.role == MessageRole.tool:
|
||||
print(f"i={i}, M=tool, MM={type(letta_message)}")
|
||||
assert isinstance(letta_message, FunctionReturn)
|
||||
# Check the the value in `text` is the same
|
||||
assert message.text == letta_message.function_return
|
||||
letta_message_index += 1
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unexpected message role: {message.role}")
|
||||
|
||||
# Move to the next message in the original messages list
|
||||
break
|
||||
|
||||
|
||||
def test_get_messages_letta_format(server, user_id, agent_id):
|
||||
for reverse in [False, True]:
|
||||
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
|
||||
|
||||
Reference in New Issue
Block a user