feat: local auth config (#854)

This commit is contained in:
Charles Packer
2024-01-18 16:43:41 -08:00
committed by GitHub
parent da5a8cdbfe
commit 77e5c43c8f
3 changed files with 33 additions and 15 deletions

View File

@@ -162,7 +162,7 @@ def quickstart(
api_key = os.getenv("OPENAI_API_KEY")
while api_key is None or len(api_key) == 0:
# Ask for API key as input
api_key = questionary.text("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask()
api_key = questionary.password("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask()
credentials.openai_key = api_key
credentials.save()

View File

@@ -16,7 +16,7 @@ from memgpt.log import logger
from memgpt import utils
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.credentials import MemGPTCredentials, SUPPORTED_AUTH_TYPES
from memgpt.constants import MEMGPT_DIR
# from memgpt.agent_store.storage import StorageConnector, TableType
@@ -79,7 +79,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
if openai_api_key is None:
while openai_api_key is None or len(openai_api_key) == 0:
# Ask for API key as input
openai_api_key = questionary.text(
openai_api_key = questionary.password(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
if openai_api_key is None:
@@ -92,7 +92,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
default_input = (
shorten_key_middle(credentials.openai_key) if credentials.openai_key.startswith("sk-") else credentials.openai_key
)
openai_api_key = questionary.text(
openai_api_key = questionary.password(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):",
default=default_input,
).ask()
@@ -314,6 +314,31 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
if model_wrapper is None:
raise KeyboardInterrupt
# ask about local auth
use_local_auth = questionary.confirm(
"Is your LLM endpoint authenticated? (default no)",
default=False,
).ask()
if use_local_auth is None:
raise KeyboardInterrupt
if use_local_auth:
local_auth_type = questionary.select(
"What HTTP authentication method does your endpoint require?",
choices=SUPPORTED_AUTH_TYPES,
default=SUPPORTED_AUTH_TYPES[0],
).ask()
if local_auth_type is None:
raise KeyboardInterrupt
local_auth_key = questionary.password(
"Enter your authentication key:",
).ask()
if local_auth_key is None:
raise KeyboardInterrupt
# credentials = MemGPTCredentials.load()
credentials.openllm_auth_type = local_auth_type
credentials.openllm_key = local_auth_key
credentials.save()
# set: context_window
if str(model) not in LLM_MAX_TOKENS:
# Ask the user to specify the context length
@@ -373,13 +398,13 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
# if we still can't find it, ask for it as input
while openai_api_key is None or len(openai_api_key) == 0:
# Ask for API key as input
openai_api_key = questionary.text(
openai_api_key = questionary.password(
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
).ask()
if openai_api_key is None:
raise KeyboardInterrupt
credentials.openai_key = openai_api_key
config.save()
credentials.save()
embedding_endpoint_type = "openai"
embedding_endpoint = "https://api.openai.com/v1"
@@ -514,7 +539,7 @@ def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredenti
# TODO: allow configuring embedding model
def configure_recall_storage(config: MemGPTConfig):
def configure_recall_storage(config: MemGPTConfig, credentials: MemGPTCredentials):
# Configure recall storage backend
recall_storage_options = ["sqlite", "postgres"]
recall_storage_type = questionary.select(
@@ -602,13 +627,6 @@ def configure():
persona=default_persona,
human=default_human,
agent=default_agent,
# credentials
openai_key=openai_key,
azure_key=azure_creds["azure_key"],
azure_endpoint=azure_creds["azure_endpoint"],
azure_version=azure_creds["azure_version"],
azure_deployment=azure_creds["azure_deployment"], # OK if None
azure_embedding_deployment=azure_creds["azure_embedding_deployment"], # OK if None
# storage
archival_storage_type=archival_storage_type,
archival_storage_uri=archival_storage_uri,

View File

@@ -172,7 +172,7 @@ class Message(Record):
if "tool_call_id" in openai_message_dict:
assert openai_message_dict["tool_call_id"] is None, openai_message_dict
if "tool_calls" in openai_message_dict:
if "tool_calls" in openai_message_dict and openai_message_dict["tool_calls"] is not None:
assert openai_message_dict["role"] == "assistant", openai_message_dict
tool_calls = [