100 lines
3.1 KiB
Python
100 lines
3.1 KiB
Python
import os
|
|
import subprocess
|
|
import sys
|
|
|
|
subprocess.check_call(
|
|
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
|
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
|
import pgvector # Try to import again after installing
|
|
|
|
from memgpt.connectors.storage import StorageConnector, Passage
|
|
from memgpt.connectors.db import PostgresStorageConnector
|
|
from memgpt.embeddings import embedding_model
|
|
from memgpt.config import MemGPTConfig, AgentConfig
|
|
|
|
import argparse
|
|
|
|
|
|
def test_postgres_openai():
|
|
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
|
|
if os.getenv("OPENAI_API_KEY") is None:
|
|
return # soft pass
|
|
|
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
|
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
config.archival_storage_uri = config.archival_storage_uri.replace(
|
|
"postgres://", "postgresql://"
|
|
) # https://stackoverflow.com/a/64698899
|
|
config.save()
|
|
print(config)
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = PostgresStorageConnector(name="test-openai")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
# TODO fix (causes a hang for some reason)
|
|
# print("deleting...")
|
|
# db.delete()
|
|
# print("...finished")
|
|
|
|
|
|
def test_postgres_local():
|
|
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
|
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
|
|
|
config = MemGPTConfig(
|
|
archival_storage_type="postgres",
|
|
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
|
embedding_model="local",
|
|
embedding_dim=384, # use HF model
|
|
)
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
config.archival_storage_uri = config.archival_storage_uri.replace(
|
|
"postgres://", "postgresql://"
|
|
) # https://stackoverflow.com/a/64698899
|
|
config.save()
|
|
print(config)
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = PostgresStorageConnector(name="test-local")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
# TODO fix (causes a hang for some reason)
|
|
# print("deleting...")
|
|
# db.delete()
|
|
# print("...finished")
|
|
|
|
|
|
# test_postgres()
|