Remove broken tests from chroma merge (#584)

This commit is contained in:
Sarah Wooders
2023-12-05 22:09:44 -08:00
committed by GitHub
parent ff54e2f04a
commit dcceb8671f

View File

@@ -112,75 +112,6 @@ def test_chroma():
)
def test_postgres():
# override config path with enviornment variable
# TODO: make into temporary file
os.environ["MEMGPT_CONFIG_PATH"] = "/Users/sarahwooders/repos/MemGPT/test_config.cfg"
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
config = memgpt.config.MemGPTConfig(archival_storage_type="postgres", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
print(config)
config.save()
# exit()
name = "tmp_hf_dataset2"
dataset = load_dataset("MemGPT/example_short_stories")
cache_dir = os.getenv("HF_DATASETS_CACHE")
if cache_dir is None:
# Construct the default path if the environment variable is not set.
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
load_directory(
name=name,
input_dir=cache_dir,
recursive=True,
)
def test_chroma():
import chromadb
# override config path with enviornment variable
# TODO: make into temporary file
os.environ["MEMGPT_CONFIG_PATH"] = "/Users/sarahwooders/repos/MemGPT/test_config.cfg"
print("env", os.getenv("MEMGPT_CONFIG_PATH"))
config = memgpt.config.MemGPTConfig(archival_storage_type="chroma", config_path=os.getenv("MEMGPT_CONFIG_PATH"))
print(config)
config.save()
# exit()
name = "tmp_hf_dataset"
dataset = load_dataset("MemGPT/example_short_stories")
cache_dir = os.getenv("HF_DATASETS_CACHE")
if cache_dir is None:
# Construct the default path if the environment variable is not set.
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
config = memgpt.config.MemGPTConfig(archival_storage_type="chroma")
load_directory(
name=name,
input_dir=cache_dir,
recursive=True,
)
# index = memgpt.embeddings.Index(name)
## query chroma
##chroma_client = chromadb.Client()
# chroma_client = chromadb.PersistentClient(path="/Users/sarahwooders/repos/MemGPT/chromadb")
# collection = chroma_client.get_collection(name=name)
# results = collection.query(
# query_texts=["cinderella be getting sick"],
# n_results=2
# )
# print(results)
# assert len(results) == 2, f"Expected 2 results, but got {len(results)}"
def test_load_directory():
return
# downloading hugging face dataset (if does not exist)