add database test
This commit is contained in:
@@ -1,7 +1,7 @@
|
||||
import tempfile
|
||||
import asyncio
|
||||
import os
|
||||
from memgpt.connectors.connector import load_directory
|
||||
from memgpt.connectors.connector import load_directory, load_database, load_webpage
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.utils as utils
|
||||
@@ -20,7 +20,7 @@ import memgpt.interface # for printing to terminal
|
||||
import asyncio
|
||||
from datasets import load_dataset
|
||||
|
||||
def test_archival():
|
||||
def test_load_directory():
|
||||
# downloading hugging face dataset (if does not exist)
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
|
||||
@@ -58,4 +58,56 @@ def test_archival():
|
||||
results = query("cinderella be getting sick")
|
||||
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
|
||||
|
||||
test_archival()
|
||||
def test_load_webpage():
|
||||
pass
|
||||
|
||||
def test_load_database():
|
||||
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
import pandas as pd
|
||||
|
||||
db_path = "memgpt/personas/examples/sqldb/test.db"
|
||||
engine = create_engine(f'sqlite:///{db_path}')
|
||||
|
||||
# Create a MetaData object and reflect the database to get table information.
|
||||
metadata = MetaData()
|
||||
metadata.reflect(bind=engine)
|
||||
|
||||
# Get a list of table names from the reflected metadata.
|
||||
table_names = metadata.tables.keys()
|
||||
|
||||
print(table_names)
|
||||
|
||||
# Define a SQL query to retrieve data from a table (replace 'your_table_name' with your actual table name).
|
||||
query = f"SELECT * FROM {list(table_names)[0]}"
|
||||
|
||||
# Use Pandas to read data from the database into a DataFrame.
|
||||
df = pd.read_sql_query(query, engine)
|
||||
print(df)
|
||||
|
||||
load_database(
|
||||
name="tmp_db_dataset",
|
||||
#engine=engine,
|
||||
dump_path=db_path,
|
||||
query=f"SELECT * FROM {list(table_names)[0]}",
|
||||
)
|
||||
|
||||
persistence_manager = LocalStateManager(archival_memory_db="tmp_db_dataset")
|
||||
|
||||
# create agent
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT,
|
||||
DEFAULT_MEMGPT_MODEL,
|
||||
personas.get_persona_text(personas.DEFAULT),
|
||||
humans.get_human_text(humans.DEFAULT),
|
||||
memgpt.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
print("Successfully loaded into index")
|
||||
assert True
|
||||
|
||||
|
||||
|
||||
|
||||
#test_load_directory()
|
||||
test_load_database()
|
||||
Reference in New Issue
Block a user