feat: isolate test config from main config (#1063)

Co-authored-by: Charles Packer <packercharles@gmail.com>
This commit is contained in:
tombedor
2024-03-06 11:21:37 +11:00
committed by GitHub
parent 83eb401be8
commit b665e67b01
17 changed files with 197 additions and 135 deletions

View File

@@ -188,7 +188,6 @@ class Agent(object):
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
):
# An agent can be created from a Preset object
if preset is not None:
assert agent_state is None, "Can create an agent from a Preset or AgentState (but both were provided)"

View File

@@ -33,7 +33,7 @@ def set_field(config, section, field, value):
@dataclass
class MemGPTConfig:
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") if os.getenv("MEMGPT_CONFIG_PATH") else os.path.join(MEMGPT_DIR, "config")
config_path: str = os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")
anon_clientid: str = str(uuid.UUID(int=0))
# preset
@@ -196,16 +196,51 @@ class MemGPTConfig:
# model defaults
set_field(config, "model", "model", self.default_llm_config.model)
set_field(config, "model", "model_endpoint", self.default_llm_config.model_endpoint)
set_field(config, "model", "model_endpoint_type", self.default_llm_config.model_endpoint_type)
set_field(
config,
"model",
"model_endpoint_type",
self.default_llm_config.model_endpoint_type,
)
set_field(config, "model", "model_wrapper", self.default_llm_config.model_wrapper)
set_field(config, "model", "context_window", str(self.default_llm_config.context_window))
set_field(
config,
"model",
"context_window",
str(self.default_llm_config.context_window),
)
# embeddings
set_field(config, "embedding", "embedding_endpoint_type", self.default_embedding_config.embedding_endpoint_type)
set_field(config, "embedding", "embedding_endpoint", self.default_embedding_config.embedding_endpoint)
set_field(config, "embedding", "embedding_model", self.default_embedding_config.embedding_model)
set_field(config, "embedding", "embedding_dim", str(self.default_embedding_config.embedding_dim))
set_field(config, "embedding", "embedding_chunk_size", str(self.default_embedding_config.embedding_chunk_size))
set_field(
config,
"embedding",
"embedding_endpoint_type",
self.default_embedding_config.embedding_endpoint_type,
)
set_field(
config,
"embedding",
"embedding_endpoint",
self.default_embedding_config.embedding_endpoint,
)
set_field(
config,
"embedding",
"embedding_model",
self.default_embedding_config.embedding_model,
)
set_field(
config,
"embedding",
"embedding_dim",
str(self.default_embedding_config.embedding_dim),
)
set_field(
config,
"embedding",
"embedding_chunk_size",
str(self.default_embedding_config.embedding_chunk_size),
)
# archival storage
set_field(config, "archival_storage", "type", self.archival_storage_type)
@@ -253,7 +288,16 @@ class MemGPTConfig:
if not os.path.exists(MEMGPT_DIR):
os.makedirs(MEMGPT_DIR, exist_ok=True)
folders = ["personas", "humans", "archival", "agents", "functions", "system_prompts", "presets", "settings"]
folders = [
"personas",
"humans",
"archival",
"agents",
"functions",
"system_prompts",
"presets",
"settings",
]
for folder in folders:
if not os.path.exists(os.path.join(MEMGPT_DIR, folder)):

View File

@@ -0,0 +1,4 @@
from tests.config import TestMGPTConfig
TEST_MEMGPT_CONFIG = TestMGPTConfig()

7
tests/config.py Normal file
View File

@@ -0,0 +1,7 @@
import os
from memgpt.config import MemGPTConfig
from memgpt.constants import MEMGPT_DIR
class TestMGPTConfig(MemGPTConfig):
config_path: str = os.getenv("TEST_MEMGPT_CONFIG_PATH") or os.getenv("MEMGPT_CONFIG_PATH") or os.path.join(MEMGPT_DIR, "config")

View File

@@ -1,16 +1,14 @@
from collections import UserDict
import json
import os
import inspect
import uuid
from memgpt.config import MemGPTConfig
from memgpt import create_client
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from memgpt.functions.functions import USER_FUNCTIONS_DIR
from memgpt.utils import assistant_function_to_tool
from memgpt.models import chat_completion_response
from tests import TEST_MEMGPT_CONFIG
from tests.utils import wipe_config, create_config
@@ -39,10 +37,8 @@ def agent():
# create memgpt client
client = create_client()
config = MemGPTConfig.load()
# ensure user exists
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})

