From 3de3fb83153d33acfbb28d70b12d3e5eaff4e676 Mon Sep 17 00:00:00 2001 From: cpacker Date: Tue, 2 Jan 2024 13:27:19 -0800 Subject: [PATCH] don't allow bad endpoint addresses during memgpt configure --- memgpt/cli/cli_config.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 7718a7d3..92dace38 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -123,16 +123,21 @@ def configure_llm_endpoint(config: MemGPTConfig): if model_endpoint_type in DEFAULT_ENDPOINTS: default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type] model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask() + while not utils.is_valid_url(model_endpoint): + typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) + model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask() elif config.model_endpoint: model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask() + while not utils.is_valid_url(model_endpoint): + typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) + model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask() else: # default_model_endpoint = None model_endpoint = None - while not model_endpoint: + model_endpoint = questionary.text("Enter default endpoint:").ask() + while not utils.is_valid_url(model_endpoint): + typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) model_endpoint = questionary.text("Enter default endpoint:").ask() - if "http://" not in model_endpoint and "https://" not in model_endpoint: - typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) - model_endpoint = None else: model_endpoint = default_model_endpoint assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set." @@ -330,9 +335,9 @@ def configure_embedding_endpoint(config: MemGPTConfig): # get endpoint embedding_endpoint = questionary.text("Enter default endpoint:").ask() - if "http://" not in embedding_endpoint and "https://" not in embedding_endpoint: + while not utils.is_valid_url(embedding_endpoint): typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) - embedding_endpoint = None + embedding_endpoint = questionary.text("Enter default endpoint:").ask() # get model type default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5"