This commit is contained in:
Sarah Wooders
2023-10-26 16:08:25 -07:00
parent 1bc8e7a601
commit 0ab3d098d2
26 changed files with 600 additions and 851 deletions

View File

@@ -9,10 +9,7 @@ import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
from memgpt.persistence_manager import (
InMemoryStateManager,
LocalStateManager
)
from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager
from memgpt.config import Config
from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL
from memgpt.connectors import connector
@@ -20,6 +17,7 @@ import memgpt.interface # for printing to terminal
import asyncio
from datasets import load_dataset
def test_load_directory():
# downloading hugging face dataset (if does not exist)
dataset = load_dataset("MemGPT/example_short_stories")
@@ -30,12 +28,12 @@ def test_load_directory():
# Construct the default path if the environment variable is not set.
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
# load directory
# load directory
print("Loading dataset into index...")
print(cache_dir)
load_directory(
name="tmp_hf_dataset",
input_dir=cache_dir,
input_dir=cache_dir,
recursive=True,
)
@@ -51,23 +49,25 @@ def test_load_directory():
memgpt.interface,
persistence_manager,
)
def query(q):
def query(q):
res = asyncio.run(memgpt_agent.archival_memory_search(q))
return res
results = query("cinderella be getting sick")
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
def test_load_webpage():
def test_load_webpage():
pass
def test_load_database():
def test_load_database():
from sqlalchemy import create_engine, MetaData
import pandas as pd
db_path = "memgpt/personas/examples/sqldb/test.db"
engine = create_engine(f'sqlite:///{db_path}')
engine = create_engine(f"sqlite:///{db_path}")
# Create a MetaData object and reflect the database to get table information.
metadata = MetaData()
@@ -87,7 +87,7 @@ def test_load_database():
load_database(
name="tmp_db_dataset",
#engine=engine,
# engine=engine,
dump_path=db_path,
query=f"SELECT * FROM {list(table_names)[0]}",
)
@@ -107,7 +107,5 @@ def test_load_database():
assert True
#test_load_directory()
test_load_database()
# test_load_directory()
test_load_database()