View File

@@ -2,9 +2,9 @@ import os
import uuid
from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config
@@ -30,8 +30,7 @@ def create_test_agent():
)
global agent_obj
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)

View File

@@ -1,20 +1,13 @@
import uuid
import time
import os
import threading
from memgpt import Admin, create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
from memgpt.functions.functions import load_all_function_sets
from memgpt.prompts import gpt_system
from memgpt.constants import DEFAULT_PRESET
import pytest
from .utils import wipe_config
import uuid
@@ -116,9 +109,3 @@ def test_user_message(client):
# print(
# f"[2] MESSAGE SEND SUCCESS!!! AGENT {test_agent_state_post_message.id}\n\tmessages={test_agent_state_post_message.state['messages']}"
# )
if __name__ == "__main__":
# test_create_preset()
test_create_agent()
test_user_message()

View File

@@ -2,11 +2,10 @@ import uuid
import os
from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
from memgpt.data_types import LLMConfig, EmbeddingConfig, AgentState, Passage
from memgpt.data_types import EmbeddingConfig, Passage
from memgpt.embeddings import embedding_model
from memgpt.agent_store.storage import StorageConnector, TableType
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config, create_config
import uuid
@@ -21,7 +20,11 @@ test_user_id = uuid.uuid4()
def generate_passages(user, agent):
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
texts = [
"This is a test passage",
"This is another test passage",
"Cinderella wept",
]
embed_model = embedding_model(agent.embedding_config)
orig_embeddings = []
passages = []
@@ -86,8 +89,7 @@ def test_create_user():
hosted_agent_run.persistence_manager.archival_memory.storage.insert_many(passages)
# test passage dimentionality
config = MemGPTConfig.load()
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, client.user.id)
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, client.user.id)
storage.filters = {} # clear filters to be able to get all passages
passages = storage.get_all()
for passage in passages:

View File

@@ -9,12 +9,11 @@ from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.cli.cli_load import load_directory
# from memgpt.data_sources.connectors import DirectoryConnector, load_data
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.metadata import MetadataStore
from memgpt.data_types import User, AgentState, EmbeddingConfig
from memgpt import create_client
from .utils import wipe_config, create_config
from tests import TEST_MEMGPT_CONFIG
from .utils import wipe_config
@pytest.fixture(autouse=True)
@@ -37,16 +36,20 @@ def recreate_declarative_base():
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"])
def test_load_directory(metadata_storage_connector, passage_storage_connector, clear_dynamically_created_models, recreate_declarative_base):
def test_load_directory(
metadata_storage_connector,
passage_storage_connector,
clear_dynamically_created_models,
recreate_declarative_base,
):
wipe_config()
# setup config
config = MemGPTConfig()
if metadata_storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.metadata_storage_type = "postgres"
TEST_MEMGPT_CONFIG.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.metadata_storage_type = "postgres"
elif metadata_storage_connector == "sqlite":
print("testing sqlite metadata")
# nothing to do (should be config defaults)
@@ -56,18 +59,18 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
elif passage_storage_connector == "chroma":
print("testing chroma passage storage")
# nothing to do (should be config defaults)
else:
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
config.save()
TEST_MEMGPT_CONFIG.save()
# create metadata store
ms = MetadataStore(config)
user = User(id=uuid.UUID(config.anon_clientid))
ms = MetadataStore(TEST_MEMGPT_CONFIG)
user = User(id=uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid))
# embedding config
if os.getenv("OPENAI_API_KEY"):
@@ -100,18 +103,20 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
)
# write out the config so that the 'load' command will use it (CLI commands pull from config)
config.default_embedding_config = embedding_config
config.save()
TEST_MEMGPT_CONFIG.default_embedding_config = embedding_config
TEST_MEMGPT_CONFIG.save()
# config.default_embedding_config = embedding_config
# config.save()
# create user and agent
agent = AgentState(
user_id=user.id,
name="test_agent",
preset=config.preset,
persona=config.persona,
human=config.human,
llm_config=config.default_llm_config,
embedding_config=embedding_config,
preset=TEST_MEMGPT_CONFIG.preset,
persona=TEST_MEMGPT_CONFIG.persona,
human=TEST_MEMGPT_CONFIG.human,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
ms.delete_user(user.id)
ms.create_user(user)
@@ -123,7 +128,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Creating storage connectors...")
user_id = user.id
print("User ID", user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
# load data
name = "test_dataset"
@@ -135,7 +140,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c
print("Resetting tables with delete_table...")
passages_conn.delete_table()
print("Re-creating tables...")
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, TEST_MEMGPT_CONFIG, user_id)
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"
# test: load directory

View File

@@ -4,35 +4,29 @@ import pytest
from memgpt.agent import Agent, save_agent
from memgpt.metadata import MetadataStore
from memgpt.config import MemGPTConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig, EmbeddingConfig
from memgpt.data_types import User, AgentState, Source, LLMConfig
from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG
from memgpt.presets.presets import add_default_presets, add_default_humans_and_personas
from memgpt.models.pydantic_models import HumanModel, PersonaModel
from memgpt.models.pydantic_models import HumanModel, PersonaModel
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
@pytest.mark.parametrize("storage_connector", ["sqlite"])
def test_storage(storage_connector):
from memgpt.presets.presets import add_default_presets
config = MemGPTConfig()
if storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
if storage_connector == "sqlite":
config.recall_storage_type = "local"
TEST_MEMGPT_CONFIG.recall_storage_type = "local"
ms = MetadataStore(config)
ms = MetadataStore(TEST_MEMGPT_CONFIG)
# users
user_1 = User()
@@ -57,8 +51,8 @@ def test_storage(storage_connector):
preset=DEFAULT_PRESET,
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
source_1 = Source(user_id=user_1.id, name="source_1")
@@ -108,7 +102,7 @@ def test_storage(storage_connector):
# test: updating
# test: update JSON-stored LLMConfig class
print(agent_1.llm_config, config.default_llm_config)
print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config)
llm_config = ms.get_agent(agent_1.id).llm_config
assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}"
assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}"

