fix: errors on getting default config values on fresh install (#1249)
This commit is contained in:
@@ -65,10 +65,12 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
|
||||
model_endpoint_type, model_endpoint = None, None
|
||||
|
||||
# get default
|
||||
default_model_endpoint_type = config.default_llm_config.model_endpoint_type
|
||||
if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [
|
||||
provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"
|
||||
]: # local model
|
||||
default_model_endpoint_type = config.default_llm_config.model_endpoint_type if config.default_embedding_config else None
|
||||
if (
|
||||
config.default_llm_config
|
||||
and config.default_llm_config.model_endpoint_type is not None
|
||||
and config.default_llm_config.model_endpoint_type not in [provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"]
|
||||
): # local model
|
||||
default_model_endpoint_type = "local"
|
||||
|
||||
provider = questionary.select(
|
||||
@@ -270,7 +272,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
|
||||
backend_options = builtins.list(DEFAULT_ENDPOINTS.keys())
|
||||
# assert backend_options_old == backend_options, (backend_options_old, backend_options)
|
||||
default_model_endpoint_type = None
|
||||
if config.default_llm_config.model_endpoint_type in backend_options:
|
||||
if config.default_llm_config and config.default_llm_config.model_endpoint_type in backend_options:
|
||||
# set from previous config
|
||||
default_model_endpoint_type = config.default_llm_config.model_endpoint_type
|
||||
model_endpoint_type = questionary.select(
|
||||
@@ -296,7 +298,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
elif config.default_llm_config.model_endpoint:
|
||||
elif config.default_llm_config and config.default_llm_config.model_endpoint:
|
||||
model_endpoint = questionary.text("Enter default endpoint:", default=config.default_llm_config.model_endpoint).ask()
|
||||
if model_endpoint is None:
|
||||
raise KeyboardInterrupt
|
||||
@@ -425,7 +427,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
other_option_str = "[enter model name manually]"
|
||||
|
||||
# Check if the model we have set already is even in the list (informs our default)
|
||||
valid_model = config.default_llm_config.model in hardcoded_model_options
|
||||
valid_model = config.default_llm_config and config.default_llm_config.model in hardcoded_model_options
|
||||
model = questionary.select(
|
||||
"Select default model (recommended: gpt-4):",
|
||||
choices=hardcoded_model_options + [see_all_option_str, other_option_str],
|
||||
@@ -440,7 +442,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
model = questionary.select(
|
||||
"Select default model (recommended: gpt-4):",
|
||||
choices=fetched_model_options + [other_option_str],
|
||||
default=config.default_llm_config.model if valid_model else fetched_model_options[0],
|
||||
default=config.default_llm_config.model if (valid_model and config.default_llm_config) else fetched_model_options[0],
|
||||
).ask()
|
||||
if model is None:
|
||||
raise KeyboardInterrupt
|
||||
@@ -590,7 +592,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
if model_endpoint_type == "ollama":
|
||||
default_model = (
|
||||
config.default_llm_config.model
|
||||
if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "ollama"
|
||||
if config.default_llm_config and config.default_llm_config.model_endpoint_type == "ollama"
|
||||
else DEFAULT_OLLAMA_MODEL
|
||||
)
|
||||
model = questionary.text(
|
||||
@@ -602,9 +604,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
model = None if len(model) == 0 else model
|
||||
|
||||
default_model = (
|
||||
config.default_llm_config.model
|
||||
if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "vllm"
|
||||
else ""
|
||||
config.default_llm_config.model if config.default_llm_config and config.default_llm_config.model_endpoint_type == "vllm" else ""
|
||||
)
|
||||
|
||||
# vllm needs huggingface model tag
|
||||
@@ -773,7 +773,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials):
|
||||
# configure embedding endpoint
|
||||
|
||||
default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type
|
||||
default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type if config.default_embedding_config else None
|
||||
|
||||
embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = None, None, None, None
|
||||
embedding_provider = questionary.select(
|
||||
@@ -839,7 +839,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
|
||||
|
||||
# get model type
|
||||
default_embedding_model = (
|
||||
config.default_embedding_config.embedding_model if config.default_embedding_config.embedding_model else "BAAI/bge-large-en-v1.5"
|
||||
config.default_embedding_config.embedding_model if config.default_embedding_config else "BAAI/bge-large-en-v1.5"
|
||||
)
|
||||
embedding_model = questionary.text(
|
||||
"Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):",
|
||||
@@ -849,7 +849,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# get model dimentions
|
||||
default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config.embedding_dim else "1024"
|
||||
default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config else "1024"
|
||||
embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask()
|
||||
if embedding_dim is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
Reference in New Issue
Block a user