Bugfixes and test updates for passing tests for both postgres and chroma

This commit is contained in:
Sarah Wooders
2023-12-22 10:29:27 +04:00
parent b4b05bd75d
commit e2b29d8995
5 changed files with 167 additions and 76 deletions

View File

@@ -19,6 +19,8 @@ class ChromaStorageConnector(StorageConnector):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
assert table_type == TableType.ARCHIVAL_MEMORY, "Chroma only supports archival memory"
# create chroma client
if config.archival_storage_path:
self.client = chromadb.PersistentClient(config.archival_storage_path)
@@ -34,7 +36,6 @@ class ChromaStorageConnector(StorageConnector):
def get_filters(self, filters: Optional[Dict] = {}):
# get all filters for query
print("GET FILTER", filters)
if filters is not None:
filter_conditions = {**self.filters, **filters}
else:
@@ -53,12 +54,9 @@ class ChromaStorageConnector(StorageConnector):
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
offset = 0
ids, filters = self.get_filters(filters)
print("FILTERS", filters)
while True:
# Retrieve a chunk of records with the given page_size
print("querying...", self.collection.count(), "offset", offset, "page", page_size)
results = self.collection.get(ids=ids, offset=offset, limit=page_size, include=self.include, where=filters)
print(len(results["embeddings"]))
# If the chunk is empty, we've retrieved all records
if len(results["embeddings"]) == 0:
@@ -72,9 +70,6 @@ class ChromaStorageConnector(StorageConnector):
def results_to_records(self, results):
# convert timestamps to datetime
print("ID", results["ids"])
print("ID TYPE", type(results["ids"][0]))
print(uuid.UUID(results["ids"][0]))
for metadata in results["metadatas"]:
if "created_at" in metadata:
metadata["created_at"] = timestamp_to_datetime(metadata["created_at"])
@@ -126,13 +121,11 @@ class ChromaStorageConnector(StorageConnector):
record_metadata = {}
metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed
metadata = {**metadata, **record_metadata} # merge with metadata
print("m", metadata)
metadatas.append(metadata)
return ids, documents, embeddings, metadatas
def insert(self, record: Record):
ids, documents, embeddings, metadatas = self.format_records([record])
print("metadata", record, metadatas)
if not any(embeddings):
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
else:
@@ -149,10 +142,14 @@ class ChromaStorageConnector(StorageConnector):
ids, filters = self.get_filters(filters)
self.collection.delete(ids=ids, where=filters)
def delete_table(self):
# drop collection
self.client.delete_collection(self.collection.name)
def save(self):
# save to persistence file (nothing needs to be done)
printd("Saving chroma")
pass
raise NotImplementedError
def size(self, filters: Optional[Dict] = {}) -> int:
# unfortuantely, need to use pagination to get filtering

View File

