diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 855f8fad..3f5c440a 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -65,10 +65,12 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) model_endpoint_type, model_endpoint = None, None # get default - default_model_endpoint_type = config.default_llm_config.model_endpoint_type - if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [ - provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local" - ]: # local model + default_model_endpoint_type = config.default_llm_config.model_endpoint_type if config.default_embedding_config else None + if ( + config.default_llm_config + and config.default_llm_config.model_endpoint_type is not None + and config.default_llm_config.model_endpoint_type not in [provider for provider in LLM_API_PROVIDER_OPTIONS if provider != "local"] + ): # local model default_model_endpoint_type = "local" provider = questionary.select( @@ -270,7 +272,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) backend_options = builtins.list(DEFAULT_ENDPOINTS.keys()) # assert backend_options_old == backend_options, (backend_options_old, backend_options) default_model_endpoint_type = None - if config.default_llm_config.model_endpoint_type in backend_options: + if config.default_llm_config and config.default_llm_config.model_endpoint_type in backend_options: # set from previous config default_model_endpoint_type = config.default_llm_config.model_endpoint_type model_endpoint_type = questionary.select( @@ -296,7 +298,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask() if model_endpoint is None: raise KeyboardInterrupt - elif config.default_llm_config.model_endpoint: + elif config.default_llm_config and config.default_llm_config.model_endpoint: model_endpoint = questionary.text("Enter default endpoint:", default=config.default_llm_config.model_endpoint).ask() if model_endpoint is None: raise KeyboardInterrupt @@ -425,7 +427,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ other_option_str = "[enter model name manually]" # Check if the model we have set already is even in the list (informs our default) - valid_model = config.default_llm_config.model in hardcoded_model_options + valid_model = config.default_llm_config and config.default_llm_config.model in hardcoded_model_options model = questionary.select( "Select default model (recommended: gpt-4):", choices=hardcoded_model_options + [see_all_option_str, other_option_str], @@ -440,7 +442,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ model = questionary.select( "Select default model (recommended: gpt-4):", choices=fetched_model_options + [other_option_str], - default=config.default_llm_config.model if valid_model else fetched_model_options[0], + default=config.default_llm_config.model if (valid_model and config.default_llm_config) else fetched_model_options[0], ).ask() if model is None: raise KeyboardInterrupt @@ -590,7 +592,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ if model_endpoint_type == "ollama": default_model = ( config.default_llm_config.model - if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "ollama" + if config.default_llm_config and config.default_llm_config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL ) model = questionary.text( @@ -602,9 +604,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ model = None if len(model) == 0 else model default_model = ( - config.default_llm_config.model - if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "vllm" - else "" + config.default_llm_config.model if config.default_llm_config and config.default_llm_config.model_endpoint_type == "vllm" else "" ) # vllm needs huggingface model tag @@ -773,7 +773,7 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials): # configure embedding endpoint - default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type + default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type if config.default_embedding_config else None embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = None, None, None, None embedding_provider = questionary.select( @@ -839,7 +839,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden # get model type default_embedding_model = ( - config.default_embedding_config.embedding_model if config.default_embedding_config.embedding_model else "BAAI/bge-large-en-v1.5" + config.default_embedding_config.embedding_model if config.default_embedding_config else "BAAI/bge-large-en-v1.5" ) embedding_model = questionary.text( "Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):", @@ -849,7 +849,7 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden raise KeyboardInterrupt # get model dimentions - default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config.embedding_dim else "1024" + default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config else "1024" embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask() if embedding_dim is None: raise KeyboardInterrupt