Fix bug with chroma vector query
This commit is contained in:
@@ -136,16 +136,14 @@ class ChromaStorageConnector(StorageConnector):
|
||||
def insert(self, record: Record):
|
||||
ids, documents, embeddings, metadatas = self.format_records([record])
|
||||
if not any(embeddings):
|
||||
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
else:
|
||||
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
|
||||
raise ValueError("Embeddings must be provided to chroma")
|
||||
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
|
||||
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
ids, documents, embeddings, metadatas = self.format_records(records)
|
||||
if not any(embeddings):
|
||||
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
else:
|
||||
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
|
||||
raise ValueError("Embeddings must be provided to chroma")
|
||||
self.collection.add(documents=documents, embeddings=embeddings, ids=ids, metadatas=metadatas)
|
||||
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
ids, filters = self.get_filters(filters)
|
||||
@@ -170,7 +168,17 @@ class ChromaStorageConnector(StorageConnector):
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
ids, filters = self.get_filters(filters)
|
||||
results = self.collection.query(query_embeddings=[query_vec], n_results=top_k, include=self.include, where=filters)
|
||||
return self.results_to_records(results)
|
||||
|
||||
# flatten, since we only have one query vector
|
||||
flattened_results = {}
|
||||
for key, value in results.items():
|
||||
if value:
|
||||
flattened_results[key] = value[0]
|
||||
assert len(value) == 1, f"Value is size {len(value)}: {value}"
|
||||
else:
|
||||
flattened_results[key] = value
|
||||
|
||||
return self.results_to_records(flattened_results)
|
||||
|
||||
def query_date(self, start_date, end_date, start=None, count=None):
|
||||
raise ValueError("Cannot run query_date with chroma")
|
||||
|
||||
@@ -8,8 +8,8 @@ from .constants import TIMEOUT
|
||||
from .utils import configure_memgpt
|
||||
|
||||
|
||||
def test_configure_memgpt():
|
||||
configure_memgpt()
|
||||
# def test_configure_memgpt():
|
||||
# configure_memgpt()
|
||||
|
||||
|
||||
def test_save_load():
|
||||
@@ -41,5 +41,5 @@ def test_save_load():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_configure_memgpt()
|
||||
# test_configure_memgpt()
|
||||
test_save_load()
|
||||
|
||||
@@ -100,6 +100,7 @@ def test_storage(storage_connector, table_type):
|
||||
config.embedding_endpoint = None
|
||||
config.embedding_dim = 384
|
||||
config.save()
|
||||
embed_model = embedding_model()
|
||||
|
||||
# create agent
|
||||
agent_config = AgentConfig(
|
||||
@@ -168,11 +169,12 @@ def test_storage(storage_connector, table_type):
|
||||
assert conn.size(filters={"role": "user"}) == 1, f"Expected 1 record, got {conn.size(filters={'role': 'user'})}"
|
||||
|
||||
# test: query (vector)
|
||||
if embed_model:
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
query = "why was she crying"
|
||||
query_vec = embed_model.get_text_embedding(query)
|
||||
res = conn.query(None, query_vec, top_k=2)
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
print("Archival memory results", res)
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
# test optional query functions for recall memory
|
||||
|
||||
Reference in New Issue
Block a user