diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 31431c6d..a744d8af 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -494,10 +494,10 @@ def attach( source_storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES) dest_storage = StorageConnector.get_storage_connector(table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config) - size = source_storage.size() + size = source_storage.size({"data_source": data_source}) typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN) page_size = 100 - generator = source_storage.get_all_paginated(page_size=page_size) # yields List[Passage] + generator = source_storage.get_all_paginated(filters={"data_source": data_source}, page_size=page_size) # yields List[Passage] passages = [] for i in tqdm(range(0, size, page_size)): passages = next(generator) diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index d10e593a..7f189d5e 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -47,6 +47,7 @@ def store_docs(name, docs, show_progress=True): # compute and record passages storage = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=config.archival_storage_type) embed_model = embedding_model() + orig_size = storage.size() # use llama index to run embeddings code service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size) @@ -77,6 +78,7 @@ def store_docs(name, docs, show_progress=True): # insert into storage storage.insert_many(passages) + assert orig_size + len(passages) == storage.size(), f"Expected {orig_size + len(passages)} passages, got {storage.size()}" storage.save() diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py index 8d27728f..9e32e0b6 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/connectors/chroma.py @@ -139,7 +139,6 @@ class ChromaStorageConnector(StorageConnector): def insert_many(self, records: List[Record], show_progress=True): ids, documents, embeddings, metadatas = self.format_records(records) - print("Inserting", ids) if not any(embeddings): self.collection.add(documents=documents, ids=ids, metadatas=metadatas) else: diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index a302d6a7..f949b2a3 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -191,8 +191,7 @@ class SQLStorageConnector(StorageConnector): filter_conditions = {**self.filters, **filters} else: filter_conditions = self.filters - print("FILTERS", filter_conditions) - + print("SQL FILTERS", filter_conditions) return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: @@ -213,11 +212,13 @@ class SQLStorageConnector(StorageConnector): # Increment the offset to get the next chunk in the next iteration offset += page_size - def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]: + def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]: session = self.Session() filters = self.get_filters(filters) - print("LIMIT", limit) - db_records = session.query(self.db_model).filter(*filters).limit(limit).all() + if limit: + db_records = session.query(self.db_model).filter(*filters).limit(limit).all() + else: + db_records = session.query(self.db_model).filter(*filters).all() return [record.to_record() for record in db_records] def get(self, id: str) -> Optional[Record]: @@ -229,9 +230,9 @@ class SQLStorageConnector(StorageConnector): def size(self, filters: Optional[Dict] = {}) -> int: # return size of table - print("size") session = self.Session() filters = self.get_filters(filters) + print("ALL FILTERS", filters) return session.query(self.db_model).filter(*filters).count() def insert(self, record: Record): @@ -404,122 +405,3 @@ class SQLLiteStorageConnector(SQLStorageConnector): sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le) sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b)) - - -class LanceDBConnector(StorageConnector): - """Storage via LanceDB""" - - # TODO: this should probably eventually be moved into a parent DB class - - def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None): - config = MemGPTConfig.load() - # determine table name - if agent_config: - assert name is None, f"Cannot specify both agent config and name {name}" - self.table_name = self.generate_table_name_agent(agent_config) - elif name: - assert agent_config is None, f"Cannot specify both agent config and name {name}" - self.table_name = self.generate_table_name(name) - else: - raise ValueError("Must specify either agent config or name") - - # create table - self.uri = config.archival_storage_uri - if config.archival_storage_uri is None: - raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}") - import lancedb - - self.db = lancedb.connect(self.uri) - if self.table_name in self.db.table_names(): - self.table = self.db[self.table_name] - else: - self.table = None - - def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]: - ds = self.table.to_lance() - offset = 0 - while True: - # Retrieve a chunk of records with the given page_size - db_passages_chunk = ds.to_table(offset=offset, limit=page_size).to_pylist() - # If the chunk is empty, we've retrieved all records - if not db_passages_chunk: - break - - # Yield a list of Passage objects converted from the chunk - yield [ - Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk - ] - - # Increment the offset to get the next chunk in the next iteration - offset += page_size - - def get_all(self, limit=10) -> List[Passage]: - db_passages = self.table.to_lance().to_table(limit=limit).to_pylist() - return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages] - - def get(self, id: str) -> Optional[Passage]: - db_passage = self.table.where(f"passage_id={id}").to_list() - if len(db_passage) == 0: - return None - return Passage( - text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"] - ) - - def size(self) -> int: - # return size of table - if self.table: - return len(self.table) - else: - printd(f"Table with name {self.table_name} not present") - return 0 - - def insert(self, passage: Passage): - data = [{"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}] - - if self.table is not None: - self.table.add(data) - else: - self.table = self.db.create_table(self.table_name, data=data, mode="overwrite") - - def insert_many(self, passages: List[Passage], show_progress=True): - data = [] - iterable = tqdm(passages) if show_progress else passages - for passage in iterable: - temp_dict = {"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding} - data.append(temp_dict) - - if self.table is not None: - self.table.add(data) - else: - self.table = self.db.create_table(self.table_name, data=data, mode="overwrite") - - def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]: - # Assuming query_vec is of same length as embeddings inside table - results = self.table.search(query_vec).limit(top_k).to_list() - # Convert the results into Passage objects - passages = [ - Passage(text=result["text"], embedding=result["vector"], doc_id=result["doc_id"], passage_id=result["passage_id"]) - for result in results - ] - return passages - - def delete(self): - """Drop the passage table from the database.""" - # Drop the table specified by the PassageModel class - self.db.drop_table(self.table_name) - - def save(self): - return - - @staticmethod - def list_loaded_data(): - config = MemGPTConfig.load() - import lancedb - - db = lancedb.connect(config.archival_storage_uri) - - tables = db.table_names() - tables = [table for table in tables if table.startswith("memgpt_")] - start_chars = len("memgpt_") - tables = [table[start_chars:] for table in tables] - return tables diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index b2c0462c..1c267359 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -73,6 +73,8 @@ class StorageConnector: else: self.filters = {} + print("FILTERS", self.filters) + def get_filters(self, filters: Optional[Dict] = {}): # get all filters for query if filters is not None: diff --git a/memgpt/memory.py b/memgpt/memory.py index f100147c..cf3446b4 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -440,7 +440,7 @@ class EmbeddingArchivalMemory(ArchivalMemory): def __repr__(self) -> str: limit = 10 passages = [] - for passage in list(self.storage.get_all(limit)): # TODO: only get first 10 + for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10 passages.append(str(passage.text)) memory_str = "\n".join(passages) return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}" diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 5c0240d9..c45db5af 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -4,50 +4,62 @@ import os import pytest from memgpt.connectors.storage import StorageConnector, TableType -# import asyncio -from datasets import load_dataset - # import memgpt from memgpt.cli.cli_load import load_directory, load_database, load_webpage from memgpt.cli.cli import attach from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN from memgpt.config import AgentConfig, MemGPTConfig -# import memgpt.presets as presets -# import memgpt.personas.personas as personas -# import memgpt.humans.humans as humans -# from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager -# # from memgpt.config import AgentConfig -# from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL -# import memgpt.interface # for printing to terminal - - -# @pytest.mark.parametrize("storage_connector", ["sqllite", "postgres"]) -@pytest.mark.parametrize("metadata_storage_connector", ["sqlite"]) -@pytest.mark.parametrize("passage_storage_connector", ["chroma"]) +@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"]) +@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"]) def test_load_directory(metadata_storage_connector, passage_storage_connector): + # setup config + config = MemGPTConfig() + if metadata_storage_connector == "postgres": + if not os.getenv("PGVECTOR_TEST_DB_URL"): + print("Skipping test, missing PG URI") + return + config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + config.metadata_storage_type = "postgres" + elif metadata_storage_connector == "sqlite": + print("testing sqlite metadata") + # nothing to do (should be config defaults) + else: + raise NotImplementedError(f"Storage type {metadata_storage_connector} not implemented") + if passage_storage_connector == "postgres": + if not os.getenv("PGVECTOR_TEST_DB_URL"): + print("Skipping test, missing PG URI") + return + config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") + config.archival_storage_type = "postgres" + elif passage_storage_connector == "chroma": + print("testing chroma passage storage") + # nothing to do (should be config defaults) + else: + raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented") + config.save() + + # setup storage connectors data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES) passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector) - # load hugging face dataset - # dataset_name = "MemGPT/example_short_stories" - # dataset = load_dataset(dataset_name) - - # cache_dir = os.getenv("HF_DATASETS_CACHE") - # if cache_dir is None: - # # Construct the default path if the environment variable is not set. - # cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") - # print("HF Directory", cache_dir) + # load data name = "test_dataset" cache_dir = "CONTRIBUTING.md" + # TODO: load two different data sources + # clear out data data_source_conn.delete_table() passages_conn.delete_table() data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES) passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector) + assert ( + data_source_conn.size() == 0 + ), f"Expected 0 records, got {data_source_conn.size()}: {[vars(r) for r in data_source_conn.get_all()]}" + assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}" # test: load directory load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False) # cache_dir, @@ -59,8 +71,14 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector): print("Source", sources) # test to see if contained in storage + assert ( + len(passages_conn.get_all()) == passages_conn.size() + ), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all())}" passages = passages_conn.get_all({"data_source": name}) + print("Source", [p.data_source for p in passages]) + print("All sources", [p.data_source for p in passages_conn.get_all()]) assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}" + assert len(passages) == passages_conn.size(), f"Expected {passages_conn.size()} passages, but got {len(passages)}" assert [p.data_source == name for p in passages] print("Passages", passages) @@ -71,7 +89,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector): # test loading into an agent # create agent agent_config = AgentConfig( - name="test_agent", + name="memgpt_test_agent", persona=DEFAULT_PERSONA, human=DEFAULT_HUMAN, model=DEFAULT_MEMGPT_MODEL, @@ -81,7 +99,11 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector): conn = StorageConnector.get_storage_connector( storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config ) - assert conn.size() == 0 + conn.delete_table() + conn = StorageConnector.get_storage_connector( + storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config + ) + assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}" # attach data attach(agent=agent_config.name, data_source=name) diff --git a/tests/test_storage.py b/tests/test_storage.py index 8c9d57c0..e46915c8 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -56,8 +56,7 @@ def generate_messages(): return messages -@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqllite", "lancedb"]) -# @pytest.mark.parametrize("storage_connector", ["sqllite"]) +@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqlite"]) @pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY]) def test_storage(storage_connector, table_type): @@ -86,9 +85,9 @@ def test_storage(storage_connector, table_type): return config.archival_storage_type = "chroma" config.archival_storage_path = "./test_chroma" - if storage_connector == "sqllite": + if storage_connector == "sqlite": if table_type == TableType.ARCHIVAL_MEMORY: - print("Skipping test, sqllite only supported for recall memory") + print("Skipping test, sqlite only supported for recall memory") return config.recall_storage_type = "local"