fix: errors on getting default config values on fresh install (#1249)

This commit is contained in:
Sarah Wooders
2024-04-12 20:50:47 -07:00
committed by GitHub
parent 5ece354e0e
commit a3dfb15071

View File

@@ -65,10 +65,12 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
model_endpoint_type, model_endpoint = None, None
# get default
default_model_endpoint_type = config.default_llm_config.model_endpoint_type
if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [
provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"
]: # local model
default_model_endpoint_type = config.default_llm_config.model_endpoint_type if config.default_embedding_config else None
if (
config.default_llm_config
and config.default_llm_config.model_endpoint_type is not None
and config.default_llm_config.model_endpoint_type not in [provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"]
): # local model
default_model_endpoint_type = "local"
provider = questionary.select(
@@ -270,7 +272,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
# assert backend_options_old == backend_options, (backend_options_old, backend_options)
default_model_endpoint_type = None
if config.default_llm_config.model_endpoint_type in backend_options:
if config.default_llm_config and config.default_llm_config.model_endpoint_type in backend_options:
# set from previous config
default_model_endpoint_type = config.default_llm_config.model_endpoint_type
model_endpoint_type = questionary.select(
@@ -296,7 +298,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
elif config.default_llm_config.model_endpoint:
elif config.default_llm_config and config.default_llm_config.model_endpoint:
model_endpoint = questionary.text("Enter default endpoint:", default=config.default_llm_config.model_endpoint).ask()
if model_endpoint is None:
raise KeyboardInterrupt
@@ -425,7 +427,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
other_option_str = "[enter model name manually]"
# Check if the model we have set already is even in the list (informs our default)
valid_model = config.default_llm_config.model in hardcoded_model_options
valid_model = config.default_llm_config and config.default_llm_config.model in hardcoded_model_options
model = questionary.select(
"Select default model (recommended: gpt-4):",
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
@@ -440,7 +442,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
model = questionary.select(
"Select default model (recommended: gpt-4):",
choices=fetched_model_options + [other_option_str],
default=config.default_llm_config.model if valid_model else fetched_model_options[0],
default=config.default_llm_config.model if (valid_model and config.default_llm_config) else fetched_model_options[0],
).ask()
if model is None:
raise KeyboardInterrupt
@@ -590,7 +592,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model_endpoint_type == "ollama":
default_model = (
config.default_llm_config.model
if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "ollama"
if config.default_llm_config and config.default_llm_config.model_endpoint_type == "ollama"
else DEFAULT_OLLAMA_MODEL
)
model = questionary.text(
@@ -602,9 +604,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
model = None if len(model) == 0 else model
default_model = (
config.default_llm_config.model
if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "vllm"
else ""
config.default_llm_config.model if config.default_llm_config and config.default_llm_config.model_endpoint_type == "vllm" else ""
)
# vllm needs huggingface model tag
@@ -773,7 +773,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials):
# configure embedding endpoint
default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type
default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type if config.default_embedding_config else None
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = None, None, None, None
embedding_provider = questionary.select(
@@ -839,7 +839,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
# get model type
default_embedding_model = (
config.default_embedding_config.embedding_model if config.default_embedding_config.embedding_model else "BAAI/bge-large-en-v1.5"
config.default_embedding_config.embedding_model if config.default_embedding_config else "BAAI/bge-large-en-v1.5"
)
embedding_model = questionary.text(
"Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):",
@@ -849,7 +849,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
raise KeyboardInterrupt
# get model dimentions
default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config.embedding_dim else "1024"
default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config else "1024"
embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask()
if embedding_dim is None:
raise KeyboardInterrupt