add database test

This commit is contained in:
Sarah Wooders
2023-10-26 15:30:31 -07:00
parent 85ac22ff9e
commit bbacf0fb33
4 changed files with 110 additions and 36 deletions

View File

@@ -1,7 +1,7 @@
import tempfile
import asyncio
import os
from memgpt.connectors.connector import load_directory
from memgpt.connectors.connector import load_directory, load_database, load_webpage
import memgpt.agent as agent
import memgpt.system as system
import memgpt.utils as utils
@@ -20,7 +20,7 @@ import memgpt.interface # for printing to terminal
import asyncio
from datasets import load_dataset
def test_archival():
def test_load_directory():
# downloading hugging face dataset (if does not exist)
dataset = load_dataset("MemGPT/example_short_stories")
@@ -58,4 +58,56 @@ def test_archival():
results = query("cinderella be getting sick")
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
test_archival()
def test_load_webpage():
pass
def test_load_database():
from sqlalchemy import create_engine, MetaData
import pandas as pd
db_path = "memgpt/personas/examples/sqldb/test.db"
engine = create_engine(f'sqlite:///{db_path}')
# Create a MetaData object and reflect the database to get table information.
metadata = MetaData()
metadata.reflect(bind=engine)
# Get a list of table names from the reflected metadata.
table_names = metadata.tables.keys()
print(table_names)
# Define a SQL query to retrieve data from a table (replace 'your_table_name' with your actual table name).
query = f"SELECT * FROM {list(table_names)[0]}"
# Use Pandas to read data from the database into a DataFrame.
df = pd.read_sql_query(query, engine)
print(df)
load_database(
name="tmp_db_dataset",
#engine=engine,
dump_path=db_path,
query=f"SELECT * FROM {list(table_names)[0]}",
)
persistence_manager = LocalStateManager(archival_memory_db="tmp_db_dataset")
# create agent
memgpt_agent = presets.use_preset(
presets.DEFAULT,
DEFAULT_MEMGPT_MODEL,
personas.get_persona_text(personas.DEFAULT),
humans.get_human_text(humans.DEFAULT),
memgpt.interface,
persistence_manager,
)
print("Successfully loaded into index")
assert True
#test_load_directory()
test_load_database()