@@ -1,44 +1,54 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import memgpt.utils as utils
|
||||
from memgpt.constants import BASE_TOOLS
|
||||
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.schemas.agent import CreateAgent
|
||||
from memgpt.schemas.memory import ChatMemory
|
||||
from memgpt.schemas.source import SourceCreate
|
||||
from memgpt.schemas.user import UserCreate
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.memory import ChatMemory
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
from .utils import DummyDataConnector
|
||||
from .utils import DummyDataConnector, create_config, wipe_config, wipe_memgpt_home
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
# create_config("openai")
|
||||
# credentials = MemGPTCredentials(
|
||||
# openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# )
|
||||
# else: # hosted
|
||||
# create_config("memgpt_hosted")
|
||||
# credentials = MemGPTCredentials()
|
||||
load_dotenv()
|
||||
wipe_config()
|
||||
wipe_memgpt_home()
|
||||
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = settings.memgpt_pg_uri
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
credentials = MemGPTCredentials(
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
create_config("memgpt_hosted")
|
||||
credentials = MemGPTCredentials()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
print("CONFIG PATH", config.config_path)
|
||||
|
||||
## set to use postgres
|
||||
# config.archival_storage_uri = db_url
|
||||
# config.recall_storage_uri = db_url
|
||||
# config.metadata_storage_uri = db_url
|
||||
# config.archival_storage_type = "postgres"
|
||||
# config.recall_storage_type = "postgres"
|
||||
# config.metadata_storage_type = "postgres"
|
||||
# set to use postgres
|
||||
config.archival_storage_uri = db_url
|
||||
config.recall_storage_uri = db_url
|
||||
config.metadata_storage_uri = db_url
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
config.metadata_storage_type = "postgres"
|
||||
|
||||
config.save()
|
||||
credentials.save()
|
||||
|
||||
server = SyncServer()
|
||||
return server
|
||||
@@ -47,7 +57,7 @@ def server():
|
||||
@pytest.fixture(scope="module")
|
||||
def user_id(server):
|
||||
# create user
|
||||
user = server.create_user(UserCreate(name="test_user"))
|
||||
user = server.create_user()
|
||||
print(f"Created user\n{user.id}")
|
||||
|
||||
yield user.id
|
||||
@@ -60,8 +70,7 @@ def user_id(server):
|
||||
def agent_id(server, user_id):
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(name="test_agent", tools=BASE_TOOLS, memory=ChatMemory(human="Sarah", persona="I am a helpful assistant")),
|
||||
user_id=user_id,
|
||||
user_id=user_id, name="test_agent", tools=BASE_TOOLS, memory=ChatMemory(human="I am Chad", persona="I love testing")
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
@@ -99,7 +108,7 @@ def test_user_message(server, user_id, agent_id):
|
||||
@pytest.mark.order(3)
|
||||
def test_load_data(server, user_id, agent_id):
|
||||
# create source
|
||||
source = server.create_source(SourceCreate(name="test_source"), user_id=user_id)
|
||||
source = server.create_source("test_source", user_id)
|
||||
|
||||
# load data
|
||||
archival_memories = [
|
||||
@@ -136,73 +145,72 @@ def test_save_archival_memory(server, user_id, agent_id):
|
||||
def test_user_message(server, user_id, agent_id):
|
||||
# add data into recall memory
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
|
||||
|
||||
@pytest.mark.order(5)
|
||||
def test_get_recall_memory(server, user_id, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
cursor1 = messages_1[-1].id
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
messages_2[-1].id
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
messages_3[-1].id
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
# [m["id"] for m in messages_3]
|
||||
# [m["id"] for m in messages_2]
|
||||
timestamps = [m.created_at for m in messages_3]
|
||||
timestamps = [m["created_at"] for m in messages_3]
|
||||
print("timestamps", timestamps)
|
||||
assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
assert messages_3[-1]["created_at"] >= messages_3[0]["created_at"]
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
print("MESSAGES")
|
||||
for m in messages_3:
|
||||
print(m["id"], m["role"])
|
||||
if m["role"] == "assistant":
|
||||
print(m["text"])
|
||||
print("------------")
|
||||
|
||||
# test in-context message ids
|
||||
all_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000)
|
||||
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
# TODO: doesn't pass since recall memory also logs all system message changess
|
||||
# print("IN CONTEXT:", [m.text for m in server.get_in_context_messages(agent_id=agent_id)])
|
||||
# print("ALL:", [m.text for m in all_messages])
|
||||
# print()
|
||||
# for message in all_messages:
|
||||
# if message.id not in in_context_ids:
|
||||
# print("NOT IN CONTEXT:", message.id, message.created_at, message.text[-100:])
|
||||
# print()
|
||||
# assert len(in_context_ids) == len(messages_3)
|
||||
message_ids = [m.id for m in messages_3]
|
||||
for message_id in in_context_ids:
|
||||
assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
all_messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1000)
|
||||
print("num messages", len(all_messages))
|
||||
in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id)
|
||||
print(in_context_ids)
|
||||
for m in messages_3:
|
||||
if str(m["id"]) not in [str(i) for i in in_context_ids]:
|
||||
print("missing", m["id"], m["role"])
|
||||
assert len(in_context_ids) == len(messages_3)
|
||||
assert isinstance(in_context_ids[0], uuid.UUID)
|
||||
message_ids = [m["id"] for m in messages_3]
|
||||
for message_id in message_ids:
|
||||
assert message_id in in_context_ids, f"{message_id} not in {in_context_ids}"
|
||||
|
||||
# test recall memory
|
||||
messages_1 = server.get_agent_messages(agent_id=agent_id, start=0, count=1)
|
||||
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1)
|
||||
assert len(messages_1) == 1
|
||||
messages_2 = server.get_agent_messages(agent_id=agent_id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(agent_id=agent_id, start=1, count=2)
|
||||
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1, count=2)
|
||||
# not sure exactly how many messages there should be
|
||||
assert len(messages_2) > len(messages_3)
|
||||
# test safe empty return
|
||||
messages_none = server.get_agent_messages(agent_id=agent_id, start=1000, count=1000)
|
||||
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=1000, count=1000)
|
||||
assert len(messages_none) == 0
|
||||
|
||||
|
||||
@pytest.mark.order(6)
|
||||
def test_get_archival_memory(server, user_id, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||
assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
|
||||
cursor1 = passages_1[-1].id
|
||||
passages_2 = server.get_agent_archival_cursor(
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||
cursor2, passages_2 = server.get_agent_archival_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
after=cursor1,
|
||||
order_by="text",
|
||||
)
|
||||
cursor2 = passages_2[-1].id
|
||||
passages_3 = server.get_agent_archival_cursor(
|
||||
cursor3, passages_3 = server.get_agent_archival_cursor(
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
@@ -210,8 +218,10 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
limit=1000,
|
||||
order_by="text",
|
||||
)
|
||||
passages_3[-1].id
|
||||
assert passages_1[0].text == "Cinderella wore a blue dress"
|
||||
print("p1", [p["text"] for p in passages_1])
|
||||
print("p2", [p["text"] for p in passages_2])
|
||||
print("p3", [p["text"] for p in passages_3])
|
||||
assert passages_1[0]["text"] == "alpha"
|
||||
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
|
||||
|
||||
Reference in New Issue
Block a user