@@ -4,6 +4,7 @@ import psycopg
from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text
from sqlalchemy import func
from sqlalchemy.orm import sessionmaker, mapped_column
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.sql import func
@@ -81,12 +82,18 @@ def get_db_model(table_name: str, table_type: TableType):
id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4)
user_id = Column(String, nullable=False)
agent_id = Column(String, nullable=False)
# openai info
role = Column(String, nullable=False)
text = Column(String, nullable=False)
model = Column(String, nullable=False)
user = Column(String) # optional: multi-agent only
# function info
function_name = Column(String)
function_args = Column(String)
function_response = Column(String)
embedding = mapped_column(Vector(config.embedding_dim))
# Add a datetime column, with default value as the current time
@@ -100,6 +107,7 @@ def get_db_model(table_name: str, table_type: TableType):
user_id=self.user_id,
agent_id=self.agent_id,
role=self.role,
user=self.user,
text=self.text,
model=self.model,
function_name=self.function_name,
@@ -118,7 +126,7 @@ def get_db_model(table_name: str, table_type: TableType):
raise ValueError(f"Table type {table_type} not implemented")
class PostgresStorageConnector(StorageConnector):
class SQLStorageConnector(StorageConnector):
"""Storage via Postgres"""
# TODO: this should probably eventually be moved into a parent DB class
@@ -127,6 +135,8 @@ class PostgresStorageConnector(StorageConnector):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
# TODO: only support recall memory (need postgres for archival)
# get storage URI
if table_type == TableType.ARCHIVAL_MEMORY:
self.uri = config.archival_storage_uri
@@ -155,20 +165,20 @@ class PostgresStorageConnector(StorageConnector):
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]:
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
session = self.Session()
offset = 0
filters = self.get_filters(filters)
while True:
# Retrieve a chunk of records with the given page_size
db_passages_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all()
# If the chunk is empty, we've retrieved all records
if not db_passages_chunk:
if not db_record_chunk:
break
# Yield a list of Record objects converted from the chunk
yield [self.type(**p.to_dict()) for p in db_passages_chunk]
yield [record.to_record() for record in db_record_chunk]
# Increment the offset to get the next chunk in the next iteration
offset += page_size
@@ -179,10 +189,9 @@ class PostgresStorageConnector(StorageConnector):
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
return [record.to_record() for record in db_records]
def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Record]:
def get(self, id: str) -> Optional[Record]:
session = self.Session()
filters = self.get_filters(filters)
db_record = session.query(self.db_model).filter(*filters).get(id)
db_record = session.query(self.db_model).get(id)
if db_record is None:
return None
return db_record.to_record()
@@ -209,15 +218,7 @@ class PostgresStorageConnector(StorageConnector):
session.commit()
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
filters = self.get_filters(filters)
results = session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
# Convert the results into Passage objects
records = [result.to_record() for result in results]
return records
raise NotImplementedError("Vector query not implemented for SQLStorageConnector")
def save(self):
return
@@ -255,11 +256,70 @@ class PostgresStorageConnector(StorageConnector):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
session = self.Session()
filters = self.get_filters({})
results = session.query(self.db_model).filter(*filters).filter(self.db_model.text.contains(query)).all()
print(results)
results = session.query(self.db_model).filter(*filters).filter(func.lower(self.db_model.text).contains(func.lower(query))).all()
# return [self.type(**vars(result)) for result in results]
return [result.to_record() for result in results]
def delete_table(self):
session = self.Session()
self.db_model.__table__.drop(session.bind)
session.commit()
def delete(self, filters: Optional[Dict] = {}):
session = self.Session()
filters = self.get_filters(filters)
session.query(self.db_model).filter(*filters).delete()
session.commit()
class PostgresStorageConnector(SQLStorageConnector):
"""Storage via Postgres"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
filters = self.get_filters(filters)
results = session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
# Convert the results into Passage objects
records = [result.to_record() for result in results]
return records
def delete(self, filters: Optional[Dict] = {}):
session = self.Session()
filters = self.get_filters(filters)
session.query(self.db_model).filter(*filters).delete()
session.commit()
class PostgresStorageConnector(SQLStorageConnector):
"""Storage via Postgres"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
filters = self.get_filters(filters)
results = session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
# Convert the results into Passage objects
records = [result.to_record() for result in results]
return records
class LanceDBConnector(StorageConnector):
"""Storage via LanceDB"""
@@ -277,8 +337,6 @@ class LanceDBConnector(StorageConnector):
else:
raise ValueError("Must specify either agent config or name")
printd(f"Using table name {self.table_name}")
# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
@@ -326,7 +384,7 @@ class LanceDBConnector(StorageConnector):
if self.table:
return len(self.table)
else:
print(f"Table with name {self.table_name} not present")
printd(f"Table with name {self.table_name} not present")
return 0
def insert(self, passage: Passage):

View File

@@ -1,4 +1,5 @@
from typing import Optional, List, Iterator
import shutil
from memgpt.config import AgentConfig, MemGPTConfig
from tqdm import tqdm
import re
@@ -181,39 +182,56 @@ class InMemoryStorageConnector(StorageConnector):
raise ValueError(f"Table type {table_type} not supported by InMemoryStorageConnector")
# TODO: load if exists
self.agent_config = agent_config
if agent_config is None:
# is a data source
raise ValueError("Cannot load data source from InMemoryStorageConnector")
else:
directory = agent_config.save_state_dir()
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}")
if os.path.exists(directory):
print(f"Loading saved agent {agent_config.name} from {directory}")
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}")
# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
state = json.load(open(filename, "r"))
# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
state = json.load(open(filename, "r"))
# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
with open(filename, "rb") as f:
data = pickle.load(f)
self.rows = data["all_messages"]
# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
with open(filename, "rb") as f:
data = pickle.load(f)
self.rows = data["all_messages"]
else:
print(f"Creating new agent {agent_config.name}")
self.rows = []
# convert to Record class
self.rows = [self.json_to_message(m) for m in self.rows]
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
raise NotImplementedError
offset = 0
while True:
yield self.rows[offset : offset + page_size]
offset += page_size
if offset >= len(self.rows):
break
def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]:
raise NotImplementedError
def get_all(self, limit: Optional[int] = None, filters: Optional[Dict] = {}) -> List[Record]:
if limit:
return self.rows[:limit]
return self.rows
def get(self, id: str) -> Record:
raise NotImplementedError
match_row = [row for row in self.rows if row.id == id]
if len(match_row) == 0:
return None
assert len(match_row) == 1, f"Expected 1 match, got {len(match_row)} matches"
return match_row[0]
def insert(self, record: Record):
self.rows.append(record)
@@ -284,3 +302,12 @@ class InMemoryStorageConnector(StorageConnector):
def query_text(self, query: str) -> List[Record]:
return [row for row in self.rows if row.role not in ["system", "function"] and query.lower() in row.text.lower()]
def delete(self, filters: Optional[Dict] = {}):
raise NotImplementedError
def delete_table(self, filters: Optional[Dict] = {}):
if os.path.exists(self.agent_config.save_state_dir()):
shutil.rmtree(self.agent_config.save_state_dir())
if os.path.exists(self.agent_config.save_persistence_manager_dir()):
shutil.rmtree(self.agent_config.save_persistence_manager_dir())

