feat: Cursor-based pagination for storage connectors and server (#830)

This commit is contained in:
Sarah Wooders
2024-01-16 14:45:20 -08:00
committed by GitHub
parent c441bf15b7
commit 92bbf83fc9
3 changed files with 170 additions and 33 deletions

View File

@@ -1,6 +1,5 @@
import uuid
import os
import memgpt.utils as utils
utils.DEBUG = True
@@ -8,6 +7,7 @@ from memgpt.config import MemGPTConfig
from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage
from memgpt.embeddings import embedding_model
from memgpt.metadata import MetadataStore
from .utils import wipe_config, wipe_memgpt_home
@@ -24,6 +24,7 @@ def test_server():
config.save()
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
server = SyncServer()
try:
@@ -44,9 +45,10 @@ def test_server():
embedding_dim=1536,
openai_key=os.getenv("OPENAI_API_KEY"),
)
print("Using OpenAI embeddings")
else:
embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384)
print("Using local embeddings")
agent_state = server.create_agent(
user_id=user_id,
@@ -67,41 +69,69 @@ def test_server():
print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory"))
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
# test recall memory
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
assert len(messages_none) == 0
# test archival memory
# add data into archival memory
agent = server._load_agent(user_id=user_id, agent_id=agent_state.id)
archival_memories = ["Cinderella wore a blue dress", "Dog eat dog", "Shishir loves indian food"]
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
embed_model = embedding_model(embedding_config)
for text in archival_memories:
embedding = embed_model.get_text_embedding(text)
agent.persistence_manager.archival_memory.storage.insert(
Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding)
)
# add data into recall memory
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
# test recall memory cursor pagination
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=1000)
ids3 = [m["id"] for m in messages_3]
ids2 = [m["id"] for m in messages_2]
timestamps = [m["created_at"] for m in messages_3]
print("timestamps", timestamps)
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
assert len(messages_3) == len(messages_1) + len(messages_2)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
)
cursor2, passages_2 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
)
cursor3, passages_3 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
)
print("p1", [p["text"] for p in passages_1])
print("p2", [p["text"] for p in passages_2])
print("p3", [p["text"] for p in passages_3])
assert passages_1[0]["text"] == "alpha"
assert len(passages_2) == 3
assert len(passages_3) == 4
# test recall memory
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
assert len(messages_none) == 0
# test archival memory
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
assert len(passage_2) == 2
print(passage_1)
assert len(passage_2) == 4
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
assert len(passage_none) == 0