Use ~/.memgpt/config to set questionary defaults in memgpt configure + update tests to use specific config path (#389)
This commit is contained in:
@@ -26,8 +26,11 @@ def configure():
|
||||
|
||||
MemGPTConfig.create_config_dir()
|
||||
|
||||
# Will pre-populate with defaults, or what the user previously set
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# openai credentials
|
||||
use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?").ask()
|
||||
use_openai = questionary.confirm("Do you want to enable MemGPT with Open AI?", default=True).ask()
|
||||
if use_openai:
|
||||
# search for key in enviornment
|
||||
openai_key = os.getenv("OPENAI_API_KEY")
|
||||
@@ -37,7 +40,7 @@ def configure():
|
||||
# openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask()
|
||||
|
||||
# azure credentials
|
||||
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=False).ask()
|
||||
use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask()
|
||||
use_azure_deployment_ids = False
|
||||
if use_azure:
|
||||
# search for key in enviornment
|
||||
@@ -69,30 +72,39 @@ def configure():
|
||||
model_endpoint_options = []
|
||||
if os.getenv("OPENAI_API_BASE") is not None:
|
||||
model_endpoint_options.append(os.getenv("OPENAI_API_BASE"))
|
||||
if use_azure:
|
||||
model_endpoint_options += ["azure"]
|
||||
if use_openai:
|
||||
model_endpoint_options += ["openai"]
|
||||
|
||||
if use_azure:
|
||||
model_endpoint_options += ["azure"]
|
||||
assert len(model_endpoint_options) > 0, "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE."
|
||||
default_endpoint = questionary.select("Select default inference endpoint:", model_endpoint_options).ask()
|
||||
valid_default_model = config.model_endpoint in model_endpoint_options
|
||||
default_endpoint = questionary.select(
|
||||
"Select default inference endpoint:",
|
||||
model_endpoint_options,
|
||||
default=config.model_endpoint if valid_default_model else model_endpoint_options[0],
|
||||
).ask()
|
||||
|
||||
# configure embedding provider
|
||||
embedding_endpoint_options = ["local"] # cannot configure custom endpoint (too confusing)
|
||||
if use_azure:
|
||||
model_endpoint_options += ["azure"]
|
||||
embedding_endpoint_options += ["azure"]
|
||||
if use_openai:
|
||||
model_endpoint_options += ["openai"]
|
||||
default_embedding_endpoint = questionary.select("Select default embedding endpoint:", embedding_endpoint_options).ask()
|
||||
embedding_endpoint_options += ["openai"]
|
||||
valid_default_embedding = config.embedding_model in embedding_endpoint_options
|
||||
default_embedding_endpoint = questionary.select(
|
||||
"Select default embedding endpoint:",
|
||||
embedding_endpoint_options,
|
||||
default=config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1],
|
||||
).ask()
|
||||
|
||||
# configure embedding dimentions
|
||||
default_embedding_dim = 1536
|
||||
default_embedding_dim = config.embedding_dim
|
||||
if default_embedding_endpoint == "local":
|
||||
# HF model uses lower dimentionality
|
||||
default_embedding_dim = 384
|
||||
|
||||
# configure preset
|
||||
default_preset = questionary.select("Select default preset:", preset_options, default=DEFAULT_PRESET).ask()
|
||||
default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask()
|
||||
|
||||
# default model
|
||||
if use_openai or use_azure:
|
||||
@@ -100,7 +112,7 @@ def configure():
|
||||
if use_openai:
|
||||
model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"]
|
||||
default_model = questionary.select(
|
||||
"Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default="gpt-4"
|
||||
"Select default model (recommended: gpt-4):", choices=["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"], default=config.model
|
||||
).ask()
|
||||
else:
|
||||
default_model = "local" # TODO: figure out if this is ok? this is for local endpoint
|
||||
@@ -108,10 +120,10 @@ def configure():
|
||||
# defaults
|
||||
personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()]
|
||||
print(personas)
|
||||
default_persona = questionary.select("Select default persona:", personas, default="sam_pov").ask()
|
||||
default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask()
|
||||
humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()]
|
||||
print(humans)
|
||||
default_human = questionary.select("Select default human:", humans, default="cs_phd").ask()
|
||||
default_human = questionary.select("Select default human:", humans, default=config.default_human).ask()
|
||||
|
||||
# TODO: figure out if we should set a default agent or not
|
||||
default_agent = None
|
||||
@@ -126,11 +138,14 @@ def configure():
|
||||
|
||||
# Configure archival storage backend
|
||||
archival_storage_options = ["local", "postgres"]
|
||||
archival_storage_type = questionary.select("Select storage backend for archival data:", archival_storage_options, default="local").ask()
|
||||
archival_storage_type = questionary.select(
|
||||
"Select storage backend for archival data:", archival_storage_options, default=config.archival_storage_type
|
||||
).ask()
|
||||
archival_storage_uri = None
|
||||
if archival_storage_type == "postgres":
|
||||
archival_storage_uri = questionary.text(
|
||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):"
|
||||
"Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):",
|
||||
default=config.archival_storage_uri if config.archival_storage_uri else "",
|
||||
).ask()
|
||||
|
||||
# TODO: allow configuring embedding model
|
||||
|
||||
@@ -70,7 +70,7 @@ class PostgresStorageConnector(StorageConnector):
|
||||
# create table
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config")
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||
self.db_model = get_db_model(self.table_name)
|
||||
self.engine = create_engine(self.uri)
|
||||
Base.metadata.create_all(self.engine) # Create the table if it doesn't exist
|
||||
|
||||
@@ -20,9 +20,9 @@ def test_postgres_openai():
|
||||
if os.getenv("OPENAI_API_KEY") is None:
|
||||
return # soft pass
|
||||
|
||||
os.environ["MEMGPT_CONFIG_FILE"] = "./config"
|
||||
config = MemGPTConfig()
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
config = MemGPTConfig(archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"))
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
config.archival_storage_uri = config.archival_storage_uri.replace(
|
||||
"postgres://", "postgresql://"
|
||||
@@ -56,10 +56,15 @@ def test_postgres_openai():
|
||||
|
||||
def test_postgres_local():
|
||||
assert os.getenv("PGVECTOR_TEST_DB_URL") is not None
|
||||
os.environ["MEMGPT_CONFIG_FILE"] = "./config"
|
||||
# os.environ["MEMGPT_CONFIG_PATH"] = "./config"
|
||||
|
||||
config = MemGPTConfig(embedding_model="local", embedding_dim=384) # use HF model
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL") # the URI for a postgres DB w/ the pgvector extension
|
||||
config = MemGPTConfig(
|
||||
archival_storage_type="postgres",
|
||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
embedding_model="local",
|
||||
embedding_dim=384, # use HF model
|
||||
)
|
||||
print(config.config_path)
|
||||
assert config.archival_storage_uri is not None
|
||||
config.archival_storage_uri = config.archival_storage_uri.replace(
|
||||
"postgres://", "postgresql://"
|
||||
|
||||
Reference in New Issue
Block a user