feat: isolate test config from main config (#1063)
Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
@@ -4,13 +4,12 @@ import os
|
||||
import memgpt.utils as utils
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from tests.config import TestMGPTConfig
|
||||
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig
|
||||
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
|
||||
|
||||
|
||||
@@ -22,9 +21,10 @@ def server():
|
||||
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
|
||||
assert db_url, "Missing PGVECTOR_TEST_DB_URL"
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = MemGPTConfig(
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
@@ -48,7 +48,7 @@ def server():
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
config = MemGPTConfig(
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
@@ -141,7 +141,13 @@ def test_load_data(server, user_id, agent_id):
|
||||
source = server.create_source("test_source", user_id)
|
||||
|
||||
# load data
|
||||
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
||||
archival_memories = [
|
||||
"alpha",
|
||||
"Cinderella wore a blue dress",
|
||||
"Dog eat dog",
|
||||
"ZZZ",
|
||||
"Shishir loves indian food",
|
||||
]
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
|
||||
@@ -215,10 +221,19 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||
cursor2, passages_2 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
after=cursor1,
|
||||
order_by="text",
|
||||
)
|
||||
cursor3, passages_3 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
before=cursor2,
|
||||
limit=1000,
|
||||
order_by="text",
|
||||
)
|
||||
print("p1", [p["text"] for p in passages_1])
|
||||
print("p2", [p["text"] for p in passages_2])
|
||||
|
||||
Reference in New Issue
Block a user