From e2b29d899569095c3c7838d3c8df0ed338d6dc6e Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 22 Dec 2023 10:29:27 +0400 Subject: [PATCH] Bugfixes and test updates for passing tests for both postgres and chroma --- memgpt/connectors/chroma.py | 17 +++--- memgpt/connectors/db.py | 102 +++++++++++++++++++++++++++-------- memgpt/connectors/local.py | 63 +++++++++++++++------- memgpt/connectors/storage.py | 7 +++ tests/test_storage.py | 54 ++++++++++--------- 5 files changed, 167 insertions(+), 76 deletions(-) diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py index 3d303e2b..10c42f15 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/connectors/chroma.py @@ -19,6 +19,8 @@ class ChromaStorageConnector(StorageConnector): super().__init__(table_type=table_type, agent_config=agent_config) config = MemGPTConfig.load() + assert table_type == TableType.ARCHIVAL_MEMORY, "Chroma only supports archival memory" + # create chroma client if config.archival_storage_path: self.client = chromadb.PersistentClient(config.archival_storage_path) @@ -34,7 +36,6 @@ class ChromaStorageConnector(StorageConnector): def get_filters(self, filters: Optional[Dict] = {}): # get all filters for query - print("GET FILTER", filters) if filters is not None: filter_conditions = {**self.filters, **filters} else: @@ -53,12 +54,9 @@ class ChromaStorageConnector(StorageConnector): def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: offset = 0 ids, filters = self.get_filters(filters) - print("FILTERS", filters) while True: # Retrieve a chunk of records with the given page_size - print("querying...", self.collection.count(), "offset", offset, "page", page_size) results = self.collection.get(ids=ids, offset=offset, limit=page_size, include=self.include, where=filters) - print(len(results["embeddings"])) # If the chunk is empty, we've retrieved all records if len(results["embeddings"]) == 0: @@ -72,9 +70,6 @@ class ChromaStorageConnector(StorageConnector): def results_to_records(self, results): # convert timestamps to datetime - print("ID", results["ids"]) - print("ID TYPE", type(results["ids"][0])) - print(uuid.UUID(results["ids"][0])) for metadata in results["metadatas"]: if "created_at" in metadata: metadata["created_at"] = timestamp_to_datetime(metadata["created_at"]) @@ -126,13 +121,11 @@ class ChromaStorageConnector(StorageConnector): record_metadata = {} metadata = {key: value for key, value in metadata.items() if value is not None} # null values not allowed metadata = {**metadata, **record_metadata} # merge with metadata - print("m", metadata) metadatas.append(metadata) return ids, documents, embeddings, metadatas def insert(self, record: Record): ids, documents, embeddings, metadatas = self.format_records([record]) - print("metadata", record, metadatas) if not any(embeddings): self.collection.add(documents=documents, ids=ids, metadatas=metadatas) else: @@ -149,10 +142,14 @@ class ChromaStorageConnector(StorageConnector): ids, filters = self.get_filters(filters) self.collection.delete(ids=ids, where=filters) + def delete_table(self): + # drop collection + self.client.delete_collection(self.collection.name) + def save(self): # save to persistence file (nothing needs to be done) printd("Saving chroma") - pass + raise NotImplementedError def size(self, filters: Optional[Dict] = {}) -> int: # unfortuantely, need to use pagination to get filtering diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 5134c803..29974bc7 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -4,6 +4,7 @@ import psycopg from sqlalchemy import create_engine, Column, String, BIGINT, select, inspect, text +from sqlalchemy import func from sqlalchemy.orm import sessionmaker, mapped_column from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.sql import func @@ -81,12 +82,18 @@ def get_db_model(table_name: str, table_type: TableType): id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) user_id = Column(String, nullable=False) agent_id = Column(String, nullable=False) + + # openai info role = Column(String, nullable=False) text = Column(String, nullable=False) model = Column(String, nullable=False) + user = Column(String) # optional: multi-agent only + + # function info function_name = Column(String) function_args = Column(String) function_response = Column(String) + embedding = mapped_column(Vector(config.embedding_dim)) # Add a datetime column, with default value as the current time @@ -100,6 +107,7 @@ def get_db_model(table_name: str, table_type: TableType): user_id=self.user_id, agent_id=self.agent_id, role=self.role, + user=self.user, text=self.text, model=self.model, function_name=self.function_name, @@ -118,7 +126,7 @@ def get_db_model(table_name: str, table_type: TableType): raise ValueError(f"Table type {table_type} not implemented") -class PostgresStorageConnector(StorageConnector): +class SQLStorageConnector(StorageConnector): """Storage via Postgres""" # TODO: this should probably eventually be moved into a parent DB class @@ -127,6 +135,8 @@ class PostgresStorageConnector(StorageConnector): super().__init__(table_type=table_type, agent_config=agent_config) config = MemGPTConfig.load() + # TODO: only support recall memory (need postgres for archival) + # get storage URI if table_type == TableType.ARCHIVAL_MEMORY: self.uri = config.archival_storage_uri @@ -155,20 +165,20 @@ class PostgresStorageConnector(StorageConnector): return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] - def get_all_paginated(self, page_size: int, filters: Optional[Dict]) -> Iterator[List[Record]]: + def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: session = self.Session() offset = 0 filters = self.get_filters(filters) while True: # Retrieve a chunk of records with the given page_size - db_passages_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all() + db_record_chunk = session.query(self.db_model).filter(*filters).offset(offset).limit(page_size).all() # If the chunk is empty, we've retrieved all records - if not db_passages_chunk: + if not db_record_chunk: break # Yield a list of Record objects converted from the chunk - yield [self.type(**p.to_dict()) for p in db_passages_chunk] + yield [record.to_record() for record in db_record_chunk] # Increment the offset to get the next chunk in the next iteration offset += page_size @@ -179,10 +189,9 @@ class PostgresStorageConnector(StorageConnector): db_records = session.query(self.db_model).filter(*filters).limit(limit).all() return [record.to_record() for record in db_records] - def get(self, id: str, filters: Optional[Dict] = {}) -> Optional[Record]: + def get(self, id: str) -> Optional[Record]: session = self.Session() - filters = self.get_filters(filters) - db_record = session.query(self.db_model).filter(*filters).get(id) + db_record = session.query(self.db_model).get(id) if db_record is None: return None return db_record.to_record() @@ -209,15 +218,7 @@ class PostgresStorageConnector(StorageConnector): session.commit() def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: - session = self.Session() - filters = self.get_filters(filters) - results = session.scalars( - select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k) - ).all() - - # Convert the results into Passage objects - records = [result.to_record() for result in results] - return records + raise NotImplementedError("Vector query not implemented for SQLStorageConnector") def save(self): return @@ -255,11 +256,70 @@ class PostgresStorageConnector(StorageConnector): # todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204 session = self.Session() filters = self.get_filters({}) - results = session.query(self.db_model).filter(*filters).filter(self.db_model.text.contains(query)).all() - print(results) + results = session.query(self.db_model).filter(*filters).filter(func.lower(self.db_model.text).contains(func.lower(query))).all() # return [self.type(**vars(result)) for result in results] return [result.to_record() for result in results] + def delete_table(self): + session = self.Session() + self.db_model.__table__.drop(session.bind) + session.commit() + + def delete(self, filters: Optional[Dict] = {}): + session = self.Session() + filters = self.get_filters(filters) + session.query(self.db_model).filter(*filters).delete() + session.commit() + + +class PostgresStorageConnector(SQLStorageConnector): + """Storage via Postgres""" + + # TODO: this should probably eventually be moved into a parent DB class + + def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): + super().__init__(table_type=table_type, agent_config=agent_config) + self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension + + def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: + session = self.Session() + filters = self.get_filters(filters) + results = session.scalars( + select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k) + ).all() + + # Convert the results into Passage objects + records = [result.to_record() for result in results] + return records + + def delete(self, filters: Optional[Dict] = {}): + session = self.Session() + filters = self.get_filters(filters) + session.query(self.db_model).filter(*filters).delete() + session.commit() + + +class PostgresStorageConnector(SQLStorageConnector): + """Storage via Postgres""" + + # TODO: this should probably eventually be moved into a parent DB class + + def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None): + super().__init__(table_type=table_type, agent_config=agent_config) + self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension + + def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: + session = self.Session() + filters = self.get_filters(filters) + results = session.scalars( + select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k) + ).all() + + # Convert the results into Passage objects + records = [result.to_record() for result in results] + return records + + class LanceDBConnector(StorageConnector): """Storage via LanceDB""" @@ -277,8 +337,6 @@ class LanceDBConnector(StorageConnector): else: raise ValueError("Must specify either agent config or name") - printd(f"Using table name {self.table_name}") - # create table self.uri = config.archival_storage_uri if config.archival_storage_uri is None: @@ -326,7 +384,7 @@ class LanceDBConnector(StorageConnector): if self.table: return len(self.table) else: - print(f"Table with name {self.table_name} not present") + printd(f"Table with name {self.table_name} not present") return 0 def insert(self, passage: Passage): diff --git a/memgpt/connectors/local.py b/memgpt/connectors/local.py index 8eb4854e..793dff76 100644 --- a/memgpt/connectors/local.py +++ b/memgpt/connectors/local.py @@ -1,4 +1,5 @@ from typing import Optional, List, Iterator +import shutil from memgpt.config import AgentConfig, MemGPTConfig from tqdm import tqdm import re @@ -181,39 +182,56 @@ class InMemoryStorageConnector(StorageConnector): raise ValueError(f"Table type {table_type} not supported by InMemoryStorageConnector") # TODO: load if exists + self.agent_config = agent_config if agent_config is None: # is a data source raise ValueError("Cannot load data source from InMemoryStorageConnector") else: directory = agent_config.save_state_dir() - json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. - if not json_files: - print(f"/load error: no .json checkpoint files found") - raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}") + if os.path.exists(directory): + print(f"Loading saved agent {agent_config.name} from {directory}") + json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. + if not json_files: + print(f"/load error: no .json checkpoint files found") + raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}") - # Sort files based on modified timestamp, with the latest file being the first. - filename = max(json_files, key=os.path.getmtime) - state = json.load(open(filename, "r")) + # Sort files based on modified timestamp, with the latest file being the first. + filename = max(json_files, key=os.path.getmtime) + state = json.load(open(filename, "r")) - # load persistence manager - filename = os.path.basename(filename).replace(".json", ".persistence.pickle") - directory = agent_config.save_persistence_manager_dir() - printd(f"Loading persistence manager from {os.path.join(directory, filename)}") - with open(filename, "rb") as f: - data = pickle.load(f) - self.rows = data["all_messages"] + # load persistence manager + filename = os.path.basename(filename).replace(".json", ".persistence.pickle") + directory = agent_config.save_persistence_manager_dir() + printd(f"Loading persistence manager from {os.path.join(directory, filename)}") + with open(filename, "rb") as f: + data = pickle.load(f) + self.rows = data["all_messages"] + else: + print(f"Creating new agent {agent_config.name}") + self.rows = [] # convert to Record class self.rows = [self.json_to_message(m) for m in self.rows] def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: - raise NotImplementedError + offset = 0 + while True: + yield self.rows[offset : offset + page_size] + offset += page_size + if offset >= len(self.rows): + break - def get_all(self, limit: int, filters: Optional[Dict]) -> List[Record]: - raise NotImplementedError + def get_all(self, limit: Optional[int] = None, filters: Optional[Dict] = {}) -> List[Record]: + if limit: + return self.rows[:limit] + return self.rows def get(self, id: str) -> Record: - raise NotImplementedError + match_row = [row for row in self.rows if row.id == id] + if len(match_row) == 0: + return None + assert len(match_row) == 1, f"Expected 1 match, got {len(match_row)} matches" + return match_row[0] def insert(self, record: Record): self.rows.append(record) @@ -284,3 +302,12 @@ class InMemoryStorageConnector(StorageConnector): def query_text(self, query: str) -> List[Record]: return [row for row in self.rows if row.role not in ["system", "function"] and query.lower() in row.text.lower()] + + def delete(self, filters: Optional[Dict] = {}): + raise NotImplementedError + + def delete_table(self, filters: Optional[Dict] = {}): + if os.path.exists(self.agent_config.save_state_dir()): + shutil.rmtree(self.agent_config.save_state_dir()) + if os.path.exists(self.agent_config.save_persistence_manager_dir()): + shutil.rmtree(self.agent_config.save_persistence_manager_dir()) diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 9367a562..764ce1cc 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -117,6 +117,11 @@ class StorageConnector: return LanceDBConnector(agent_config=agent_config, table_type=table_type) + elif storage_type == "local": + from memgpt.connectors.local import InMemoryStorageConnector + + return InMemoryStorageConnector(agent_config=agent_config, table_type=table_type) + else: raise NotImplementedError(f"Storage type {storage_type} not implemented") @@ -134,6 +139,8 @@ class StorageConnector: if storage_type is None: storage_type = MemGPTConfig.load().archival_storage_type + return + if storage_type == "local": from memgpt.connectors.local import VectorIndexStorageConnector diff --git a/tests/test_storage.py b/tests/test_storage.py index b1a85313..3ce717d0 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,4 +1,5 @@ import os +import uuid import subprocess import sys import pytest @@ -11,7 +12,7 @@ import pytest import pgvector # Try to import again after installing from memgpt.connectors.storage import StorageConnector, TableType from memgpt.connectors.chroma import ChromaStorageConnector -from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector +from memgpt.connectors.db import SQLStorageConnector, LanceDBConnector from memgpt.embeddings import embedding_model from memgpt.data_types import Message, Passage from memgpt.config import MemGPTConfig, AgentConfig @@ -22,13 +23,13 @@ from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMA import argparse from datetime import datetime, timedelta - +# Note: the database will filter out rows that do not correspond to agent1 and test_user by default. texts = ["This is a test passage", "This is another test passage", "Cinderella wept"] start_date = datetime(2009, 10, 5, 18, 00) -dates = [start_date - timedelta(weeks=1), start_date, start_date + timedelta(weeks=1)] +dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)] roles = ["user", "agent", "agent"] agent_ids = ["agent1", "agent2", "agent1"] -ids = ["test1", "test2", "test3"] # TODO: generate unique uuid +ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()] user_id = "test_user" @@ -41,16 +42,7 @@ def generate_passages(embed_model): embedding = None if embed_model: embedding = embed_model.get_text_embedding(text) - passages.append( - Passage( - user_id=user_id, - text=text, - agent_id=agent_id, - embedding=embedding, - data_source="test_source", - id=id, - ) - ) + passages.append(Passage(user_id=user_id, text=text, agent_id=agent_id, embedding=embedding, data_source="test_source", id=id)) return passages @@ -65,7 +57,8 @@ def generate_messages(): @pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"]) -@pytest.mark.parametrize("table_type", [TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY]) +# @pytest.mark.parametrize("storage_connector", ["postgres"]) +@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY]) def test_storage(storage_connector, table_type): # setup memgpt config @@ -88,10 +81,16 @@ def test_storage(storage_connector, table_type): config.archival_storage_type = "lancedb" config.recall_storage_type = "lancedb" if storage_connector == "chroma": + if table_type == TableType.RECALL_MEMORY: + print("Skipping test, chroma only supported for archival memory") + return config.archival_storage_type = "chroma" - config.recall_storage_type = "chroma" - config.recall_storage_path = "./test_chroma" config.archival_storage_path = "./test_chroma" + if storage_connector == "local": + if table_type == TableType.ARCHIVAL_MEMORY: + print("Skipping test, local only supported for recall memory") + return + config.recall_storage_type = "local" # get embedding model embed_model = None @@ -116,7 +115,8 @@ def test_storage(storage_connector, table_type): # create storage connector conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config) - conn.delete() # clear out data + # conn.client.delete_collection(conn.collection.name) # clear out data + conn.delete_table() conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config) # override filters @@ -161,6 +161,7 @@ def test_storage(storage_connector, table_type): assert len(all_records) == 1, f"Expected 1 records, got {len(all_records)}" # test: get + print("GET ID", ids[0], records) res = conn.get(id=ids[0]) assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}" @@ -178,8 +179,8 @@ def test_storage(storage_connector, table_type): assert len(res) == 2, f"Expected 2 results, got {len(res)}" assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" - # test optional query functions - if storage_connector != "chroma": + # test optional query functions for recall memory + if table_type == TableType.RECALL_MEMORY: # test: query_text query = "CindereLLa" res = conn.query_text(query) @@ -187,12 +188,13 @@ def test_storage(storage_connector, table_type): assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}" # test: query_date (recall memory only) - if table_type == TableType.RECALL_MEMORY: - print("Testing recall memory date search") - start_date = start_date - timedelta(days=1) - end_date = start_date + timedelta(days=1) - res = conn.query_date(start_date=start_date, end_date=end_date) - assert len(res) == 1, f"Expected 1 result, got {len(res): {res}}" + print("Testing recall memory date search") + start_date = datetime(2009, 10, 5, 18, 00) + start_date = start_date - timedelta(days=1) + end_date = start_date + timedelta(days=1) + res = conn.query_date(start_date=start_date, end_date=end_date) + print("DATE", res) + assert len(res) == 1, f"Expected 1 result, got {len(res)}: {res}" # test: delete conn.delete({"id": ids[0]})