* mark depricated API section * CLI bug fixes for azure * check azure before running * Update README.md * Update README.md * bug fix with persona loading * remove print * make errors for cli flags more clear * format * fix imports * fix imports * add prints * update lock * update config fields * cleanup config loading * commit * remove asserts * refactor configure * put into different functions * add embedding default * pass in config * fixes * allow overriding openai embedding endpoint * black * trying to patch tests (some circular import errors) * update flags and docs * patched support for local llms using endpoint and endpoint type passed via configs, not env vars * missing files * fix naming * fix import * fix two runtime errors * patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included * disable debug messages * made error message for failed load more informative * don't print dynamic linking function warning unless --debug * updated tests to work with new cli workflow (disabled openai config test for now) * added skips for tests when vars are missing * update bad arg * revise test to soft pass on empty string too * don't run configure twice * extend timeout (try to pass against nltk download) * update defaults * typo with endpoint type default * patch runtime errors for when model is None * catching another case of 'x in model' when model is None (preemptively) * allow overrides to local llm related config params * made model wrapper selection from a list vs raw input * update test for select instead of input * Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry * updated error messages to be more informative with links to readthedocs * add back gpt3.5-turbo --------- Co-authored-by: cpacker <packercharles@gmail.com>
105 lines
3.4 KiB
Python
105 lines
3.4 KiB
Python
import os
|
|
import subprocess
|
|
import sys
|
|
import pytest
|
|
|
|
subprocess.check_call(
|
|
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
|
|
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
|
import pgvector # Try to import again after installing
|
|
|
|
from memgpt.connectors.storage import StorageConnector, Passage
|
|
from memgpt.connectors.db import PostgresStorageConnector
|
|
from memgpt.embeddings import embedding_model
|
|
from memgpt.config import MemGPTConfig, AgentConfig
|
|
|
|
import argparse
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
|
|
def test_postgres_openai():
|
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
|
return # soft pass
|
|
if not os.getenv("OPENAI_API_KEY"):
|
|
return # soft pass
|
|
|
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
|
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
config.archival_storage_uri = config.archival_storage_uri.replace(
|
|
"postgres://", "postgresql://"
|
|
) # https://stackoverflow.com/a/64698899
|
|
config.save()
|
|
print(config)
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = PostgresStorageConnector(name="test-openai")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
# TODO fix (causes a hang for some reason)
|
|
# print("deleting...")
|
|
# db.delete()
|
|
# print("...finished")
|
|
|
|
|
|
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
|
|
def test_postgres_local():
|
|
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
|
return
|
|
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
|
|
|
config = MemGPTConfig(
|
|
archival_storage_type="postgres",
|
|
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
|
embedding_endpoint_type="local",
|
|
embedding_dim=384, # use HF model
|
|
)
|
|
print(config.config_path)
|
|
assert config.archival_storage_uri is not None
|
|
config.archival_storage_uri = config.archival_storage_uri.replace(
|
|
"postgres://", "postgresql://"
|
|
) # https://stackoverflow.com/a/64698899
|
|
config.save()
|
|
print(config)
|
|
|
|
embed_model = embedding_model()
|
|
|
|
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
|
|
|
db = PostgresStorageConnector(name="test-local")
|
|
|
|
for passage in passage:
|
|
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
|
|
|
|
print(db.get_all())
|
|
|
|
query = "why was she crying"
|
|
query_vec = embed_model.get_text_embedding(query)
|
|
res = db.query(None, query_vec, top_k=2)
|
|
|
|
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
|
|
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
|
|
|
|
# TODO fix (causes a hang for some reason)
|
|
# print("deleting...")
|
|
# db.delete()
|
|
# print("...finished")
|
|
|
|
|
|
# test_postgres()
|