feat: Separate out streaming route (#2111)

This commit is contained in:
Matthew Zhou
2024-11-27 14:03:46 -08:00
committed by GitHub
parent cfb48a112f
commit 5a59d2ac42
16 changed files with 301 additions and 411 deletions

View File

@@ -5,8 +5,17 @@ import warnings
import pytest
import letta.utils as utils
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.constants import BASE_TOOLS
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import (
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
LettaMessage,
SystemMessage,
UserMessage,
)
from letta.schemas.message import Message
from letta.schemas.user import User
from .test_managers import DEFAULT_EMBEDDING_CONFIG
@@ -15,18 +24,8 @@ utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCallMessage,
FunctionReturn,
InternalMonologue,
LettaMessage,
SystemMessage,
UserMessage,
)
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.message import Message
from letta.schemas.source import Source
from letta.server.server import SyncServer
@@ -174,27 +173,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
messages_2[-1].id
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
messages_3[-1].id
# [m["id"] for m in messages_3]
# [m["id"] for m in messages_2]
timestamps = [m.created_at for m in messages_3]
print("timestamps", timestamps)
assert messages_3[-1].created_at >= messages_3[0].created_at
assert len(messages_3) == len(messages_1) + len(messages_2)
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test in-context message ids
all_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
# TODO: doesn't pass since recall memory also logs all system message changess
# print("IN CONTEXT:", [m.text for m in server.get_in_context_messages(agent_id=agent_id)])
# print("ALL:", [m.text for m in all_messages])
# print()
# for message in all_messages:
# if message.id not in in_context_ids:
# print("NOT IN CONTEXT:", message.id, message.created_at, message.text[-100:])
# print()
# assert len(in_context_ids) == len(messages_3)
message_ids = [m.id for m in messages_3]
for message_id in in_context_ids:
assert message_id in message_ids, f"{message_id} not in {message_ids}"
@@ -248,201 +233,6 @@ def test_get_archival_memory(server, user_id, agent_id):
assert len(passage_none) == 0
def _test_get_messages_letta_format(
server,
user_id,
agent_id,
reverse=False,
# flag that determines whether or not to use AssistantMessage, or just FunctionCallMessage universally
use_assistant_message=False,
):
"""Reverse is off by default, the GET goes in chronological order"""
messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
return_message_object=True,
use_assistant_message=use_assistant_message,
)
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
assert all(isinstance(m, Message) for m in messages)
letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
return_message_object=False,
use_assistant_message=use_assistant_message,
)
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
assert all(isinstance(m, LettaMessage) for m in letta_messages)
# Loop through `messages` while also looping through `letta_messages`
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
# If role of message (in `messages`) is `assistant`,
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
print("MESSAGES (obj):")
for i, m in enumerate(messages):
# print(m)
print(f"{i}: {m.role}, {m.text[:50]}...")
# print(m.role)
print("MEMGPT_MESSAGES:")
for i, m in enumerate(letta_messages):
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
# Collect system messages and their texts
system_messages = [m for m in messages if m.role == MessageRole.system]
system_texts = [m.text for m in system_messages]
# If there are multiple system messages, print the diff
if len(system_messages) > 1:
print("Differences between system messages:")
for i in range(len(system_texts) - 1):
for j in range(i + 1, len(system_texts)):
import difflib
diff = difflib.unified_diff(
system_texts[i].splitlines(),
system_texts[j].splitlines(),
fromfile=f"System Message {i+1}",
tofile=f"System Message {j+1}",
lineterm="",
)
print("\n".join(diff))
else:
print("There is only one or no system message.")
letta_message_index = 0
for i, message in enumerate(messages):
assert isinstance(message, Message)
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
while letta_message_index < len(letta_messages):
letta_message = letta_messages[letta_message_index]
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
if message.role == MessageRole.assistant:
print(f"i={i}, M=assistant, MM={type(letta_message)}")
# If reverse, function call will come first
if reverse:
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
func_args = json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
func_args = {}
# If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool
if (
use_assistant_message
and tool_call.function.name == DEFAULT_MESSAGE_TOOL
and DEFAULT_MESSAGE_TOOL_KWARG in func_args
):
assert isinstance(letta_message, AssistantMessage)
assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
# Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage
else:
assert isinstance(letta_message, FunctionCallMessage)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
if message.text is not None:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
else:
if message.text is not None:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
func_args = json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
func_args = {}
# If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool
if (
use_assistant_message
and tool_call.function.name == DEFAULT_MESSAGE_TOOL
and DEFAULT_MESSAGE_TOOL_KWARG in func_args
):
assert isinstance(letta_message, AssistantMessage)
assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
# Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage
else:
assert isinstance(letta_message, FunctionCallMessage)
assert tool_call.function.name == letta_message.function_call.name
assert tool_call.function.arguments == letta_message.function_call.arguments
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
elif message.role == MessageRole.user:
print(f"i={i}, M=user, MM={type(letta_message)}")
assert isinstance(letta_message, UserMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.system:
print(f"i={i}, M=system, MM={type(letta_message)}")
assert isinstance(letta_message, SystemMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.tool:
print(f"i={i}, M=tool, MM={type(letta_message)}")
assert isinstance(letta_message, FunctionReturn)
# Check the the value in `text` is the same
assert message.text == letta_message.function_return
letta_message_index += 1
else:
raise ValueError(f"Unexpected message role: {message.role}")
# Move to the next message in the original messages list
break
def test_get_messages_letta_format(server, user_id, agent_id):
for reverse in [False, True]:
for assistant_message in [False, True]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse, use_assistant_message=assistant_message)
def test_agent_rethink_rewrite_retry(server, user_id, agent_id):
"""Test the /rethink, /rewrite, and /retry commands in the CLI
@@ -597,3 +387,165 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
# test that another user in the same org can delete the agent
server.delete_agent(another_user.id, agent_state.id)
def _test_get_messages_letta_format(
server,
user_id,
agent_id,
reverse=False,
):
"""Reverse is off by default, the GET goes in chronological order"""
messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
return_message_object=True,
)
# messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
assert all(isinstance(m, Message) for m in messages)
letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
return_message_object=False,
)
# letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False)
assert all(isinstance(m, LettaMessage) for m in letta_messages)
# Loop through `messages` while also looping through `letta_messages`
# Each message in `messages` should have 1+ corresponding messages in `letta_messages`
# If role of message (in `messages`) is `assistant`,
# then there should be two messages in `letta_messages`, one which is type InternalMonologue and one which is type FunctionCallMessage.
# If role of message (in `messages`) is `user`, then there should be one message in `letta_messages` which is type UserMessage.
# If role of message (in `messages`) is `system`, then there should be one message in `letta_messages` which is type SystemMessage.
# If role of message (in `messages`) is `tool`, then there should be one message in `letta_messages` which is type FunctionReturn.
print("MESSAGES (obj):")
for i, m in enumerate(messages):
# print(m)
print(f"{i}: {m.role}, {m.text[:50]}...")
# print(m.role)
print("MEMGPT_MESSAGES:")
for i, m in enumerate(letta_messages):
print(f"{i}: {type(m)} ...{str(m)[-50:]}")
# Collect system messages and their texts
system_messages = [m for m in messages if m.role == MessageRole.system]
system_texts = [m.text for m in system_messages]
# If there are multiple system messages, print the diff
if len(system_messages) > 1:
print("Differences between system messages:")
for i in range(len(system_texts) - 1):
for j in range(i + 1, len(system_texts)):
import difflib
diff = difflib.unified_diff(
system_texts[i].splitlines(),
system_texts[j].splitlines(),
fromfile=f"System Message {i+1}",
tofile=f"System Message {j+1}",
lineterm="",
)
print("\n".join(diff))
else:
print("There is only one or no system message.")
letta_message_index = 0
for i, message in enumerate(messages):
assert isinstance(message, Message)
print(f"\n\nmessage {i}: {message.role}, {message.text[:50] if message.text else 'null'}")
while letta_message_index < len(letta_messages):
letta_message = letta_messages[letta_message_index]
print(f"letta_message {letta_message_index}: {str(letta_message)[:50]}")
if message.role == MessageRole.assistant:
print(f"i={i}, M=assistant, MM={type(letta_message)}")
# If reverse, function call will come first
if reverse:
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
if message.text is not None:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
else:
if message.text is not None:
assert isinstance(letta_message, InternalMonologue)
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
else:
# If there's no inner thoughts then there needs to be a tool call
assert message.tool_calls is not None
# If there are multiple tool calls, we should have multiple back to back FunctionCallMessages
if message.tool_calls is not None:
for tool_call in message.tool_calls:
# Try to parse the tool call args
try:
json.loads(tool_call.function.arguments)
except:
warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}")
assert isinstance(letta_message, FunctionCallMessage)
assert tool_call.function.name == letta_message.function_call.name
assert tool_call.function.arguments == letta_message.function_call.arguments
letta_message_index += 1
letta_message = letta_messages[letta_message_index]
elif message.role == MessageRole.user:
print(f"i={i}, M=user, MM={type(letta_message)}")
assert isinstance(letta_message, UserMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.system:
print(f"i={i}, M=system, MM={type(letta_message)}")
assert isinstance(letta_message, SystemMessage)
assert message.text == letta_message.message
letta_message_index += 1
elif message.role == MessageRole.tool:
print(f"i={i}, M=tool, MM={type(letta_message)}")
assert isinstance(letta_message, FunctionReturn)
# Check the the value in `text` is the same
assert message.text == letta_message.function_return
letta_message_index += 1
else:
raise ValueError(f"Unexpected message role: {message.role}")
# Move to the next message in the original messages list
break
def test_get_messages_letta_format(server, user_id, agent_id):
for reverse in [False, True]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)