diff --git a/tests/clear_postgres_db.py b/tests/clear_postgres_db.py index 60618363..457beea7 100644 --- a/tests/clear_postgres_db.py +++ b/tests/clear_postgres_db.py @@ -1,8 +1,19 @@ +import os + from sqlalchemy import create_engine, MetaData -engine = create_engine("postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt") -meta = MetaData() -meta.reflect(bind=engine) +def main(): + uri = os.environ.get( + "PGVECTOR_TEST_DB_URL", + "postgresql+pg8000://memgpt:memgpt@localhost:8888/memgpt", + ) -meta.drop_all(bind=engine) + engine = create_engine(uri) + meta = MetaData() + meta.reflect(bind=engine) + meta.drop_all(bind=engine) + + +if __name__ == "__main__": + main()