feat: Cursor-based pagination for storage connectors and server (#830)
This commit is contained in:
@@ -4,7 +4,8 @@ import psycopg
|
||||
|
||||
|
||||
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text, JSON, BLOB, BINARY, ARRAY, DateTime
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import func, or_, and_
|
||||
from sqlalchemy import desc, asc
|
||||
from sqlalchemy.orm import sessionmaker, mapped_column, declarative_base
|
||||
from sqlalchemy.orm.session import close_all_sessions
|
||||
from sqlalchemy.sql import func
|
||||
@@ -15,7 +16,7 @@ import uuid
|
||||
|
||||
import re
|
||||
from tqdm import tqdm
|
||||
from typing import Optional, List, Iterator, Dict
|
||||
from typing import Optional, List, Iterator, Dict, Tuple
|
||||
import numpy as np
|
||||
from tqdm import tqdm
|
||||
import pandas as pd
|
||||
@@ -67,16 +68,13 @@ class CommonVector(TypeDecorator):
|
||||
if value:
|
||||
assert isinstance(value, np.ndarray) or isinstance(value, list), f"Value must be of type np.ndarray or list, got {type(value)}"
|
||||
assert isinstance(value[0], float), f"Value must be of type float, got {type(value[0])}"
|
||||
# print("WRITE", np.array(value).tobytes())
|
||||
return np.array(value).tobytes()
|
||||
else:
|
||||
# print("WRITE", value, type(value))
|
||||
return value
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
if not value:
|
||||
return value
|
||||
# print("dialect", dialect, type(value))
|
||||
return np.frombuffer(value)
|
||||
|
||||
|
||||
@@ -125,6 +123,10 @@ def get_db_model(
|
||||
raise ValueError(f"User {user_id} not found")
|
||||
embedding_dim = user.default_embedding_config.embedding_dim
|
||||
|
||||
# this cannot be the case if we are making an agent-specific table
|
||||
assert table_type != TableType.RECALL_MEMORY, f"Agent {agent_id} not found"
|
||||
assert table_type != TableType.ARCHIVAL_MEMORY, f"Agent {agent_id} not found"
|
||||
|
||||
# Define a helper function to create or get the model class
|
||||
def create_or_get_model(class_name, base_model, table_name):
|
||||
if class_name in globals():
|
||||
@@ -276,6 +278,57 @@ class SQLStorageConnector(StorageConnector):
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
|
||||
def get_all_cursor(
|
||||
self,
|
||||
filters: Optional[Dict] = {},
|
||||
after: uuid.UUID = None,
|
||||
before: uuid.UUID = None,
|
||||
limit: Optional[int] = 1000,
|
||||
order_by: str = "created_at",
|
||||
reverse: bool = False,
|
||||
):
|
||||
"""Get all that returns a cursor (record.id) and records"""
|
||||
filters = self.get_filters(filters)
|
||||
|
||||
# generate query
|
||||
query = self.session.query(self.db_model).filter(*filters)
|
||||
# query = query.order_by(asc(self.db_model.id))
|
||||
|
||||
# records are sorted by the order_by field first, and then by the ID if two fields are the same
|
||||
if reverse:
|
||||
query = query.order_by(desc(getattr(self.db_model, order_by)), asc(self.db_model.id))
|
||||
else:
|
||||
query = query.order_by(asc(getattr(self.db_model, order_by)), asc(self.db_model.id))
|
||||
|
||||
# cursor logic: filter records based on before/after ID
|
||||
if after:
|
||||
after_value = getattr(self.get(id=after), order_by)
|
||||
if reverse: # if reverse, then we want to get records that are less than the after_value
|
||||
sort_exp = getattr(self.db_model, order_by) < after_value
|
||||
else: # otherwise, we want to get records that are greater than the after_value
|
||||
sort_exp = getattr(self.db_model, order_by) > after_value
|
||||
query = query.filter(
|
||||
or_(sort_exp, and_(getattr(self.db_model, order_by) == after_value, self.db_model.id > after)) # tiebreaker case
|
||||
)
|
||||
if before:
|
||||
before_value = getattr(self.get(id=before), order_by)
|
||||
if reverse:
|
||||
sort_exp = getattr(self.db_model, order_by) > before_value
|
||||
else:
|
||||
sort_exp = getattr(self.db_model, order_by) < before_value
|
||||
query = query.filter(or_(sort_exp, and_(getattr(self.db_model, order_by) == before_value, self.db_model.id < before)))
|
||||
|
||||
# get records
|
||||
db_record_chunk = query.limit(limit).all()
|
||||
if not db_record_chunk:
|
||||
return None
|
||||
records = [record.to_record() for record in db_record_chunk]
|
||||
next_cursor = db_record_chunk[-1].id
|
||||
assert isinstance(next_cursor, uuid.UUID)
|
||||
|
||||
# return (cursor, list[records])
|
||||
return (next_cursor, records)
|
||||
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]:
|
||||
filters = self.get_filters(filters)
|
||||
if limit:
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import abstractmethod
|
||||
from typing import Union, Callable
|
||||
from typing import Union, Callable, Optional, Tuple
|
||||
import uuid
|
||||
import json
|
||||
import logging
|
||||
@@ -513,11 +513,15 @@ class SyncServer(LockingServer):
|
||||
llm_config=agent_config["llm_config"] if "llm_config" in agent_config else user.default_llm_config,
|
||||
embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else user.default_embedding_config,
|
||||
)
|
||||
# NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation
|
||||
self.ms.create_agent(agent_state)
|
||||
|
||||
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
|
||||
try:
|
||||
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self.ms.delete_agent(agent_id=agent_state.id)
|
||||
raise
|
||||
|
||||
logger.info(f"Created new agent from config: {agent}")
|
||||
@@ -647,6 +651,56 @@ class SyncServer(LockingServer):
|
||||
json_passages = [vars(record) for record in page]
|
||||
return json_passages
|
||||
|
||||
def get_agent_archival_cursor(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
after: Optional[uuid.UUID] = None,
|
||||
before: Optional[uuid.UUID] = None,
|
||||
limit: Optional[int] = 100,
|
||||
order_by: Optional[str] = "created_at",
|
||||
reverse: Optional[bool] = False,
|
||||
):
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
# iterate over recorde
|
||||
cursor, records = memgpt_agent.persistence_manager.archival_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
)
|
||||
json_records = [vars(record) for record in records]
|
||||
return cursor, json_records
|
||||
|
||||
def get_agent_recall_cursor(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
after: Optional[uuid.UUID] = None,
|
||||
before: Optional[uuid.UUID] = None,
|
||||
limit: Optional[int] = 100,
|
||||
order_by: Optional[str] = "created_at",
|
||||
reverse: Optional[bool] = False,
|
||||
):
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
# iterate over records
|
||||
cursor, records = memgpt_agent.persistence_manager.recall_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
)
|
||||
json_records = [vars(record) for record in records]
|
||||
|
||||
# TODO: mark what is in-context versus not
|
||||
return cursor, json_records
|
||||
|
||||
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
|
||||
"""Return the config of an agent"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import uuid
|
||||
import os
|
||||
|
||||
import memgpt.utils as utils
|
||||
|
||||
utils.DEBUG = True
|
||||
@@ -8,6 +7,7 @@ from memgpt.config import MemGPTConfig
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.metadata import MetadataStore
|
||||
from .utils import wipe_config, wipe_memgpt_home
|
||||
|
||||
|
||||
@@ -24,6 +24,7 @@ def test_server():
|
||||
config.save()
|
||||
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
ms = MetadataStore(config)
|
||||
server = SyncServer()
|
||||
|
||||
try:
|
||||
@@ -44,9 +45,10 @@ def test_server():
|
||||
embedding_dim=1536,
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
|
||||
print("Using OpenAI embeddings")
|
||||
else:
|
||||
embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384)
|
||||
print("Using local embeddings")
|
||||
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
@@ -67,41 +69,69 @@ def test_server():
|
||||
|
||||
print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory"))
|
||||
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
|
||||
# test recall memory
|
||||
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
|
||||
assert len(messages_1) == 1
|
||||
|
||||
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
|
||||
# not sure exactly how many messages there should be
|
||||
assert len(messages_2) > len(messages_3)
|
||||
|
||||
# test safe empty return
|
||||
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
assert len(messages_none) == 0
|
||||
|
||||
# test archival memory
|
||||
# add data into archival memory
|
||||
agent = server._load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
archival_memories = ["Cinderella wore a blue dress", "Dog eat dog", "Shishir loves indian food"]
|
||||
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
||||
embed_model = embedding_model(embedding_config)
|
||||
for text in archival_memories:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
agent.persistence_manager.archival_memory.storage.insert(
|
||||
Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding)
|
||||
)
|
||||
|
||||
# add data into recall memory
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
|
||||
# test recall memory cursor pagination
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=1000)
|
||||
ids3 = [m["id"] for m in messages_3]
|
||||
ids2 = [m["id"] for m in messages_2]
|
||||
timestamps = [m["created_at"] for m in messages_3]
|
||||
print("timestamps", timestamps)
|
||||
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test archival memory cursor pagination
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
|
||||
)
|
||||
cursor2, passages_2 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
|
||||
)
|
||||
cursor3, passages_3 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||
)
|
||||
print("p1", [p["text"] for p in passages_1])
|
||||
print("p2", [p["text"] for p in passages_2])
|
||||
print("p3", [p["text"] for p in passages_3])
|
||||
assert passages_1[0]["text"] == "alpha"
|
||||
assert len(passages_2) == 3
|
||||
assert len(passages_3) == 4
|
||||
|
||||
# test recall memory
|
||||
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
|
||||
assert len(messages_1) == 1
|
||||
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
|
||||
# not sure exactly how many messages there should be
|
||||
assert len(messages_2) > len(messages_3)
|
||||
# test safe empty return
|
||||
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
assert len(messages_none) == 0
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
|
||||
assert len(passage_1) == 1
|
||||
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
|
||||
assert len(passage_2) == 2
|
||||
|
||||
print(passage_1)
|
||||
|
||||
assert len(passage_2) == 4
|
||||
# test safe empty return
|
||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
Reference in New Issue
Block a user