feat: Cursor-based pagination for storage connectors and server (#830)

This commit is contained in:
Sarah Wooders
2024-01-16 14:45:20 -08:00
committed by GitHub
parent c441bf15b7
commit 92bbf83fc9
3 changed files with 170 additions and 33 deletions

View File

@@ -4,7 +4,8 @@ import psycopg
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime
from sqlalchemy import func
from sqlalchemy import func, or_, and_
from sqlalchemy import desc, asc
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
from sqlalchemy.orm.session import close_all_sessions
from sqlalchemy.sql import func
@@ -15,7 +16,7 @@ import uuid
import re
from tqdm import tqdm
from typing import Optional, List, Iterator, Dict
from typing import Optional, List, Iterator, Dict, Tuple
import numpy as np
from tqdm import tqdm
import pandas as pd
@@ -67,16 +68,13 @@ class CommonVector(TypeDecorator):
if value:
assert isinstance(value, np.ndarray) or isinstance(value, list), f"Value must be of type np.ndarray or list, got {type(value)}"
assert isinstance(value[0], float), f"Value must be of type float, got {type(value[0])}"
# print("WRITE", np.array(value).tobytes())
return np.array(value).tobytes()
else:
# print("WRITE", value, type(value))
return value
def process_result_value(self, value, dialect):
if not value:
return value
# print("dialect", dialect, type(value))
return np.frombuffer(value)
@@ -125,6 +123,10 @@ def get_db_model(
raise ValueError(f"User {user_id} not found")
embedding_dim = user.default_embedding_config.embedding_dim
# this cannot be the case if we are making an agent-specific table
assert table_type != TableType.RECALL_MEMORY, f"Agent {agent_id} not found"
assert table_type != TableType.ARCHIVAL_MEMORY, f"Agent {agent_id} not found"
# Define a helper function to create or get the model class
def create_or_get_model(class_name, base_model, table_name):
if class_name in globals():
@@ -276,6 +278,57 @@ class SQLStorageConnector(StorageConnector):
# Increment the offset to get the next chunk in the next iteration
offset += page_size
def get_all_cursor(
self,
filters: Optional[Dict] = {},
after: uuid.UUID = None,
before: uuid.UUID = None,
limit: Optional[int] = 1000,
order_by: str = "created_at",
reverse: bool = False,
):
"""Get all that returns a cursor (record.id) and records"""
filters = self.get_filters(filters)
# generate query
query = self.session.query(self.db_model).filter(*filters)
# query = query.order_by(asc(self.db_model.id))
# records are sorted by the order_by field first, and then by the ID if two fields are the same
if reverse:
query = query.order_by(desc(getattr(self.db_model, order_by)), asc(self.db_model.id))
else:
query = query.order_by(asc(getattr(self.db_model, order_by)), asc(self.db_model.id))
# cursor logic: filter records based on before/after ID
if after:
after_value = getattr(self.get(id=after), order_by)
if reverse: # if reverse, then we want to get records that are less than the after_value
sort_exp = getattr(self.db_model, order_by) < after_value
else: # otherwise, we want to get records that are greater than the after_value
sort_exp = getattr(self.db_model, order_by) > after_value
query = query.filter(
or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
)
if before:
before_value = getattr(self.get(id=before), order_by)
if reverse:
sort_exp = getattr(self.db_model, order_by) > before_value
else:
sort_exp = getattr(self.db_model, order_by) < before_value
query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
# get records
db_record_chunk = query.limit(limit).all()
if not db_record_chunk:
return None
records = [record.to_record() for record in db_record_chunk]
next_cursor = db_record_chunk[-1].id
assert isinstance(next_cursor, uuid.UUID)
# return (cursor, list[records])
return (next_cursor, records)
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]:
filters = self.get_filters(filters)
if limit:

View File

@@ -1,5 +1,5 @@
from abc import abstractmethod
from typing import Union, Callable
from typing import Union, Callable, Optional, Tuple
import uuid
import json
import logging
@@ -513,11 +513,15 @@ class SyncServer(LockingServer):
llm_config=agent_config["llm_config"] if "llm_config" in agent_config else user.default_llm_config,
embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else user.default_embedding_config,
)
# NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation
self.ms.create_agent(agent_state)
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
try:
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface)
except Exception as e:
logger.exception(e)
self.ms.delete_agent(agent_id=agent_state.id)
raise
logger.info(f"Created new agent from config: {agent}")
@@ -647,6 +651,56 @@ class SyncServer(LockingServer):
json_passages = [vars(record) for record in page]
return json_passages
def get_agent_archival_cursor(
self,
user_id: uuid.UUID,
agent_id: uuid.UUID,
after: Optional[uuid.UUID] = None,
before: Optional[uuid.UUID] = None,
limit: Optional[int] = 100,
order_by: Optional[str] = "created_at",
reverse: Optional[bool] = False,
):
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# iterate over recorde
cursor, records = memgpt_agent.persistence_manager.archival_memory.storage.get_all_cursor(
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
)
json_records = [vars(record) for record in records]
return cursor, json_records
def get_agent_recall_cursor(
self,
user_id: uuid.UUID,
agent_id: uuid.UUID,
after: Optional[uuid.UUID] = None,
before: Optional[uuid.UUID] = None,
limit: Optional[int] = 100,
order_by: Optional[str] = "created_at",
reverse: Optional[bool] = False,
):
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# iterate over records
cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor(
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
)
json_records = [vars(record) for record in records]
# TODO: mark what is in-context versus not
return cursor, json_records
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the config of an agent"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real

