feat: isolate test config from main config (#1063)

Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
tombedor
2024-03-06 11:21:37 +11:00
committed by GitHub
parent 83eb401be8
commit b665e67b01
17 changed files with 197 additions and 135 deletions

View File

@@ -4,13 +4,12 @@ import os
import memgpt.utils as utils
from dotenv import load_dotenv
from tests.config import TestMGPTConfig
utils.DEBUG = True
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
from memgpt.embeddings import embedding_model
from memgpt.presets.presets import add_default_presets
from memgpt.data_types import EmbeddingConfig, LLMConfig
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
@@ -22,9 +21,10 @@ def server():
# Use os.getenv with a fallback to os.environ.get
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
assert db_url, "Missing PGVECTOR_TEST_DB_URL"
if os.getenv("OPENAI_API_KEY"):
config = MemGPTConfig(
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
@@ -48,7 +48,7 @@ def server():
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
config = MemGPTConfig(
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
@@ -141,7 +141,13 @@ def test_load_data(server, user_id, agent_id):
source = server.create_source("test_source", user_id)
# load data
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
archival_memories = [
"alpha",
"Cinderella wore a blue dress",
"Dog eat dog",
"ZZZ",
"Shishir loves indian food",
]
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
@@ -215,10 +221,19 @@ def test_get_archival_memory(server, user_id, agent_id):
# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
cursor2, passages_2 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
user_id=user_id,
agent_id=agent_id,
reverse=False,
after=cursor1,
order_by="text",
)
cursor3, passages_3 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
user_id=user_id,
agent_id=agent_id,
reverse=False,
before=cursor2,
limit=1000,
order_by="text",
)
print("p1", [p["text"] for p in passages_1])
print("p2", [p["text"] for p in passages_2])