From ecad9a45ad04702f6b33d6e27769d445f7b80d31 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Thu, 9 Nov 2023 14:01:11 -0800 Subject: [PATCH] Use `~/.memgpt/config` to set questionary defaults in `memgpt configure` + update tests to use specific config path (#389) --- memgpt/cli/cli_config.py | 47 ++++++++++++++++++++++++++-------------- memgpt/connectors/db.py | 2 +- tests/test_storage.py | 17 ++++++++++----- 3 files changed, 43 insertions(+), 23 deletions(-) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index e203ea09..c68dbbb4 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -26,8 +26,11 @@ def configure(): MemGPTConfig.create_config_dir() + # Will pre-populate with defaults, or what the user previously set + config = MemGPTConfig.load() + # openai credentials - use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?").ask() + use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?", default=True).ask() if use_openai: # search for key in enviornment openai_key = os.getenv("OPENAI_API_KEY") @@ -37,7 +40,7 @@ def configure(): # openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask() # azure credentials - use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask() + use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask() use_azure_deployment_ids = False if use_azure: # search for key in enviornment @@ -69,30 +72,39 @@ def configure(): model_endpoint_options = [] if os.getenv("OPENAI_API_BASE") is not None: model_endpoint_options.append(os.getenv("OPENAI_API_BASE")) - if use_azure: - model_endpoint_options += ["azure"] if use_openai: model_endpoint_options += ["openai"] - + if use_azure: + model_endpoint_options += ["azure"] assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE." - default_endpoint = questionary.select("Select default inference endpoint:", model_endpoint_options).ask() + valid_default_model = config.model_endpoint in model_endpoint_options + default_endpoint = questionary.select( + "Select default inference endpoint:", + model_endpoint_options, + default=config.model_endpoint if valid_default_model else model_endpoint_options[0], + ).ask() # configure embedding provider embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing) if use_azure: - model_endpoint_options += ["azure"] + embedding_endpoint_options += ["azure"] if use_openai: - model_endpoint_options += ["openai"] - default_embedding_endpoint = questionary.select("Select default embedding endpoint:", embedding_endpoint_options).ask() + embedding_endpoint_options += ["openai"] + valid_default_embedding = config.embedding_model in embedding_endpoint_options + default_embedding_endpoint = questionary.select( + "Select default embedding endpoint:", + embedding_endpoint_options, + default=config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1], + ).ask() # configure embedding dimentions - default_embedding_dim = 1536 + default_embedding_dim = config.embedding_dim if default_embedding_endpoint == "local": # HF model uses lower dimentionality default_embedding_dim = 384 # configure preset - default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask() + default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask() # default model if use_openai or use_azure: @@ -100,7 +112,7 @@ def configure(): if use_openai: model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"] default_model = questionary.select( - "Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default="gpt-4" + "Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default=config.model ).ask() else: default_model = "local" # TODO: figure out if this is ok? this is for local endpoint @@ -108,10 +120,10 @@ def configure(): # defaults personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()] print(personas) - default_persona = questionary.select("Select default persona:", personas, default="sam_pov").ask() + default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask() humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()] print(humans) - default_human = questionary.select("Select default human:", humans, default="cs_phd").ask() + default_human = questionary.select("Select default human:", humans, default=config.default_human).ask() # TODO: figure out if we should set a default agent or not default_agent = None @@ -126,11 +138,14 @@ def configure(): # Configure archival storage backend archival_storage_options = ["local", "postgres"] - archival_storage_type = questionary.select("Select storage backend for archival data:", archival_storage_options, default="local").ask() + archival_storage_type = questionary.select( + "Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type + ).ask() archival_storage_uri = None if archival_storage_type == "postgres": archival_storage_uri = questionary.text( - "Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):" + "Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):", + default=config.archival_storage_uri if config.archival_storage_uri else "", ).ask() # TODO: allow configuring embedding model diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index 86e25ef4..9ae873b8 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -70,7 +70,7 @@ class PostgresStorageConnector(StorageConnector): # create table self.uri = config.archival_storage_uri if config.archival_storage_uri is None: - raise ValueError(f"Must specifiy archival_storage_uri in config") + raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}") self.db_model = get_db_model(self.table_name) self.engine = create_engine(self.uri) Base.metadata.create_all(self.engine) # Create the table if it doesn't exist diff --git a/tests/test_storage.py b/tests/test_storage.py index 1518088a..fc1286b6 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -20,9 +20,9 @@ def test_postgres_openai(): if os.getenv("OPENAI_API_KEY") is None: return # soft pass - os.environ["MEMGPT_CONFIG_FILE"] = "./config" - config = MemGPTConfig() - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension + # os.environ["MEMGPT_CONFIG_PATH"] = "./config" + config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL")) + print(config.config_path) assert config.archival_storage_uri is not None config.archival_storage_uri = config.archival_storage_uri.replace( "postgres://", "postgresql://" @@ -56,10 +56,15 @@ def test_postgres_openai(): def test_postgres_local(): assert os.getenv("PGVECTOR_TEST_DB_URL") is not None - os.environ["MEMGPT_CONFIG_FILE"] = "./config" + # os.environ["MEMGPT_CONFIG_PATH"] = "./config" - config = MemGPTConfig(embedding_model="local", embedding_dim=384) # use HF model - config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension + config = MemGPTConfig( + archival_storage_type="postgres", + archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), + embedding_model="local", + embedding_dim=384, # use HF model + ) + print(config.config_path) assert config.archival_storage_uri is not None config.archival_storage_uri = config.archival_storage_uri.replace( "postgres://", "postgresql://"