View File

@@ -117,6 +117,11 @@ class StorageConnector:
return LanceDBConnector(agent_config=agent_config, table_type=table_type)
elif storage_type == "local":
from memgpt.connectors.local import InMemoryStorageConnector
return InMemoryStorageConnector(agent_config=agent_config, table_type=table_type)
else:
raise NotImplementedError(f"Storage type {storage_type} not implemented")
@@ -134,6 +139,8 @@ class StorageConnector:
if storage_type is None:
storage_type = MemGPTConfig.load().archival_storage_type
return
if storage_type == "local":
from memgpt.connectors.local import VectorIndexStorageConnector

View File

@@ -1,4 +1,5 @@
import os
import uuid
import subprocess
import sys
import pytest
@@ -11,7 +12,7 @@ import pytest
import pgvector # Try to import again after installing
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.connectors.chroma import ChromaStorageConnector
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
from memgpt.connectors.db import SQLStorageConnector, LanceDBConnector
from memgpt.embeddings import embedding_model
from memgpt.data_types import Message, Passage
from memgpt.config import MemGPTConfig, AgentConfig
@@ -22,13 +23,13 @@ from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMA
import argparse
from datetime import datetime, timedelta
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
start_date = datetime(2009, 10, 5, 18, 00)
dates = [start_date - timedelta(weeks=1), start_date, start_date + timedelta(weeks=1)]
dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)]
roles = ["user", "agent", "agent"]
agent_ids = ["agent1", "agent2", "agent1"]
ids = ["test1", "test2", "test3"] # TODO: generate unique uuid
ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()]
user_id = "test_user"
@@ -41,16 +42,7 @@ def generate_passages(embed_model):
embedding = None
if embed_model:
embedding = embed_model.get_text_embedding(text)
passages.append(
Passage(
user_id=user_id,
text=text,
agent_id=agent_id,
embedding=embedding,
data_source="test_source",
id=id,
)
)
passages.append(Passage(user_id=user_id, text=text, agent_id=agent_id, embedding=embedding, data_source="test_source", id=id))
return passages
@@ -65,7 +57,8 @@ def generate_messages():
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"])
@pytest.mark.parametrize("table_type", [TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY])
# @pytest.mark.parametrize("storage_connector", ["postgres"])
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
def test_storage(storage_connector, table_type):
# setup memgpt config
@@ -88,10 +81,16 @@ def test_storage(storage_connector, table_type):
config.archival_storage_type = "lancedb"
config.recall_storage_type = "lancedb"
if storage_connector == "chroma":
if table_type == TableType.RECALL_MEMORY:
print("Skipping test, chroma only supported for archival memory")
return
config.archival_storage_type = "chroma"
config.recall_storage_type = "chroma"
config.recall_storage_path = "./test_chroma"
config.archival_storage_path = "./test_chroma"
if storage_connector == "local":
if table_type == TableType.ARCHIVAL_MEMORY:
print("Skipping test, local only supported for recall memory")
return
config.recall_storage_type = "local"
# get embedding model
embed_model = None
@@ -116,7 +115,8 @@ def test_storage(storage_connector, table_type):
# create storage connector
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
conn.delete() # clear out data
# conn.client.delete_collection(conn.collection.name) # clear out data
conn.delete_table()
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
# override filters
@@ -161,6 +161,7 @@ def test_storage(storage_connector, table_type):
assert len(all_records) == 1, f"Expected 1 records, got {len(all_records)}"
# test: get
print("GET ID", ids[0], records)
res = conn.get(id=ids[0])
assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
@@ -178,8 +179,8 @@ def test_storage(storage_connector, table_type):
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
# test optional query functions
if storage_connector != "chroma":
# test optional query functions for recall memory
if table_type == TableType.RECALL_MEMORY:
# test: query_text
query = "CindereLLa"
res = conn.query_text(query)
@@ -187,12 +188,13 @@ def test_storage(storage_connector, table_type):
assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
# test: query_date (recall memory only)
if table_type == TableType.RECALL_MEMORY:
print("Testing recall memory date search")
start_date = start_date - timedelta(days=1)
end_date = start_date + timedelta(days=1)
res = conn.query_date(start_date=start_date, end_date=end_date)
assert len(res) == 1, f"Expected 1 result, got {len(res): {res}}"
print("Testing recall memory date search")
start_date = datetime(2009, 10, 5, 18, 00)
start_date = start_date - timedelta(days=1)
end_date = start_date + timedelta(days=1)
res = conn.query_date(start_date=start_date, end_date=end_date)
print("DATE", res)
assert len(res) == 1, f"Expected 1 result, got {len(res)}: {res}"
# test: delete
conn.delete({"id": ids[0]})