From 77e5c43c8f0585f249db89badfcd1dbda2fb06be Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 18 Jan 2024 16:43:41 -0800 Subject: [PATCH] feat: local auth config (#854) --- memgpt/cli/cli.py | 2 +- memgpt/cli/cli_config.py | 44 ++++++++++++++++++++++++++++------------ memgpt/data_types.py | 2 +- 3 files changed, 33 insertions(+), 15 deletions(-) diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 409da048..0045cd43 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -162,7 +162,7 @@ def quickstart( api_key = os.getenv("OPENAI_API_KEY") while api_key is None or len(api_key) == 0: # Ask for API key as input - api_key = questionary.text("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask() + api_key = questionary.password("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask() credentials.openai_key = api_key credentials.save() diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 8574cafb..85a7f7f3 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -16,7 +16,7 @@ from memgpt.log import logger from memgpt import utils from memgpt.config import MemGPTConfig -from memgpt.credentials import MemGPTCredentials +from memgpt.credentials import MemGPTCredentials, SUPPORTED_AUTH_TYPES from memgpt.constants import MEMGPT_DIR # from memgpt.agent_store.storage import StorageConnector, TableType @@ -79,7 +79,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) if openai_api_key is None: while openai_api_key is None or len(openai_api_key) == 0: # Ask for API key as input - openai_api_key = questionary.text( + openai_api_key = questionary.password( "Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):" ).ask() if openai_api_key is None: @@ -92,7 +92,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials) default_input = ( shorten_key_middle(credentials.openai_key) if credentials.openai_key.startswith("sk-") else credentials.openai_key ) - openai_api_key = questionary.text( + openai_api_key = questionary.password( "Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):", default=default_input, ).ask() @@ -314,6 +314,31 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_ if model_wrapper is None: raise KeyboardInterrupt + # ask about local auth + use_local_auth = questionary.confirm( + "Is your LLM endpoint authenticated? (default no)", + default=False, + ).ask() + if use_local_auth is None: + raise KeyboardInterrupt + if use_local_auth: + local_auth_type = questionary.select( + "What HTTP authentication method does your endpoint require?", + choices=SUPPORTED_AUTH_TYPES, + default=SUPPORTED_AUTH_TYPES[0], + ).ask() + if local_auth_type is None: + raise KeyboardInterrupt + local_auth_key = questionary.password( + "Enter your authentication key:", + ).ask() + if local_auth_key is None: + raise KeyboardInterrupt + # credentials = MemGPTCredentials.load() + credentials.openllm_auth_type = local_auth_type + credentials.openllm_key = local_auth_key + credentials.save() + # set: context_window if str(model) not in LLM_MAX_TOKENS: # Ask the user to specify the context length @@ -373,13 +398,13 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden # if we still can't find it, ask for it as input while openai_api_key is None or len(openai_api_key) == 0: # Ask for API key as input - openai_api_key = questionary.text( + openai_api_key = questionary.password( "Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):" ).ask() if openai_api_key is None: raise KeyboardInterrupt credentials.openai_key = openai_api_key - config.save() + credentials.save() embedding_endpoint_type = "openai" embedding_endpoint = "https://api.openai.com/v1" @@ -514,7 +539,7 @@ def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredenti # TODO: allow configuring embedding model -def configure_recall_storage(config: MemGPTConfig): +def configure_recall_storage(config: MemGPTConfig, credentials: MemGPTCredentials): # Configure recall storage backend recall_storage_options = ["sqlite", "postgres"] recall_storage_type = questionary.select( @@ -602,13 +627,6 @@ def configure(): persona=default_persona, human=default_human, agent=default_agent, - # credentials - openai_key=openai_key, - azure_key=azure_creds["azure_key"], - azure_endpoint=azure_creds["azure_endpoint"], - azure_version=azure_creds["azure_version"], - azure_deployment=azure_creds["azure_deployment"], # OK if None - azure_embedding_deployment=azure_creds["azure_embedding_deployment"], # OK if None # storage archival_storage_type=archival_storage_type, archival_storage_uri=archival_storage_uri, diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 8eaabe16..9fd6c6bd 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -172,7 +172,7 @@ class Message(Record): if "tool_call_id" in openai_message_dict: assert openai_message_dict["tool_call_id"] is None, openai_message_dict - if "tool_calls" in openai_message_dict: + if "tool_calls" in openai_message_dict and openai_message_dict["tool_calls"] is not None: assert openai_message_dict["role"] == "assistant", openai_message_dict tool_calls = [