From 515d9d0f629afa6dabcd801c4e6475ca3c877512 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 27 Dec 2023 14:40:11 +0400 Subject: [PATCH] Fix bug with chroma vector query --- memgpt/connectors/chroma.py | 22 +++++++++++++++------- tests/test_cli.py | 6 +++--- tests/test_storage.py | 4 +++- 3 files changed, 21 insertions(+), 11 deletions(-) diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py index 09622abc..8e3f84f3 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/connectors/chroma.py @@ -136,16 +136,14 @@ class ChromaStorageConnector(StorageConnector): def insert(self, record: Record): ids, documents, embeddings, metadatas = self.format_records([record]) if not any(embeddings): - self.collection.add(documents=documents, ids=ids, metadatas=metadatas) - else: - self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas) + raise ValueError("Embeddings must be provided to chroma") + self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas) def insert_many(self, records: List[Record], show_progress=True): ids, documents, embeddings, metadatas = self.format_records(records) if not any(embeddings): - self.collection.add(documents=documents, ids=ids, metadatas=metadatas) - else: - self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas) + raise ValueError("Embeddings must be provided to chroma") + self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas) def delete(self, filters: Optional[Dict] = {}): ids, filters = self.get_filters(filters) @@ -170,7 +168,17 @@ class ChromaStorageConnector(StorageConnector): def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]: ids, filters = self.get_filters(filters) results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters) - return self.results_to_records(results) + + # flatten, since we only have one query vector + flattened_results = {} + for key, value in results.items(): + if value: + flattened_results[key] = value[0] + assert len(value) == 1, f"Value is size {len(value)}: {value}" + else: + flattened_results[key] = value + + return self.results_to_records(flattened_results) def query_date(self, start_date, end_date, start=None, count=None): raise ValueError("Cannot run query_date with chroma") diff --git a/tests/test_cli.py b/tests/test_cli.py index a7a7774c..0fbb5c56 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,8 +8,8 @@ from .constants import TIMEOUT from .utils import configure_memgpt -def test_configure_memgpt(): - configure_memgpt() +# def test_configure_memgpt(): +# configure_memgpt() def test_save_load(): @@ -41,5 +41,5 @@ def test_save_load(): if __name__ == "__main__": - test_configure_memgpt() + # test_configure_memgpt() test_save_load() diff --git a/tests/test_storage.py b/tests/test_storage.py index a106eb9e..8aea809c 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -100,6 +100,7 @@ def test_storage(storage_connector, table_type): config.embedding_endpoint = None config.embedding_dim = 384 config.save() + embed_model = embedding_model() # create agent agent_config = AgentConfig( @@ -168,11 +169,12 @@ def test_storage(storage_connector, table_type): assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}" # test: query (vector) - if embed_model: + if table_type == TableType.ARCHIVAL_MEMORY: query = "why was she crying" query_vec = embed_model.get_text_embedding(query) res = conn.query(None, query_vec, top_k=2) assert len(res) == 2, f"Expected 2 results, got {len(res)}" + print("Archival memory results", res) assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}" # test optional query functions for recall memory