feat: isolate test config from main config (#1063)
Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
@@ -188,7 +188,6 @@ class Agent(object):
|
||||
messages_total: Optional[int] = None, # TODO remove?
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
):
|
||||
|
||||
# An agent can be created from a Preset object
|
||||
if preset is not None:
|
||||
assert agent_state is None, "Can create an agent from a Preset or AgentState (but both were provided)"
|
||||
|
||||
@@ -33,7 +33,7 @@ def set_field(config, section, field, value):
|
||||
|
||||
@dataclass
|
||||
class MemGPTConfig:
|
||||
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config")
|
||||
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
|
||||
anon_clientid: str = str(uuid.UUID(int=0))
|
||||
|
||||
# preset
|
||||
@@ -196,16 +196,51 @@ class MemGPTConfig:
|
||||
# model defaults
|
||||
set_field(config, "model", "model", self.default_llm_config.model)
|
||||
set_field(config, "model", "model_endpoint", self.default_llm_config.model_endpoint)
|
||||
set_field(config, "model", "model_endpoint_type", self.default_llm_config.model_endpoint_type)
|
||||
set_field(
|
||||
config,
|
||||
"model",
|
||||
"model_endpoint_type",
|
||||
self.default_llm_config.model_endpoint_type,
|
||||
)
|
||||
set_field(config, "model", "model_wrapper", self.default_llm_config.model_wrapper)
|
||||
set_field(config, "model", "context_window", str(self.default_llm_config.context_window))
|
||||
set_field(
|
||||
config,
|
||||
"model",
|
||||
"context_window",
|
||||
str(self.default_llm_config.context_window),
|
||||
)
|
||||
|
||||
# embeddings
|
||||
set_field(config, "embedding", "embedding_endpoint_type", self.default_embedding_config.embedding_endpoint_type)
|
||||
set_field(config, "embedding", "embedding_endpoint", self.default_embedding_config.embedding_endpoint)
|
||||
set_field(config, "embedding", "embedding_model", self.default_embedding_config.embedding_model)
|
||||
set_field(config, "embedding", "embedding_dim", str(self.default_embedding_config.embedding_dim))
|
||||
set_field(config, "embedding", "embedding_chunk_size", str(self.default_embedding_config.embedding_chunk_size))
|
||||
set_field(
|
||||
config,
|
||||
"embedding",
|
||||
"embedding_endpoint_type",
|
||||
self.default_embedding_config.embedding_endpoint_type,
|
||||
)
|
||||
set_field(
|
||||
config,
|
||||
"embedding",
|
||||
"embedding_endpoint",
|
||||
self.default_embedding_config.embedding_endpoint,
|
||||
)
|
||||
set_field(
|
||||
config,
|
||||
"embedding",
|
||||
"embedding_model",
|
||||
self.default_embedding_config.embedding_model,
|
||||
)
|
||||
set_field(
|
||||
config,
|
||||
"embedding",
|
||||
"embedding_dim",
|
||||
str(self.default_embedding_config.embedding_dim),
|
||||
)
|
||||
set_field(
|
||||
config,
|
||||
"embedding",
|
||||
"embedding_chunk_size",
|
||||
str(self.default_embedding_config.embedding_chunk_size),
|
||||
)
|
||||
|
||||
# archival storage
|
||||
set_field(config, "archival_storage", "type", self.archival_storage_type)
|
||||
@@ -253,7 +288,16 @@ class MemGPTConfig:
|
||||
if not os.path.exists(MEMGPT_DIR):
|
||||
os.makedirs(MEMGPT_DIR, exist_ok=True)
|
||||
|
||||
folders = ["personas", "humans", "archival", "agents", "functions", "system_prompts", "presets", "settings"]
|
||||
folders = [
|
||||
"personas",
|
||||
"humans",
|
||||
"archival",
|
||||
"agents",
|
||||
"functions",
|
||||
"system_prompts",
|
||||
"presets",
|
||||
"settings",
|
||||
]
|
||||
|
||||
for folder in folders:
|
||||
if not os.path.exists(os.path.join(MEMGPT_DIR, folder)):
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
from tests.config import TestMGPTConfig
|
||||
|
||||
|
||||
TEST_MEMGPT_CONFIG = TestMGPTConfig()
|
||||
|
||||
7
tests/config.py
Normal file
7
tests/config.py
Normal file
@@ -0,0 +1,7 @@
|
||||
import os
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
|
||||
|
||||
class TestMGPTConfig(MemGPTConfig):
|
||||
config_path: str = os.getenv("TEST_MEMGPT_CONFIG_PATH") or os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
|
||||
@@ -1,16 +1,14 @@
|
||||
from collections import UserDict
|
||||
import json
|
||||
import os
|
||||
import inspect
|
||||
import uuid
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import create_client
|
||||
from memgpt import constants
|
||||
import memgpt.functions.function_sets.base as base_functions
|
||||
from memgpt.functions.functions import USER_FUNCTIONS_DIR
|
||||
from memgpt.utils import assistant_function_to_tool
|
||||
from memgpt.models import chat_completion_response
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
from tests.utils import wipe_config, create_config
|
||||
|
||||
@@ -39,10 +37,8 @@ def agent():
|
||||
# create memgpt client
|
||||
client = create_client()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# ensure user exists
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
|
||||
if not client.server.get_user(user_id=user_id):
|
||||
client.server.create_user({"id": user_id})
|
||||
|
||||
|
||||
@@ -2,9 +2,9 @@ import os
|
||||
import uuid
|
||||
|
||||
from memgpt import create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import constants
|
||||
import memgpt.functions.function_sets.base as base_functions
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from .utils import wipe_config, create_config
|
||||
|
||||
|
||||
@@ -30,8 +30,7 @@ def create_test_agent():
|
||||
)
|
||||
|
||||
global agent_obj
|
||||
config = MemGPTConfig.load()
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
|
||||
agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
|
||||
@@ -1,20 +1,13 @@
|
||||
import uuid
|
||||
import time
|
||||
import os
|
||||
import threading
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import constants
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
from .utils import wipe_config
|
||||
import uuid
|
||||
|
||||
|
||||
@@ -116,9 +109,3 @@ def test_user_message(client):
|
||||
# print(
|
||||
# f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}"
|
||||
# )
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# test_create_preset()
|
||||
test_create_agent()
|
||||
test_user_message()
|
||||
|
||||
@@ -2,11 +2,10 @@ import uuid
|
||||
import os
|
||||
|
||||
from memgpt import create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import constants
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage
|
||||
from memgpt.data_types import EmbeddingConfig, Passage
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from .utils import wipe_config, create_config
|
||||
import uuid
|
||||
|
||||
@@ -21,7 +20,11 @@ test_user_id = uuid.uuid4()
|
||||
|
||||
def generate_passages(user, agent):
|
||||
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
texts = [
|
||||
"This is a test passage",
|
||||
"This is another test passage",
|
||||
"Cinderella wept",
|
||||
]
|
||||
embed_model = embedding_model(agent.embedding_config)
|
||||
orig_embeddings = []
|
||||
passages = []
|
||||
@@ -86,8 +89,7 @@ def test_create_user():
|
||||
hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
|
||||
|
||||
# test passage dimentionality
|
||||
config = MemGPTConfig.load()
|
||||
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, client.user.id)
|
||||
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id)
|
||||
storage.filters = {} # clear filters to be able to get all passages
|
||||
passages = storage.get_all()
|
||||
for passage in passages:
|
||||
|
||||
@@ -9,12 +9,11 @@ from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.cli.cli_load import load_directory
|
||||
|
||||
# from memgpt.data_sources.connectors import DirectoryConnector, load_data
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.data_types import User, AgentState, EmbeddingConfig
|
||||
from memgpt import create_client
|
||||
from .utils import wipe_config, create_config
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from .utils import wipe_config
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -37,16 +36,20 @@ def recreate_declarative_base():
|
||||
|
||||
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
|
||||
@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"])
|
||||
def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base):
|
||||
def test_load_directory(
|
||||
metadata_storage_connector,
|
||||
passage_storage_connector,
|
||||
clear_dynamically_created_models,
|
||||
recreate_declarative_base,
|
||||
):
|
||||
wipe_config()
|
||||
# setup config
|
||||
config = MemGPTConfig()
|
||||
if metadata_storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.metadata_storage_type = "postgres"
|
||||
TEST_MEMGPT_CONFIG.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres"
|
||||
elif metadata_storage_connector == "sqlite":
|
||||
print("testing sqlite metadata")
|
||||
# nothing to do (should be config defaults)
|
||||
@@ -56,18 +59,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.archival_storage_type = "postgres"
|
||||
TEST_MEMGPT_CONFIG.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
|
||||
elif passage_storage_connector == "chroma":
|
||||
print("testing chroma passage storage")
|
||||
# nothing to do (should be config defaults)
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
|
||||
config.save()
|
||||
TEST_MEMGPT_CONFIG.save()
|
||||
|
||||
# create metadata store
|
||||
ms = MetadataStore(config)
|
||||
user = User(id=uuid.UUID(config.anon_clientid))
|
||||
ms = MetadataStore(TEST_MEMGPT_CONFIG)
|
||||
user = User(id=uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid))
|
||||
|
||||
# embedding config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
@@ -100,18 +103,20 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
|
||||
)
|
||||
|
||||
# write out the config so that the 'load' command will use it (CLI commands pull from config)
|
||||
config.default_embedding_config = embedding_config
|
||||
config.save()
|
||||
TEST_MEMGPT_CONFIG.default_embedding_config = embedding_config
|
||||
TEST_MEMGPT_CONFIG.save()
|
||||
# config.default_embedding_config = embedding_config
|
||||
# config.save()
|
||||
|
||||
# create user and agent
|
||||
agent = AgentState(
|
||||
user_id=user.id,
|
||||
name="test_agent",
|
||||
preset=config.preset,
|
||||
persona=config.persona,
|
||||
human=config.human,
|
||||
llm_config=config.default_llm_config,
|
||||
embedding_config=embedding_config,
|
||||
preset=TEST_MEMGPT_CONFIG.preset,
|
||||
persona=TEST_MEMGPT_CONFIG.persona,
|
||||
human=TEST_MEMGPT_CONFIG.human,
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
ms.delete_user(user.id)
|
||||
ms.create_user(user)
|
||||
@@ -123,7 +128,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
|
||||
print("Creating storage connectors...")
|
||||
user_id = user.id
|
||||
print("User ID", user_id)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
|
||||
|
||||
# load data
|
||||
name = "test_dataset"
|
||||
@@ -135,7 +140,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
|
||||
print("Resetting tables with delete_table...")
|
||||
passages_conn.delete_table()
|
||||
print("Re-creating tables...")
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
|
||||
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"
|
||||
|
||||
# test: load directory
|
||||
|
||||
@@ -4,35 +4,29 @@ import pytest
|
||||
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
|
||||
from memgpt.data_types import User, AgentState, Source, LLMConfig
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from memgpt.presets.presets import add_default_presets, add_default_humans_and_personas
|
||||
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
|
||||
@pytest.mark.parametrize("storage_connector", ["sqlite"])
|
||||
def test_storage(storage_connector):
|
||||
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
config = MemGPTConfig()
|
||||
if storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
|
||||
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
|
||||
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
|
||||
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
|
||||
if storage_connector == "sqlite":
|
||||
config.recall_storage_type = "local"
|
||||
TEST_MEMGPT_CONFIG.recall_storage_type = "local"
|
||||
|
||||
ms = MetadataStore(config)
|
||||
ms = MetadataStore(TEST_MEMGPT_CONFIG)
|
||||
|
||||
# users
|
||||
user_1 = User()
|
||||
@@ -57,8 +51,8 @@ def test_storage(storage_connector):
|
||||
preset=DEFAULT_PRESET,
|
||||
persona=DEFAULT_PERSONA,
|
||||
human=DEFAULT_HUMAN,
|
||||
llm_config=config.default_llm_config,
|
||||
embedding_config=config.default_embedding_config,
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
source_1 = Source(user_id=user_1.id, name="source_1")
|
||||
|
||||
@@ -108,7 +102,7 @@ def test_storage(storage_connector):
|
||||
# test: updating
|
||||
|
||||
# test: update JSON-stored LLMConfig class
|
||||
print(agent_1.llm_config, config.default_llm_config)
|
||||
print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config)
|
||||
llm_config = ms.get_agent(agent_1.id).llm_config
|
||||
assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}"
|
||||
assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}"
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
import os
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
from memgpt.config import MemGPTConfig
|
||||
from .utils import wipe_config, create_config
|
||||
from memgpt.server.server import SyncServer
|
||||
import shutil
|
||||
import uuid
|
||||
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
from .utils import wipe_config, create_config
|
||||
|
||||
|
||||
def test_migrate_0211():
|
||||
wipe_config()
|
||||
@@ -42,7 +44,12 @@ def test_migrate_0211():
|
||||
assert len(message_ids) > 0
|
||||
|
||||
# assert recall memories exist
|
||||
messages = server.get_agent_messages(user_id=agent_state.user_id, agent_id=agent_state.id, start=0, count=1000)
|
||||
messages = server.get_agent_messages(
|
||||
user_id=agent_state.user_id,
|
||||
agent_id=agent_state.id,
|
||||
start=0,
|
||||
count=1000,
|
||||
)
|
||||
assert len(messages) > 0
|
||||
|
||||
# for source_name in source_res["migration_candidates"]:
|
||||
|
||||
@@ -5,7 +5,6 @@ import uuid
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.server.rest_api.server import app
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.config import MemGPTConfig
|
||||
|
||||
# TODO: modify this to run against an actual running server
|
||||
# def test_list_messages():
|
||||
|
||||
@@ -4,13 +4,12 @@ import os
|
||||
import memgpt.utils as utils
|
||||
from dotenv import load_dotenv
|
||||
|
||||
from tests.config import TestMGPTConfig
|
||||
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig
|
||||
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
|
||||
|
||||
|
||||
@@ -22,9 +21,10 @@ def server():
|
||||
|
||||
# Use os.getenv with a fallback to os.environ.get
|
||||
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
|
||||
assert db_url, "Missing PGVECTOR_TEST_DB_URL"
|
||||
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
config = MemGPTConfig(
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
@@ -48,7 +48,7 @@ def server():
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
else: # hosted
|
||||
config = MemGPTConfig(
|
||||
config = TestMGPTConfig(
|
||||
archival_storage_uri=db_url,
|
||||
recall_storage_uri=db_url,
|
||||
metadata_storage_uri=db_url,
|
||||
@@ -141,7 +141,13 @@ def test_load_data(server, user_id, agent_id):
|
||||
source = server.create_source("test_source", user_id)
|
||||
|
||||
# load data
|
||||
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
||||
archival_memories = [
|
||||
"alpha",
|
||||
"Cinderella wore a blue dress",
|
||||
"Dog eat dog",
|
||||
"ZZZ",
|
||||
"Shishir loves indian food",
|
||||
]
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
|
||||
@@ -215,10 +221,19 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
|
||||
cursor2, passages_2 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
after=cursor1,
|
||||
order_by="text",
|
||||
)
|
||||
cursor3, passages_3 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||
user_id=user_id,
|
||||
agent_id=agent_id,
|
||||
reverse=False,
|
||||
before=cursor2,
|
||||
limit=1000,
|
||||
order_by="text",
|
||||
)
|
||||
print("p1", [p["text"] for p in passages_1])
|
||||
print("p2", [p["text"] for p in passages_2])
|
||||
|
||||
@@ -5,8 +5,7 @@ import pytest
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.embeddings import embedding_model, query_embedding
|
||||
from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.metadata import MetadataStore
|
||||
@@ -15,6 +14,8 @@ from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
|
||||
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
|
||||
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
|
||||
@@ -104,7 +105,12 @@ def recreate_declarative_base():
|
||||
# @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"])
|
||||
# @pytest.mark.parametrize("storage_connector", ["postgres"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
|
||||
def test_storage(storage_connector, table_type, clear_dynamically_created_models, recreate_declarative_base):
|
||||
def test_storage(
|
||||
storage_connector,
|
||||
table_type,
|
||||
clear_dynamically_created_models,
|
||||
recreate_declarative_base,
|
||||
):
|
||||
# setup memgpt config
|
||||
# TODO: set env for different config path
|
||||
|
||||
@@ -114,35 +120,34 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
|
||||
# print("Removing messages", globals()['Message'])
|
||||
# del globals()['Message']
|
||||
|
||||
config = MemGPTConfig()
|
||||
if storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
|
||||
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
|
||||
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
|
||||
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
|
||||
if storage_connector == "lancedb":
|
||||
# TODO: complete lancedb implementation
|
||||
if not os.getenv("LANCEDB_TEST_URL"):
|
||||
print("Skipping test, missing LanceDB URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("LANCEDB_TEST_URL")
|
||||
config.recall_storage_uri = os.getenv("LANCEDB_TEST_URL")
|
||||
config.archival_storage_type = "lancedb"
|
||||
config.recall_storage_type = "lancedb"
|
||||
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["LANCEDB_TEST_URL"]
|
||||
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["LANCEDB_TEST_URL"]
|
||||
TEST_MEMGPT_CONFIG.archival_storage_type = "lancedb"
|
||||
TEST_MEMGPT_CONFIG.recall_storage_type = "lancedb"
|
||||
if storage_connector == "chroma":
|
||||
if table_type == TableType.RECALL_MEMORY:
|
||||
print("Skipping test, chroma only supported for archival memory")
|
||||
return
|
||||
config.archival_storage_type = "chroma"
|
||||
config.archival_storage_path = "./test_chroma"
|
||||
TEST_MEMGPT_CONFIG.archival_storage_type = "chroma"
|
||||
TEST_MEMGPT_CONFIG.archival_storage_path = "./test_chroma"
|
||||
if storage_connector == "sqlite":
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
print("Skipping test, sqlite only supported for recall memory")
|
||||
return
|
||||
config.recall_storage_type = "sqlite"
|
||||
TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite"
|
||||
|
||||
# get embedding model
|
||||
embed_model = None
|
||||
@@ -162,27 +167,27 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
|
||||
embed_model = embedding_model(embedding_config)
|
||||
|
||||
# create user
|
||||
ms = MetadataStore(config)
|
||||
ms = MetadataStore(TEST_MEMGPT_CONFIG)
|
||||
ms.delete_user(user_id)
|
||||
user = User(id=user_id)
|
||||
agent = AgentState(
|
||||
user_id=user_id,
|
||||
name="agent_1",
|
||||
id=agent_1_id,
|
||||
preset=config.preset,
|
||||
persona=config.persona,
|
||||
human=config.human,
|
||||
llm_config=config.default_llm_config,
|
||||
embedding_config=config.default_embedding_config,
|
||||
preset=TEST_MEMGPT_CONFIG.preset,
|
||||
persona=TEST_MEMGPT_CONFIG.persona,
|
||||
human=TEST_MEMGPT_CONFIG.human,
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
ms.create_user(user)
|
||||
ms.create_agent(agent)
|
||||
|
||||
# create storage connector
|
||||
conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id)
|
||||
conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
|
||||
# conn.client.delete_collection(conn.collection.name) # clear out data
|
||||
conn.delete_table()
|
||||
conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id)
|
||||
conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
|
||||
|
||||
# generate data
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
|
||||
@@ -2,9 +2,7 @@ import os
|
||||
import uuid
|
||||
|
||||
from memgpt import create_client
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import constants
|
||||
import memgpt.functions.function_sets.base as base_functions
|
||||
from .utils import wipe_config, create_config
|
||||
|
||||
|
||||
@@ -31,7 +29,6 @@ def create_test_agent():
|
||||
)
|
||||
|
||||
global agent_obj
|
||||
config = MemGPTConfig.load()
|
||||
agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@@ -48,12 +45,16 @@ def test_summarize():
|
||||
|
||||
# First send a few messages (5)
|
||||
response = client.user_message(
|
||||
agent_id=agent_obj.agent_state.id, message="Hey, how's it going? What do you think about this whole shindig"
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
message="Hey, how's it going? What do you think about this whole shindig",
|
||||
)
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(agent_id=agent_obj.agent_state.id, message="Any thoughts on the meaning of life?")
|
||||
response = client.user_message(
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
message="Any thoughts on the meaning of life?",
|
||||
)
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
@@ -62,7 +63,8 @@ def test_summarize():
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
response = client.user_message(
|
||||
agent_id=agent_obj.agent_state.id, message="Would you be surprised to learn that you're actually conversing with an AI right now?"
|
||||
agent_id=agent_obj.agent_state.id,
|
||||
message="Would you be surprised to learn that you're actually conversing with an AI right now?",
|
||||
)
|
||||
assert response is not None and len(response) > 0
|
||||
print(f"test_summarize: response={response}")
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
import os
|
||||
import pytest
|
||||
from unittest.mock import Mock, AsyncMock, MagicMock
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.server.ws_api.interface import SyncWebSocketInterface
|
||||
import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
import memgpt.system as system
|
||||
from memgpt.persistence_manager import LocalStateManager
|
||||
from memgpt.data_types import AgentState
|
||||
|
||||
|
||||
@@ -62,10 +60,10 @@ async def test_websockets():
|
||||
if api_key is None:
|
||||
ws_interface.close()
|
||||
return
|
||||
config = MemGPTConfig.load()
|
||||
if config.openai_key is None:
|
||||
config.openai_key = api_key
|
||||
config.save()
|
||||
credentials = MemGPTCredentials.load()
|
||||
if credentials.openai_key is None:
|
||||
credentials.openai_key = api_key
|
||||
credentials.save()
|
||||
|
||||
# Mock the persistence manager
|
||||
# create agents with defaults
|
||||
|
||||
@@ -2,11 +2,10 @@ import datetime
|
||||
from typing import Dict, List, Tuple, Iterator
|
||||
import os
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.cli.cli import quickstart, QuickstartChoice
|
||||
from memgpt.data_sources.connectors import DataConnector
|
||||
from memgpt import Admin
|
||||
from memgpt.data_types import Document
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
from .constants import TIMEOUT
|
||||
|
||||
@@ -37,17 +36,17 @@ def create_config(endpoint="openai"):
|
||||
|
||||
|
||||
def wipe_config():
|
||||
if MemGPTConfig.exists():
|
||||
if TEST_MEMGPT_CONFIG.exists():
|
||||
# delete
|
||||
if os.getenv("MEMGPT_CONFIG_PATH"):
|
||||
config_path = os.getenv("MEMGPT_CONFIG_PATH")
|
||||
else:
|
||||
config_path = MemGPTConfig.config_path
|
||||
config_path = TEST_MEMGPT_CONFIG.config_path
|
||||
# TODO delete file config_path
|
||||
os.remove(config_path)
|
||||
assert not MemGPTConfig.exists(), "Config should not exist after deletion"
|
||||
assert not TEST_MEMGPT_CONFIG.exists(), "Config should not exist after deletion"
|
||||
else:
|
||||
print("No config to wipe", MemGPTConfig.config_path)
|
||||
print("No config to wipe", TEST_MEMGPT_CONFIG.config_path)
|
||||
|
||||
|
||||
def wipe_memgpt_home():
|
||||
@@ -63,7 +62,7 @@ def wipe_memgpt_home():
|
||||
os.system(f"mv ~/.memgpt {backup_dir}")
|
||||
|
||||
# Setup the initial directory
|
||||
MemGPTConfig.create_config_dir()
|
||||
TEST_MEMGPT_CONFIG.create_config_dir()
|
||||
|
||||
|
||||
def configure_memgpt_localllm():
|
||||
|
||||
Reference in New Issue
Block a user