View File

@@ -1,11 +1,13 @@
import os
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.config import MemGPTConfig
from .utils import wipe_config, create_config
from memgpt.server.server import SyncServer
import shutil
import uuid
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.config import MemGPTConfig
from memgpt.server.server import SyncServer
from .utils import wipe_config, create_config
def test_migrate_0211():
wipe_config()
@@ -42,7 +44,12 @@ def test_migrate_0211():
assert len(message_ids) > 0
# assert recall memories exist
messages = server.get_agent_messages(user_id=agent_state.user_id, agent_id=agent_state.id, start=0, count=1000)
messages = server.get_agent_messages(
user_id=agent_state.user_id,
agent_id=agent_state.id,
start=0,
count=1000,
)
assert len(messages) > 0
# for source_name in source_res["migration_candidates"]:

View File

@@ -5,7 +5,6 @@ import uuid
from memgpt.server.server import SyncServer
from memgpt.server.rest_api.server import app
from memgpt.constants import DEFAULT_PRESET
from memgpt.config import MemGPTConfig
# TODO: modify this to run against an actual running server
# def test_list_messages():

View File

@@ -4,13 +4,12 @@ import os
import memgpt.utils as utils
from dotenv import load_dotenv
from tests.config import TestMGPTConfig
utils.DEBUG = True
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
from memgpt.embeddings import embedding_model
from memgpt.presets.presets import add_default_presets
from memgpt.data_types import EmbeddingConfig, LLMConfig
from .utils import wipe_config, wipe_memgpt_home, DummyDataConnector
@@ -22,9 +21,10 @@ def server():
# Use os.getenv with a fallback to os.environ.get
db_url = os.getenv("PGVECTOR_TEST_DB_URL") or os.environ.get("PGVECTOR_TEST_DB_URL")
assert db_url, "Missing PGVECTOR_TEST_DB_URL"
if os.getenv("OPENAI_API_KEY"):
config = MemGPTConfig(
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
@@ -48,7 +48,7 @@ def server():
openai_key=os.getenv("OPENAI_API_KEY"),
)
else: # hosted
config = MemGPTConfig(
config = TestMGPTConfig(
archival_storage_uri=db_url,
recall_storage_uri=db_url,
metadata_storage_uri=db_url,
@@ -141,7 +141,13 @@ def test_load_data(server, user_id, agent_id):
source = server.create_source("test_source", user_id)
# load data
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
archival_memories = [
"alpha",
"Cinderella wore a blue dress",
"Dog eat dog",
"ZZZ",
"Shishir loves indian food",
]
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
@@ -215,10 +221,19 @@ def test_get_archival_memory(server, user_id, agent_id):
# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(user_id=user_id, agent_id=agent_id, reverse=False, limit=2, order_by="text")
cursor2, passages_2 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_id, reverse=False, after=cursor1, order_by="text"
user_id=user_id,
agent_id=agent_id,
reverse=False,
after=cursor1,
order_by="text",
)
cursor3, passages_3 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_id, reverse=False, before=cursor2, limit=1000, order_by="text"
user_id=user_id,
agent_id=agent_id,
reverse=False,
before=cursor2,
limit=1000,
order_by="text",
)
print("p1", [p["text"] for p in passages_1])
print("p2", [p["text"] for p in passages_2])

