Bugfixes and test updates for passing tests for both postgres and chroma
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import uuid
|
||||
import subprocess
|
||||
import sys
|
||||
import pytest
|
||||
@@ -11,7 +12,7 @@ import pytest
|
||||
import pgvector # Try to import again after installing
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
from memgpt.connectors.db import PostgresStorageConnector, LanceDBConnector
|
||||
from memgpt.connectors.db import SQLStorageConnector, LanceDBConnector
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Message, Passage
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
@@ -22,13 +23,13 @@ from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMA
|
||||
import argparse
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
|
||||
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
start_date = datetime(2009, 10, 5, 18, 00)
|
||||
dates = [start_date - timedelta(weeks=1), start_date, start_date + timedelta(weeks=1)]
|
||||
dates = [start_date, start_date - timedelta(weeks=1), start_date + timedelta(weeks=1)]
|
||||
roles = ["user", "agent", "agent"]
|
||||
agent_ids = ["agent1", "agent2", "agent1"]
|
||||
ids = ["test1", "test2", "test3"] # TODO: generate unique uuid
|
||||
ids = [uuid.uuid4(), uuid.uuid4(), uuid.uuid4()]
|
||||
user_id = "test_user"
|
||||
|
||||
|
||||
@@ -41,16 +42,7 @@ def generate_passages(embed_model):
|
||||
embedding = None
|
||||
if embed_model:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
passages.append(
|
||||
Passage(
|
||||
user_id=user_id,
|
||||
text=text,
|
||||
agent_id=agent_id,
|
||||
embedding=embedding,
|
||||
data_source="test_source",
|
||||
id=id,
|
||||
)
|
||||
)
|
||||
passages.append(Passage(user_id=user_id, text=text, agent_id=agent_id, embedding=embedding, data_source="test_source", id=id))
|
||||
return passages
|
||||
|
||||
|
||||
@@ -65,7 +57,8 @@ def generate_messages():
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "lancedb"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.ARCHIVAL_MEMORY, TableType.RECALL_MEMORY])
|
||||
# @pytest.mark.parametrize("storage_connector", ["postgres"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
|
||||
def test_storage(storage_connector, table_type):
|
||||
|
||||
# setup memgpt config
|
||||
@@ -88,10 +81,16 @@ def test_storage(storage_connector, table_type):
|
||||
config.archival_storage_type = "lancedb"
|
||||
config.recall_storage_type = "lancedb"
|
||||
if storage_connector == "chroma":
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
print("Skipping test, chroma only supported for archival memory")
|
||||
return
|
||||
config.archival_storage_type = "chroma"
|
||||
config.recall_storage_type = "chroma"
|
||||
config.recall_storage_path = "./test_chroma"
|
||||
config.archival_storage_path = "./test_chroma"
|
||||
if storage_connector == "local":
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
print("Skipping test, local only supported for recall memory")
|
||||
return
|
||||
config.recall_storage_type = "local"
|
||||
|
||||
# get embedding model
|
||||
embed_model = None
|
||||
@@ -116,7 +115,8 @@ def test_storage(storage_connector, table_type):
|
||||
|
||||
# create storage connector
|
||||
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
|
||||
conn.delete() # clear out data
|
||||
# conn.client.delete_collection(conn.collection.name) # clear out data
|
||||
conn.delete_table()
|
||||
conn = StorageConnector.get_storage_connector(storage_type=storage_connector, table_type=table_type, agent_config=agent_config)
|
||||
|
||||
# override filters
|
||||
@@ -161,6 +161,7 @@ def test_storage(storage_connector, table_type):
|
||||
assert len(all_records) == 1, f"Expected 1 records, got {len(all_records)}"
|
||||
|
||||
# test: get
|
||||
print("GET ID", ids[0], records)
|
||||
res = conn.get(id=ids[0])
|
||||
assert res.text == texts[0], f"Expected {texts[0]}, got {res.text}"
|
||||
|
||||
@@ -178,8 +179,8 @@ def test_storage(storage_connector, table_type):
|
||||
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
||||
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
||||
|
||||
# test optional query functions
|
||||
if storage_connector != "chroma":
|
||||
# test optional query functions for recall memory
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
# test: query_text
|
||||
query = "CindereLLa"
|
||||
res = conn.query_text(query)
|
||||
@@ -187,12 +188,13 @@ def test_storage(storage_connector, table_type):
|
||||
assert "Cinderella" in res[0].text, f"Expected 'Cinderella' in results, but got {res[0].text}"
|
||||
|
||||
# test: query_date (recall memory only)
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
print("Testing recall memory date search")
|
||||
start_date = start_date - timedelta(days=1)
|
||||
end_date = start_date + timedelta(days=1)
|
||||
res = conn.query_date(start_date=start_date, end_date=end_date)
|
||||
assert len(res) == 1, f"Expected 1 result, got {len(res): {res}}"
|
||||
print("Testing recall memory date search")
|
||||
start_date = datetime(2009, 10, 5, 18, 00)
|
||||
start_date = start_date - timedelta(days=1)
|
||||
end_date = start_date + timedelta(days=1)
|
||||
res = conn.query_date(start_date=start_date, end_date=end_date)
|
||||
print("DATE", res)
|
||||
assert len(res) == 1, f"Expected 1 result, got {len(res)}: {res}"
|
||||
|
||||
# test: delete
|
||||
conn.delete({"id": ids[0]})
|
||||
|
||||
Reference in New Issue
Block a user