Use ~/.memgpt/config to set questionary defaults in memgpt configure + update tests to use specific config path (#389)

This commit is contained in:
Sarah Wooders
2023-11-09 14:01:11 -08:00
committed by GitHub
parent 354bd520e0
commit ecad9a45ad
3 changed files with 43 additions and 23 deletions

View File

@@ -26,8 +26,11 @@ def configure():
MemGPTConfig.create_config_dir()
# Will pre-populate with defaults, or what the user previously set
config = MemGPTConfig.load()
# openai credentials
use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?").ask()
use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?", default=True).ask()
if use_openai:
# search for key in enviornment
openai_key = os.getenv("OPENAI_API_KEY")
@@ -37,7 +40,7 @@ def configure():
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
# azure credentials
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask()
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
use_azure_deployment_ids = False
if use_azure:
# search for key in enviornment
@@ -69,30 +72,39 @@ def configure():
model_endpoint_options = []
if os.getenv("OPENAI_API_BASE") is not None:
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
if use_azure:
model_endpoint_options += ["azure"]
if use_openai:
model_endpoint_options += ["openai"]
if use_azure:
model_endpoint_options += ["azure"]
assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE."
default_endpoint = questionary.select("Select default inference endpoint:", model_endpoint_options).ask()
valid_default_model = config.model_endpoint in model_endpoint_options
default_endpoint = questionary.select(
"Select default inference endpoint:",
model_endpoint_options,
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
).ask()
# configure embedding provider
embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing)
if use_azure:
model_endpoint_options += ["azure"]
embedding_endpoint_options += ["azure"]
if use_openai:
model_endpoint_options += ["openai"]
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", embedding_endpoint_options).ask()
embedding_endpoint_options += ["openai"]
valid_default_embedding = config.embedding_model in embedding_endpoint_options
default_embedding_endpoint = questionary.select(
"Select default embedding endpoint:",
embedding_endpoint_options,
default=config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1],
).ask()
# configure embedding dimentions
default_embedding_dim = 1536
default_embedding_dim = config.embedding_dim
if default_embedding_endpoint == "local":
# HF model uses lower dimentionality
default_embedding_dim = 384
# configure preset
default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask()
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()
# default model
if use_openai or use_azure:
@@ -100,7 +112,7 @@ def configure():
if use_openai:
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"]
default_model = questionary.select(
"Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default="gpt-4"
"Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default=config.model
).ask()
else:
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint
@@ -108,10 +120,10 @@ def configure():
# defaults
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
print(personas)
default_persona = questionary.select("Select default persona:", personas, default="sam_pov").ask()
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
print(humans)
default_human = questionary.select("Select default human:", humans, default="cs_phd").ask()
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()
# TODO: figure out if we should set a default agent or not
default_agent = None
@@ -126,11 +138,14 @@ def configure():
# Configure archival storage backend
archival_storage_options = ["local", "postgres"]
archival_storage_type = questionary.select("Select storage backend for archival data:", archival_storage_options, default="local").ask()
archival_storage_type = questionary.select(
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
).ask()
archival_storage_uri = None
if archival_storage_type == "postgres":
archival_storage_uri = questionary.text(
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):"
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
default=config.archival_storage_uri if config.archival_storage_uri else "",
).ask()
# TODO: allow configuring embedding model

View File

@@ -70,7 +70,7 @@ class PostgresStorageConnector(StorageConnector):
# create table
self.uri = config.archival_storage_uri
if config.archival_storage_uri is None:
raise ValueError(f"Must specifiy archival_storage_uri in config")
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
self.db_model = get_db_model(self.table_name)
self.engine = create_engine(self.uri)
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist

View File

@@ -20,9 +20,9 @@ def test_postgres_openai():
if os.getenv("OPENAI_API_KEY") is None:
return # soft pass
os.environ["MEMGPT_CONFIG_FILE"] = "./config"
config = MemGPTConfig()
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"
@@ -56,10 +56,15 @@ def test_postgres_openai():
def test_postgres_local():
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
os.environ["MEMGPT_CONFIG_FILE"] = "./config"
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
config = MemGPTConfig(embedding_model="local", embedding_dim=384) # use HF model
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension
config = MemGPTConfig(
archival_storage_type="postgres",
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
embedding_model="local",
embedding_dim=384, # use HF model
)
print(config.config_path)
assert config.archival_storage_uri is not None
config.archival_storage_uri = config.archival_storage_uri.replace(
"postgres://", "postgresql://"