refactor: store presets in database via metadata store (#1013)
This commit is contained in:
@@ -4,12 +4,18 @@ import os
|
||||
from memgpt import MemGPT
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt import constants
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig
|
||||
from memgpt.data_types import LLMConfig, EmbeddingConfig, Preset
|
||||
from memgpt.functions.functions import load_all_function_sets
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
|
||||
|
||||
from .utils import wipe_config
|
||||
import uuid
|
||||
|
||||
|
||||
test_agent_name = f"test_client_{str(uuid.uuid4())}"
|
||||
test_preset_name = "test_preset"
|
||||
test_agent_state = None
|
||||
client = None
|
||||
|
||||
@@ -17,7 +23,7 @@ test_agent_state_post_message = None
|
||||
test_user_id = uuid.uuid4()
|
||||
|
||||
|
||||
def test_create_agent():
|
||||
def test_create_preset():
|
||||
wipe_config()
|
||||
global client
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
@@ -25,6 +31,20 @@ def test_create_agent():
|
||||
else:
|
||||
client = MemGPT(quickstart="memgpt_hosted", user_id=test_user_id)
|
||||
|
||||
available_functions = load_all_function_sets(merge=True)
|
||||
functions_schema = [f_dict["json_schema"] for f_name, f_dict in available_functions.items()]
|
||||
preset = Preset(
|
||||
name=test_preset_name,
|
||||
user_id=test_user_id,
|
||||
description="A preset for testing the MemGPT client",
|
||||
system=gpt_system.get_system_text(DEFAULT_PRESET),
|
||||
functions_schema=functions_schema,
|
||||
)
|
||||
client.create_preset(preset)
|
||||
|
||||
|
||||
def test_create_agent():
|
||||
wipe_config()
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# ensure user exists
|
||||
@@ -36,8 +56,7 @@ def test_create_agent():
|
||||
agent_config={
|
||||
"user_id": test_user_id,
|
||||
"name": test_agent_name,
|
||||
"persona": constants.DEFAULT_PERSONA,
|
||||
"human": constants.DEFAULT_HUMAN,
|
||||
"preset": test_preset_name,
|
||||
}
|
||||
)
|
||||
print(f"\n\n[1] CREATED AGENT {test_agent_state.id}!!!\n\tmessages={test_agent_state.state['messages']}")
|
||||
@@ -109,5 +128,6 @@ def test_save_load():
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test_create_preset()
|
||||
test_create_agent()
|
||||
test_user_message()
|
||||
|
||||
Reference in New Issue
Block a user