View File

@@ -5,8 +5,7 @@ import pytest
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.embeddings import embedding_model, query_embedding
from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig
from memgpt.config import MemGPTConfig
from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState
from memgpt.credentials import MemGPTCredentials
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.metadata import MetadataStore
@@ -15,6 +14,8 @@ from memgpt.constants import MAX_EMBEDDING_DIM
from datetime import datetime, timedelta
from tests import TEST_MEMGPT_CONFIG
# Note: the database will filter out rows that do not correspond to agent1 and test_user by default.
texts = ["This is a test passage", "This is another test passage", "Cinderella wept"]
@@ -104,7 +105,12 @@ def recreate_declarative_base():
# @pytest.mark.parametrize("storage_connector", ["sqlite", "chroma"])
# @pytest.mark.parametrize("storage_connector", ["postgres"])
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
def test_storage(storage_connector, table_type, clear_dynamically_created_models, recreate_declarative_base):
def test_storage(
storage_connector,
table_type,
clear_dynamically_created_models,
recreate_declarative_base,
):
# setup memgpt config
# TODO: set env for different config path
@@ -114,35 +120,34 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
# print("Removing messages", globals()['Message'])
# del globals()['Message']
config = MemGPTConfig()
if storage_connector == "postgres":
if not os.getenv("PGVECTOR_TEST_DB_URL"):
print("Skipping test, missing PG URI")
return
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["PGVECTOR_TEST_DB_URL"]
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
if storage_connector == "lancedb":
# TODO: complete lancedb implementation
if not os.getenv("LANCEDB_TEST_URL"):
print("Skipping test, missing LanceDB URI")
return
config.archival_storage_uri = os.getenv("LANCEDB_TEST_URL")
config.recall_storage_uri = os.getenv("LANCEDB_TEST_URL")
config.archival_storage_type = "lancedb"
config.recall_storage_type = "lancedb"
TEST_MEMGPT_CONFIG.archival_storage_uri = os.environ["LANCEDB_TEST_URL"]
TEST_MEMGPT_CONFIG.recall_storage_uri = os.environ["LANCEDB_TEST_URL"]
TEST_MEMGPT_CONFIG.archival_storage_type = "lancedb"
TEST_MEMGPT_CONFIG.recall_storage_type = "lancedb"
if storage_connector == "chroma":
if table_type == TableType.RECALL_MEMORY:
print("Skipping test, chroma only supported for archival memory")
return
config.archival_storage_type = "chroma"
config.archival_storage_path = "./test_chroma"
TEST_MEMGPT_CONFIG.archival_storage_type = "chroma"
TEST_MEMGPT_CONFIG.archival_storage_path = "./test_chroma"
if storage_connector == "sqlite":
if table_type == TableType.ARCHIVAL_MEMORY:
print("Skipping test, sqlite only supported for recall memory")
return
config.recall_storage_type = "sqlite"
TEST_MEMGPT_CONFIG.recall_storage_type = "sqlite"
# get embedding model
embed_model = None
@@ -162,27 +167,27 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models
embed_model = embedding_model(embedding_config)
# create user
ms = MetadataStore(config)
ms = MetadataStore(TEST_MEMGPT_CONFIG)
ms.delete_user(user_id)
user = User(id=user_id)
agent = AgentState(
user_id=user_id,
name="agent_1",
id=agent_1_id,
preset=config.preset,
persona=config.persona,
human=config.human,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
preset=TEST_MEMGPT_CONFIG.preset,
persona=TEST_MEMGPT_CONFIG.persona,
human=TEST_MEMGPT_CONFIG.human,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
ms.create_user(user)
ms.create_agent(agent)
# create storage connector
conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id)
conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
# conn.client.delete_collection(conn.collection.name) # clear out data
conn.delete_table()
conn = StorageConnector.get_storage_connector(table_type, config=config, user_id=user_id, agent_id=agent.id)
conn = StorageConnector.get_storage_connector(table_type, config=TEST_MEMGPT_CONFIG, user_id=user_id, agent_id=agent.id)
# generate data
if table_type == TableType.ARCHIVAL_MEMORY:

View File

@@ -2,9 +2,7 @@ import os
import uuid
from memgpt import create_client
from memgpt.config import MemGPTConfig
from memgpt import constants
import memgpt.functions.function_sets.base as base_functions
from .utils import wipe_config, create_config
@@ -31,7 +29,6 @@ def create_test_agent():
)
global agent_obj
config = MemGPTConfig.load()
agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id)
@@ -48,12 +45,16 @@ def test_summarize():
# First send a few messages (5)
response = client.user_message(
agent_id=agent_obj.agent_state.id, message="Hey, how's it going? What do you think about this whole shindig"
agent_id=agent_obj.agent_state.id,
message="Hey, how's it going? What do you think about this whole shindig",
)
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
response = client.user_message(agent_id=agent_obj.agent_state.id, message="Any thoughts on the meaning of life?")
response = client.user_message(
agent_id=agent_obj.agent_state.id,
message="Any thoughts on the meaning of life?",
)
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")
@@ -62,7 +63,8 @@ def test_summarize():
print(f"test_summarize: response={response}")
response = client.user_message(
agent_id=agent_obj.agent_state.id, message="Would you be surprised to learn that you're actually conversing with an AI right now?"
agent_id=agent_obj.agent_state.id,
message="Would you be surprised to learn that you're actually conversing with an AI right now?",
)
assert response is not None and len(response) > 0
print(f"test_summarize: response={response}")

View File

@@ -1,13 +1,11 @@
import os
import pytest
from unittest.mock import Mock, AsyncMock, MagicMock
from unittest.mock import AsyncMock
from memgpt.config import MemGPTConfig, AgentConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.server.ws_api.interface import SyncWebSocketInterface
import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.system as system
from memgpt.persistence_manager import LocalStateManager
from memgpt.data_types import AgentState
@@ -62,10 +60,10 @@ async def test_websockets():
if api_key is None:
ws_interface.close()
return
config = MemGPTConfig.load()
if config.openai_key is None:
config.openai_key = api_key
config.save()
credentials = MemGPTCredentials.load()
if credentials.openai_key is None:
credentials.openai_key = api_key
credentials.save()
# Mock the persistence manager
# create agents with defaults

View File

@@ -2,11 +2,10 @@ import datetime
from typing import Dict, List, Tuple, Iterator
import os
from memgpt.config import MemGPTConfig
from memgpt.cli.cli import quickstart, QuickstartChoice
from memgpt.data_sources.connectors import DataConnector
from memgpt import Admin
from memgpt.data_types import Document
from tests import TEST_MEMGPT_CONFIG
from .constants import TIMEOUT
@@ -37,17 +36,17 @@ def create_config(endpoint="openai"):
def wipe_config():
if MemGPTConfig.exists():
if TEST_MEMGPT_CONFIG.exists():
# delete
if os.getenv("MEMGPT_CONFIG_PATH"):
config_path = os.getenv("MEMGPT_CONFIG_PATH")
else:
config_path = MemGPTConfig.config_path
config_path = TEST_MEMGPT_CONFIG.config_path
# TODO delete file config_path
os.remove(config_path)
assert not MemGPTConfig.exists(), "Config should not exist after deletion"
assert not TEST_MEMGPT_CONFIG.exists(), "Config should not exist after deletion"
else:
print("No config to wipe", MemGPTConfig.config_path)
print("No config to wipe", TEST_MEMGPT_CONFIG.config_path)
def wipe_memgpt_home():
@@ -63,7 +62,7 @@ def wipe_memgpt_home():
os.system(f"mv ~/.memgpt {backup_dir}")
# Setup the initial directory
MemGPTConfig.create_config_dir()
TEST_MEMGPT_CONFIG.create_config_dir()
def configure_memgpt_localllm():