View File

@@ -1,6 +1,5 @@
import uuid
import os
import memgpt.utils as utils
utils.DEBUG = True
@@ -8,6 +7,7 @@ from memgpt.config import MemGPTConfig
from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage
from memgpt.embeddings import embedding_model
from memgpt.metadata import MetadataStore
from .utils import wipe_config, wipe_memgpt_home
@@ -24,6 +24,7 @@ def test_server():
config.save()
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
server = SyncServer()
try:
@@ -44,9 +45,10 @@ def test_server():
embedding_dim=1536,
openai_key=os.getenv("OPENAI_API_KEY"),
)
print("Using OpenAI embeddings")
else:
embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384)
print("Using local embeddings")
agent_state = server.create_agent(
user_id=user_id,
@@ -67,41 +69,69 @@ def test_server():
print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory"))
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
# test recall memory
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
assert len(messages_none) == 0
# test archival memory
# add data into archival memory
agent = server._load_agent(user_id=user_id, agent_id=agent_state.id)
archival_memories = ["Cinderella wore a blue dress", "Dog eat dog", "Shishir loves indian food"]
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
embed_model = embedding_model(embedding_config)
for text in archival_memories:
embedding = embed_model.get_text_embedding(text)
agent.persistence_manager.archival_memory.storage.insert(
Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding)
)
# add data into recall memory
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
# test recall memory cursor pagination
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=1000)
ids3 = [m["id"] for m in messages_3]
ids2 = [m["id"] for m in messages_2]
timestamps = [m["created_at"] for m in messages_3]
print("timestamps", timestamps)
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
assert len(messages_3) == len(messages_1) + len(messages_2)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
)
cursor2, passages_2 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
)
cursor3, passages_3 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
)
print("p1", [p["text"] for p in passages_1])
print("p2", [p["text"] for p in passages_2])
print("p3", [p["text"] for p in passages_3])
assert passages_1[0]["text"] == "alpha"
assert len(passages_2) == 3
assert len(passages_3) == 4
# test recall memory
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
assert len(messages_none) == 0
# test archival memory
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
assert len(passage_2) == 2
print(passage_1)
assert len(passage_2) == 4
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
assert len(passage_none) == 0