Support metadata table via storage connectors for data sources
This commit is contained in:
@@ -1,12 +1,14 @@
|
||||
# import tempfile
|
||||
# import asyncio
|
||||
import os
|
||||
import pytest
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
|
||||
# import asyncio
|
||||
# from datasets import load_dataset
|
||||
from datasets import load_dataset
|
||||
|
||||
# import memgpt
|
||||
# from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
|
||||
# import memgpt.presets as presets
|
||||
# import memgpt.personas.personas as personas
|
||||
@@ -18,205 +20,53 @@ import os
|
||||
# import memgpt.interface # for printing to terminal
|
||||
|
||||
|
||||
def test_postgres():
|
||||
return
|
||||
# @pytest.mark.parametrize("storage_connector", ["sqllite", "postgres"])
|
||||
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite"])
|
||||
@pytest.mark.parametrize("passage_storage_connector", ["chroma"])
|
||||
def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
|
||||
# override config path with enviornment variable
|
||||
# TODO: make into temporary file
|
||||
os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg"
|
||||
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
config = memgpt.config.MemGPTConfig(archival_storage_type="postgres", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
print(config)
|
||||
config.save()
|
||||
# exit()
|
||||
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
||||
|
||||
name = "tmp_hf_dataset2"
|
||||
# load hugging face dataset
|
||||
# dataset_name = "MemGPT/example_short_stories"
|
||||
# dataset = load_dataset(dataset_name)
|
||||
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
# cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
# if cache_dir is None:
|
||||
# # Construct the default path if the environment variable is not set.
|
||||
# cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
# print("HF Directory", cache_dir)
|
||||
name = "test_dataset"
|
||||
cache_dir = "CONTRIBUTING.md"
|
||||
|
||||
cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
if cache_dir is None:
|
||||
# Construct the default path if the environment variable is not set.
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
# clear out data
|
||||
data_source_conn.delete_table()
|
||||
passages_conn.delete_table()
|
||||
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
||||
|
||||
load_directory(
|
||||
name=name,
|
||||
input_dir=cache_dir,
|
||||
recursive=True,
|
||||
)
|
||||
# test: load directory
|
||||
load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False) # cache_dir,
|
||||
|
||||
# test to see if contained in storage
|
||||
sources = data_source_conn.get_all({"name": name})
|
||||
assert len(sources) == 1, f"Expected 1 source, but got {len(sources)}"
|
||||
assert sources[0].name == name, f"Expected name {name}, but got {sources[0].name}"
|
||||
print("Source", sources)
|
||||
|
||||
def test_lancedb():
|
||||
return
|
||||
# test to see if contained in storage
|
||||
passages = passages_conn.get_all({"data_source": name})
|
||||
assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}"
|
||||
assert [p.data_source == name for p in passages]
|
||||
print("Passages", passages)
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
||||
import lancedb # Try to import again after installing
|
||||
# test: listing sources
|
||||
sources = data_source_conn.get_all()
|
||||
print("All sources", [s.name for s in sources])
|
||||
|
||||
# override config path with enviornment variable
|
||||
# TODO: make into temporary file
|
||||
os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg"
|
||||
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
print(config)
|
||||
config.save()
|
||||
|
||||
# loading dataset from hugging face
|
||||
name = "tmp_hf_dataset"
|
||||
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
|
||||
cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
if cache_dir is None:
|
||||
# Construct the default path if the environment variable is not set.
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
|
||||
config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb")
|
||||
|
||||
load_directory(
|
||||
name=name,
|
||||
input_dir=cache_dir,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
|
||||
def test_chroma():
|
||||
return
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"])
|
||||
import chromadb # Try to import again after installing
|
||||
|
||||
# override config path with enviornment variable
|
||||
# TODO: make into temporary file
|
||||
os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg"
|
||||
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
config = memgpt.config.MemGPTConfig(archival_storage_type="chroma", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
|
||||
print(config)
|
||||
config.save()
|
||||
# exit()
|
||||
|
||||
name = "tmp_hf_dataset"
|
||||
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
|
||||
cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
if cache_dir is None:
|
||||
# Construct the default path if the environment variable is not set.
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
|
||||
config = memgpt.config.MemGPTConfig(archival_storage_type="chroma")
|
||||
|
||||
load_directory(
|
||||
name=name,
|
||||
input_dir=cache_dir,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
|
||||
def test_load_directory():
|
||||
return
|
||||
# downloading hugging face dataset (if does not exist)
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
|
||||
cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
|
||||
if cache_dir is None:
|
||||
# Construct the default path if the environment variable is not set.
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
|
||||
# load directory
|
||||
print("Loading dataset into index...")
|
||||
print(cache_dir)
|
||||
load_directory(
|
||||
name="tmp_hf_dataset",
|
||||
input_dir=cache_dir,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
# create agents with defaults
|
||||
agent_config = AgentConfig(
|
||||
persona=personas.DEFAULT,
|
||||
human=humans.DEFAULT,
|
||||
model=DEFAULT_MEMGPT_MODEL,
|
||||
data_source="tmp_hf_dataset",
|
||||
)
|
||||
|
||||
# create state manager based off loaded data
|
||||
persistence_manager = LocalStateManager(agent_config=agent_config)
|
||||
|
||||
# create agent
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT_PRESET,
|
||||
agent_config,
|
||||
DEFAULT_MEMGPT_MODEL,
|
||||
personas.get_persona_text(personas.DEFAULT),
|
||||
humans.get_human_text(humans.DEFAULT),
|
||||
memgpt.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
|
||||
def query(q):
|
||||
res = asyncio.run(memgpt_agent.archival_memory_search(q))
|
||||
return res
|
||||
|
||||
results = query("cinderella be getting sick")
|
||||
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
|
||||
|
||||
|
||||
def test_load_webpage():
|
||||
pass
|
||||
|
||||
|
||||
def test_load_database():
|
||||
return
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
import pandas as pd
|
||||
|
||||
db_path = "memgpt/personas/examples/sqldb/test.db"
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
|
||||
# Create a MetaData object and reflect the database to get table information.
|
||||
metadata = MetaData()
|
||||
metadata.reflect(bind=engine)
|
||||
|
||||
# Get a list of table names from the reflected metadata.
|
||||
table_names = metadata.tables.keys()
|
||||
|
||||
print(table_names)
|
||||
|
||||
# Define a SQL query to retrieve data from a table (replace 'your_table_name' with your actual table name).
|
||||
query = f"SELECT * FROM {list(table_names)[0]}"
|
||||
|
||||
# Use Pandas to read data from the database into a DataFrame.
|
||||
df = pd.read_sql_query(query, engine)
|
||||
print(df)
|
||||
|
||||
load_database(
|
||||
name="tmp_db_dataset",
|
||||
# engine=engine,
|
||||
dump_path=db_path,
|
||||
query=f"SELECT * FROM {list(table_names)[0]}",
|
||||
)
|
||||
|
||||
# create agents with defaults
|
||||
agent_config = AgentConfig(
|
||||
persona=personas.DEFAULT,
|
||||
human=humans.DEFAULT,
|
||||
model=DEFAULT_MEMGPT_MODEL,
|
||||
data_source="tmp_hf_dataset",
|
||||
)
|
||||
|
||||
# create state manager based off loaded data
|
||||
persistence_manager = LocalStateManager(agent_config=agent_config)
|
||||
|
||||
# create agent
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT,
|
||||
agent_config,
|
||||
DEFAULT_MEMGPT_MODEL,
|
||||
personas.get_persona_text(personas.DEFAULT),
|
||||
humans.get_human_text(humans.DEFAULT),
|
||||
memgpt.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
print("Successfully loaded into index")
|
||||
assert True
|
||||
# test: delete source
|
||||
data_source_conn.delete({"name": name})
|
||||
passages_conn.delete({"data_source": name})
|
||||
assert len(data_source_conn.get_all({"name": name})) == 0
|
||||
assert len(passages_conn.get_all({"data_source": name})) == 0
|
||||
|
||||
Reference in New Issue
Block a user