diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 2fabe556..783e9a42 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -118,6 +118,10 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) ) else: credentials.azure_key = azure_creds["azure_key"] + credentials.azure_version = azure_creds["azure_version"] + credentials.azure_endpoint = azure_creds["azure_endpoint"] + if "azure_deployment" in azure_creds: + credentials.azure_deployment = azure_creds["azure_deployment"] credentials.azure_embedding_version = azure_creds["azure_embedding_version"] credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"] if "azure_embedding_deployment" in azure_creds: @@ -203,7 +207,7 @@ def get_model_options( model_options = [obj["id"] for obj in fetched_model_options_response["data"]] elif model_endpoint_type == "azure": - if credentials.azure_version is None: + if credentials.azure_key is None: raise ValueError("Missing Azure key") if credentials.azure_version is None: raise ValueError("Missing Azure version") @@ -469,8 +473,6 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden credentials.azure_key = azure_creds["azure_key"] credentials.azure_version = azure_creds["azure_version"] credentials.azure_embedding_endpoint = azure_creds["azure_embedding_endpoint"] - if "azure_deployment" in azure_creds: - credentials.azure_deployment = azure_creds["azure_deployment"] credentials.save() embedding_endpoint_type = "azure" @@ -627,7 +629,6 @@ def configure(): # check credentials credentials = MemGPTCredentials.load() openai_key = get_openai_credentials() - get_azure_credentials() MemGPTConfig.create_config_dir()