Set get_all limit to None by default and add postgres to archival memory tests
This commit is contained in:
@@ -4,50 +4,62 @@ import os
|
||||
import pytest
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
|
||||
# import asyncio
|
||||
from datasets import load_dataset
|
||||
|
||||
# import memgpt
|
||||
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
from memgpt.cli.cli import attach
|
||||
from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
|
||||
# import memgpt.presets as presets
|
||||
# import memgpt.personas.personas as personas
|
||||
# import memgpt.humans.humans as humans
|
||||
# from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager
|
||||
|
||||
# # from memgpt.config import AgentConfig
|
||||
# from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL
|
||||
# import memgpt.interface # for printing to terminal
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("storage_connector", ["sqllite", "postgres"])
|
||||
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite"])
|
||||
@pytest.mark.parametrize("passage_storage_connector", ["chroma"])
|
||||
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
|
||||
@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"])
|
||||
def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
|
||||
# setup config
|
||||
config = MemGPTConfig()
|
||||
if metadata_storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.metadata_storage_type = "postgres"
|
||||
elif metadata_storage_connector == "sqlite":
|
||||
print("testing sqlite metadata")
|
||||
# nothing to do (should be config defaults)
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {metadata_storage_connector} not implemented")
|
||||
if passage_storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.archival_storage_type = "postgres"
|
||||
elif passage_storage_connector == "chroma":
|
||||
print("testing chroma passage storage")
|
||||
# nothing to do (should be config defaults)
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
|
||||
config.save()
|
||||
|
||||
# setup storage connectors
|
||||
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
||||
|
||||
# load hugging face dataset
|
||||
# dataset_name = "MemGPT/example_short_stories"
|
||||
# dataset = load_dataset(dataset_name)
|
||||
|
||||
# cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
# if cache_dir is None:
|
||||
# # Construct the default path if the environment variable is not set.
|
||||
# cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
# print("HF Directory", cache_dir)
|
||||
# load data
|
||||
name = "test_dataset"
|
||||
cache_dir = "CONTRIBUTING.md"
|
||||
|
||||
# TODO: load two different data sources
|
||||
|
||||
# clear out data
|
||||
data_source_conn.delete_table()
|
||||
passages_conn.delete_table()
|
||||
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
||||
assert (
|
||||
data_source_conn.size() == 0
|
||||
), f"Expected 0 records, got {data_source_conn.size()}: {[vars(r) for r in data_source_conn.get_all()]}"
|
||||
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"
|
||||
|
||||
# test: load directory
|
||||
load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False) # cache_dir,
|
||||
@@ -59,8 +71,14 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
print("Source", sources)
|
||||
|
||||
# test to see if contained in storage
|
||||
assert (
|
||||
len(passages_conn.get_all()) == passages_conn.size()
|
||||
), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all())}"
|
||||
passages = passages_conn.get_all({"data_source": name})
|
||||
print("Source", [p.data_source for p in passages])
|
||||
print("All sources", [p.data_source for p in passages_conn.get_all()])
|
||||
assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}"
|
||||
assert len(passages) == passages_conn.size(), f"Expected {passages_conn.size()} passages, but got {len(passages)}"
|
||||
assert [p.data_source == name for p in passages]
|
||||
print("Passages", passages)
|
||||
|
||||
@@ -71,7 +89,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
# test loading into an agent
|
||||
# create agent
|
||||
agent_config = AgentConfig(
|
||||
name="test_agent",
|
||||
name="memgpt_test_agent",
|
||||
persona=DEFAULT_PERSONA,
|
||||
human=DEFAULT_HUMAN,
|
||||
model=DEFAULT_MEMGPT_MODEL,
|
||||
@@ -81,7 +99,11 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
conn = StorageConnector.get_storage_connector(
|
||||
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
|
||||
)
|
||||
assert conn.size() == 0
|
||||
conn.delete_table()
|
||||
conn = StorageConnector.get_storage_connector(
|
||||
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
|
||||
)
|
||||
assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
|
||||
|
||||
# attach data
|
||||
attach(agent=agent_config.name, data_source=name)
|
||||
|
||||
@@ -56,8 +56,7 @@ def generate_messages():
|
||||
return messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqllite", "lancedb"])
|
||||
# @pytest.mark.parametrize("storage_connector", ["sqllite"])
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqlite"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
|
||||
def test_storage(storage_connector, table_type):
|
||||
|
||||
@@ -86,9 +85,9 @@ def test_storage(storage_connector, table_type):
|
||||
return
|
||||
config.archival_storage_type = "chroma"
|
||||
config.archival_storage_path = "./test_chroma"
|
||||
if storage_connector == "sqllite":
|
||||
if storage_connector == "sqlite":
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
print("Skipping test, sqllite only supported for recall memory")
|
||||
print("Skipping test, sqlite only supported for recall memory")
|
||||
return
|
||||
config.recall_storage_type = "local"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user