Files
letta-server/tests/test_storage.py
Sarah Wooders ec2bda4966 Refactor config + determine LLM via config.model_endpoint_type (#422)
* mark depricated API section

* CLI bug fixes for azure

* check azure before running

* Update README.md

* Update README.md

* bug fix with persona loading

* remove print

* make errors for cli flags more clear

* format

* fix imports

* fix imports

* add prints

* update lock

* update config fields

* cleanup config loading

* commit

* remove asserts

* refactor configure

* put into different functions

* add embedding default

* pass in config

* fixes

* allow overriding openai embedding endpoint

* black

* trying to patch tests (some circular import errors)

* update flags and docs

* patched support for local llms using endpoint and endpoint type passed via configs, not env vars

* missing files

* fix naming

* fix import

* fix two runtime errors

* patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included

* disable debug messages

* made error message for failed load more informative

* don't print dynamic linking function warning unless --debug

* updated tests to work with new cli workflow (disabled openai config test for now)

* added skips for tests when vars are missing

* update bad arg

* revise test to soft pass on empty string too

* don't run configure twice

* extend timeout (try to pass against nltk download)

* update defaults

* typo with endpoint type default

* patch runtime errors for when model is None

* catching another case of 'x in model' when model is None (preemptively)

* allow overrides to local llm related config params

* made model wrapper selection from a list vs raw input

* update test for select instead of input

* Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry

* updated error messages to be more informative with links to readthedocs

* add back gpt3.5-turbo

---------

Co-authored-by: cpacker <packercharles@gmail.com>
2023-11-14 15:58:19 -08:00

105 lines
3.4 KiB
Python

import os
import subprocess
import sys
import pytest
subprocess.check_call(
[sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"]
) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
import pgvector # Try to import again after installing
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.connectors.db import PostgresStorageConnector
from memgpt.embeddings import embedding_model
from memgpt.config import MemGPTConfig, AgentConfig
import argparse
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key")
def test_postgres_openai():
if not os.getenv("PGVECTOR_TEST_DB_URL"):
return # soft pass
if not os.getenv("OPENAI_API_KEY"):
return # soft pass
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"
) # https://stackoverflow.com/a/64698899
config.save()
print(config)
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = PostgresStorageConnector(name="test-openai")
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
print(db.get_all())
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
# TODO fix (causes a hang for some reason)
# print("deleting...")
# db.delete()
# print("...finished")
@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI")
def test_postgres_local():
if not os.getenv("PGVECTOR_TEST_DB_URL"):
return
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(
archival_storage_type="postgres",
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
embedding_endpoint_type="local",
embedding_dim=384, # use HF model
)
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"
) # https://stackoverflow.com/a/64698899
config.save()
print(config)
embed_model = embedding_model()
passage = ["This is a test passage", "This is another test passage", "Cinderella wept"]
db = PostgresStorageConnector(name="test-local")
for passage in passage:
db.insert(Passage(text=passage, embedding=embed_model.get_text_embedding(passage)))
print(db.get_all())
query = "why was she crying"
query_vec = embed_model.get_text_embedding(query)
res = db.query(None, query_vec, top_k=2)
assert len(res) == 2, f"Expected 2 results, got {len(res)}"
assert "wept" in res[0].text, f"Expected 'wept' in results, but got {res[0].text}"
# TODO fix (causes a hang for some reason)
# print("deleting...")
# db.delete()
# print("...finished")
# test_postgres()