reformat
This commit is contained in:
@@ -9,10 +9,7 @@ import memgpt.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
from memgpt.persistence_manager import (
|
||||
InMemoryStateManager,
|
||||
LocalStateManager
|
||||
)
|
||||
from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager
|
||||
from memgpt.config import Config
|
||||
from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL
|
||||
from memgpt.connectors import connector
|
||||
@@ -20,6 +17,7 @@ import memgpt.interface # for printing to terminal
|
||||
import asyncio
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def test_load_directory():
|
||||
# downloading hugging face dataset (if does not exist)
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
@@ -30,12 +28,12 @@ def test_load_directory():
|
||||
# Construct the default path if the environment variable is not set.
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
|
||||
# load directory
|
||||
# load directory
|
||||
print("Loading dataset into index...")
|
||||
print(cache_dir)
|
||||
load_directory(
|
||||
name="tmp_hf_dataset",
|
||||
input_dir=cache_dir,
|
||||
input_dir=cache_dir,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
@@ -51,23 +49,25 @@ def test_load_directory():
|
||||
memgpt.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
def query(q):
|
||||
|
||||
def query(q):
|
||||
res = asyncio.run(memgpt_agent.archival_memory_search(q))
|
||||
return res
|
||||
|
||||
results = query("cinderella be getting sick")
|
||||
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
|
||||
|
||||
def test_load_webpage():
|
||||
|
||||
def test_load_webpage():
|
||||
pass
|
||||
|
||||
def test_load_database():
|
||||
|
||||
def test_load_database():
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
import pandas as pd
|
||||
|
||||
db_path = "memgpt/personas/examples/sqldb/test.db"
|
||||
engine = create_engine(f'sqlite:///{db_path}')
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
|
||||
# Create a MetaData object and reflect the database to get table information.
|
||||
metadata = MetaData()
|
||||
@@ -87,7 +87,7 @@ def test_load_database():
|
||||
|
||||
load_database(
|
||||
name="tmp_db_dataset",
|
||||
#engine=engine,
|
||||
# engine=engine,
|
||||
dump_path=db_path,
|
||||
query=f"SELECT * FROM {list(table_names)[0]}",
|
||||
)
|
||||
@@ -107,7 +107,5 @@ def test_load_database():
|
||||
assert True
|
||||
|
||||
|
||||
|
||||
|
||||
#test_load_directory()
|
||||
test_load_database()
|
||||
# test_load_directory()
|
||||
test_load_database()
|
||||
|
||||
Reference in New Issue
Block a user