feat: local auth config (#854)
This commit is contained in:
@@ -162,7 +162,7 @@ def quickstart(
|
||||
api_key = os.getenv("OPENAI_API_KEY")
|
||||
while api_key is None or len(api_key) == 0:
|
||||
# Ask for API key as input
|
||||
api_key = questionary.text("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask()
|
||||
api_key = questionary.password("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask()
|
||||
credentials.openai_key = api_key
|
||||
credentials.save()
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from memgpt.log import logger
|
||||
from memgpt import utils
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.credentials import MemGPTCredentials, SUPPORTED_AUTH_TYPES
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
|
||||
# from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
@@ -79,7 +79,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
|
||||
if openai_api_key is None:
|
||||
while openai_api_key is None or len(openai_api_key) == 0:
|
||||
# Ask for API key as input
|
||||
openai_api_key = questionary.text(
|
||||
openai_api_key = questionary.password(
|
||||
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
|
||||
).ask()
|
||||
if openai_api_key is None:
|
||||
@@ -92,7 +92,7 @@ def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials)
|
||||
default_input = (
|
||||
shorten_key_middle(credentials.openai_key) if credentials.openai_key.startswith("sk-") else credentials.openai_key
|
||||
)
|
||||
openai_api_key = questionary.text(
|
||||
openai_api_key = questionary.password(
|
||||
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):",
|
||||
default=default_input,
|
||||
).ask()
|
||||
@@ -314,6 +314,31 @@ def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_
|
||||
if model_wrapper is None:
|
||||
raise KeyboardInterrupt
|
||||
|
||||
# ask about local auth
|
||||
use_local_auth = questionary.confirm(
|
||||
"Is your LLM endpoint authenticated? (default no)",
|
||||
default=False,
|
||||
).ask()
|
||||
if use_local_auth is None:
|
||||
raise KeyboardInterrupt
|
||||
if use_local_auth:
|
||||
local_auth_type = questionary.select(
|
||||
"What HTTP authentication method does your endpoint require?",
|
||||
choices=SUPPORTED_AUTH_TYPES,
|
||||
default=SUPPORTED_AUTH_TYPES[0],
|
||||
).ask()
|
||||
if local_auth_type is None:
|
||||
raise KeyboardInterrupt
|
||||
local_auth_key = questionary.password(
|
||||
"Enter your authentication key:",
|
||||
).ask()
|
||||
if local_auth_key is None:
|
||||
raise KeyboardInterrupt
|
||||
# credentials = MemGPTCredentials.load()
|
||||
credentials.openllm_auth_type = local_auth_type
|
||||
credentials.openllm_key = local_auth_key
|
||||
credentials.save()
|
||||
|
||||
# set: context_window
|
||||
if str(model) not in LLM_MAX_TOKENS:
|
||||
# Ask the user to specify the context length
|
||||
@@ -373,13 +398,13 @@ def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCreden
|
||||
# if we still can't find it, ask for it as input
|
||||
while openai_api_key is None or len(openai_api_key) == 0:
|
||||
# Ask for API key as input
|
||||
openai_api_key = questionary.text(
|
||||
openai_api_key = questionary.password(
|
||||
"Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):"
|
||||
).ask()
|
||||
if openai_api_key is None:
|
||||
raise KeyboardInterrupt
|
||||
credentials.openai_key = openai_api_key
|
||||
config.save()
|
||||
credentials.save()
|
||||
|
||||
embedding_endpoint_type = "openai"
|
||||
embedding_endpoint = "https://api.openai.com/v1"
|
||||
@@ -514,7 +539,7 @@ def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredenti
|
||||
# TODO: allow configuring embedding model
|
||||
|
||||
|
||||
def configure_recall_storage(config: MemGPTConfig):
|
||||
def configure_recall_storage(config: MemGPTConfig, credentials: MemGPTCredentials):
|
||||
# Configure recall storage backend
|
||||
recall_storage_options = ["sqlite", "postgres"]
|
||||
recall_storage_type = questionary.select(
|
||||
@@ -602,13 +627,6 @@ def configure():
|
||||
persona=default_persona,
|
||||
human=default_human,
|
||||
agent=default_agent,
|
||||
# credentials
|
||||
openai_key=openai_key,
|
||||
azure_key=azure_creds["azure_key"],
|
||||
azure_endpoint=azure_creds["azure_endpoint"],
|
||||
azure_version=azure_creds["azure_version"],
|
||||
azure_deployment=azure_creds["azure_deployment"], # OK if None
|
||||
azure_embedding_deployment=azure_creds["azure_embedding_deployment"], # OK if None
|
||||
# storage
|
||||
archival_storage_type=archival_storage_type,
|
||||
archival_storage_uri=archival_storage_uri,
|
||||
|
||||
@@ -172,7 +172,7 @@ class Message(Record):
|
||||
if "tool_call_id" in openai_message_dict:
|
||||
assert openai_message_dict["tool_call_id"] is None, openai_message_dict
|
||||
|
||||
if "tool_calls" in openai_message_dict:
|
||||
if "tool_calls" in openai_message_dict and openai_message_dict["tool_calls"] is not None:
|
||||
assert openai_message_dict["role"] == "assistant", openai_message_dict
|
||||
|
||||
tool_calls = [
|
||||
|
||||
Reference in New Issue
Block a user