Fix bug with chroma vector query

This commit is contained in:
Sarah Wooders
2023-12-27 14:40:11 +04:00
parent b598f3e2d4
commit 515d9d0f62
3 changed files with 21 additions and 11 deletions

View File

@@ -136,16 +136,14 @@ class ChromaStorageConnector(StorageConnector):
def insert(self, record: Record):
ids, documents, embeddings, metadatas = self.format_records([record])
if not any(embeddings):
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
else:
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
raise ValueError("Embeddings must be provided to chroma")
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
def insert_many(self, records: List[Record], show_progress=True):
ids, documents, embeddings, metadatas = self.format_records(records)
if not any(embeddings):
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
else:
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
raise ValueError("Embeddings must be provided to chroma")
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
def delete(self, filters: Optional[Dict] = {}):
ids, filters = self.get_filters(filters)
@@ -170,7 +168,17 @@ class ChromaStorageConnector(StorageConnector):
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
ids, filters = self.get_filters(filters)
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters)
return self.results_to_records(results)
# flatten, since we only have one query vector
flattened_results = {}
for key, value in results.items():
if value:
flattened_results[key] = value[0]
assert len(value) == 1, f"Value is size {len(value)}: {value}"
else:
flattened_results[key] = value
return self.results_to_records(flattened_results)
def query_date(self, start_date, end_date, start=None, count=None):
raise ValueError("Cannot run query_date with chroma")

View File

@@ -8,8 +8,8 @@ from .constants import TIMEOUT
from .utils import configure_memgpt
def test_configure_memgpt():
configure_memgpt()
# def test_configure_memgpt():
# configure_memgpt()
def test_save_load():
@@ -41,5 +41,5 @@ def test_save_load():
if __name__ == "__main__":
test_configure_memgpt()
# test_configure_memgpt()
test_save_load()

View File

@@ -100,6 +100,7 @@ def test_storage(storage_connector, table_type):
config.embedding_endpoint = None
config.embedding_dim = 384
config.save()
embed_model = embedding_model()
# create agent
agent_config = AgentConfig(
@@ -168,11 +169,12 @@ def test_storage(storage_connector, table_type):
assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}"
# test: query (vector)
if embed_model:
if table_type == TableType.ARCHIVAL_MEMORY:
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = conn.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
print("Archival memory results", res)
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
# test optional query functions for recall memory