From ec2bda49661f2433570a26bfe38acba3d4eeb6cd Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 14 Nov 2023 15:58:19 -0800 Subject: [PATCH] Refactor config + determine LLM via `config.model_endpoint_type` (#422) * mark depricated API section * CLI bug fixes for azure * check azure before running * Update README.md * Update README.md * bug fix with persona loading * remove print * make errors for cli flags more clear * format * fix imports * fix imports * add prints * update lock * update config fields * cleanup config loading * commit * remove asserts * refactor configure * put into different functions * add embedding default * pass in config * fixes * allow overriding openai embedding endpoint * black * trying to patch tests (some circular import errors) * update flags and docs * patched support for local llms using endpoint and endpoint type passed via configs, not env vars * missing files * fix naming * fix import * fix two runtime errors * patch ollama typo, move ollama model question pre-wrapper, modify question phrasing to include link to readthedocs, also have a default ollama model that has a tag included * disable debug messages * made error message for failed load more informative * don't print dynamic linking function warning unless --debug * updated tests to work with new cli workflow (disabled openai config test for now) * added skips for tests when vars are missing * update bad arg * revise test to soft pass on empty string too * don't run configure twice * extend timeout (try to pass against nltk download) * update defaults * typo with endpoint type default * patch runtime errors for when model is None * catching another case of 'x in model' when model is None (preemptively) * allow overrides to local llm related config params * made model wrapper selection from a list vs raw input * update test for select instead of input * Fixed bug in endpoint when using local->openai selection, also added validation loop to manual endpoint entry * updated error messages to be more informative with links to readthedocs * add back gpt3.5-turbo --------- Co-authored-by: cpacker --- docs/config.md | 10 +- memgpt/agent.py | 96 ++++--- memgpt/cli/cli.py | 66 +++-- memgpt/cli/cli_config.py | 331 +++++++++++++--------- memgpt/config.py | 253 +++++++++-------- memgpt/embeddings.py | 4 +- memgpt/functions/function_sets/base.py | 2 +- memgpt/functions/function_sets/extras.py | 4 +- memgpt/local_llm/chat_completion_proxy.py | 81 +++--- memgpt/local_llm/constants.py | 14 + memgpt/local_llm/koboldcpp/api.py | 14 +- memgpt/local_llm/llamacpp/api.py | 14 +- memgpt/local_llm/lmstudio/api.py | 12 +- memgpt/local_llm/ollama/api.py | 19 +- memgpt/local_llm/utils.py | 15 + memgpt/local_llm/webui/api.py | 10 +- memgpt/main.py | 2 +- memgpt/openai_tools.py | 39 ++- memgpt/presets/presets.py | 2 +- tests/test_cli.py | 2 +- tests/test_load_archival.py | 8 +- tests/test_storage.py | 13 +- tests/utils.py | 52 ++-- 23 files changed, 628 insertions(+), 435 deletions(-) create mode 100644 memgpt/local_llm/constants.py diff --git a/docs/config.md b/docs/config.md index 1a390495..b8e70b29 100644 --- a/docs/config.md +++ b/docs/config.md @@ -6,15 +6,21 @@ The `memgpt run` command supports the following optional flags (if set, will ove * `--agent`: (str) Name of agent to create or to resume chatting with. * `--human`: (str) Name of the human to run the agent with. * `--persona`: (str) Name of agent persona to use. -* `--model`: (str) LLM model to run [gpt-4, gpt-3.5]. +* `--model`: (str) LLM model to run (e.g. `gpt-4`, `dolphin_xxx`) * `--preset`: (str) MemGPT preset to run agent with. * `--first`: (str) Allow user to sent the first message. * `--debug`: (bool) Show debug logs (default=False) * `--no-verify`: (bool) Bypass message verification (default=False) * `--yes`/`-y`: (bool) Skip confirmation prompt and use defaults (default=False) +You can override the parameters you set with `memgpt configure` with the following additional flags specific to local LLMs: +* `--model-wrapper`: (str) Model wrapper used by backend (e.g. `airoboros_xxx`) +* `--model-endpoint-type`: (str) Model endpoint backend type (e.g. lmstudio, ollama) +* `--model-endpoint`: (str) Model endpoint url (e.g. `localhost:5000`) +* `--context-window`: (int) Size of model context window (specific to model type) + #### Updating the config location -You can override the location of the config path by setting the enviornment variable `MEMGPT_CONFIG_PATH`: +You can override the location of the config path by setting the environment variable `MEMGPT_CONFIG_PATH`: ``` export MEMGPT_CONFIG_PATH=/my/custom/path/config # make sure this is a file, not a directory ``` diff --git a/memgpt/agent.py b/memgpt/agent.py index e15c9d84..7aa2d63c 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -9,6 +9,7 @@ from memgpt.config import AgentConfig from .system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from .memory import CoreMemory as Memory, summarize_messages from .openai_tools import completions_with_backoff as create +from memgpt.openai_tools import chat_completion_with_backoff from .utils import get_local_time, parse_json, united_diff, printd, count_tokens, get_schema_diff from .constants import ( FIRST_MESSAGE_ATTEMPTS, @@ -73,7 +74,7 @@ def initialize_message_sequence( first_user_message = get_login_event() # event letting MemGPT know the user just logged in if include_initial_boot_message: - if "gpt-3.5" in model: + if model is not None and "gpt-3.5" in model: initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35") else: initial_boot_messages = get_initial_boot_messages("startup_with_send_message") @@ -96,37 +97,6 @@ def initialize_message_sequence( return messages -def get_ai_reply( - model, - message_sequence, - functions, - function_call="auto", - context_window=None, -): - try: - response = create( - model=model, - context_window=context_window, - messages=message_sequence, - functions=functions, - function_call=function_call, - ) - - # special case for 'length' - if response.choices[0].finish_reason == "length": - raise Exception("Finish reason was length (maximum context length)") - - # catches for soft errors - if response.choices[0].finish_reason not in ["stop", "function_call"]: - raise Exception(f"API call finish with bad finish reason: {response}") - - # unpack with response.choices[0].message.content - return response - - except Exception as e: - raise e - - class Agent(object): def __init__( self, @@ -310,7 +280,7 @@ class Agent(object): json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory. if not json_files: print(f"/load error: no .json checkpoint files found") - raise ValueError(f"Cannot load {agent_name}: does not exist in {directory}") + raise ValueError(f"Cannot load {agent_name} - no saved checkpoints found in {directory}") # Sort files based on modified timestamp, with the latest file being the first. filename = max(json_files, key=os.path.getmtime) @@ -360,7 +330,7 @@ class Agent(object): # NOTE to handle old configs, instead of erroring here let's just warn # raise ValueError(error_message) - print(error_message) + printd(error_message) linked_function_set[f_name] = linked_function messages = state["messages"] @@ -602,8 +572,7 @@ class Agent(object): printd(f"This is the first message. Running extra verifier on AI response.") counter = 0 while True: - response = get_ai_reply( - model=self.model, + response = self.get_ai_reply( message_sequence=input_message_sequence, functions=self.functions, context_window=None if self.config.context_window is None else int(self.config.context_window), @@ -616,8 +585,7 @@ class Agent(object): raise Exception(f"Hit first message retry limit ({first_message_retry_limit})") else: - response = get_ai_reply( - model=self.model, + response = self.get_ai_reply( message_sequence=input_message_sequence, functions=self.functions, context_window=None if self.config.context_window is None else int(self.config.context_window), @@ -785,3 +753,55 @@ class Agent(object): # Check if it's been more than pause_heartbeats_minutes since pause_heartbeats_start elapsed_time = datetime.datetime.now() - self.pause_heartbeats_start return elapsed_time.total_seconds() < self.pause_heartbeats_minutes * 60 + + def get_ai_reply( + self, + message_sequence, + function_call="auto", + ): + """Get response from LLM API""" + + # TODO: Legacy code - delete + if self.config is None: + try: + response = create( + model=self.model, + context_window=self.context_window, + messages=message_sequence, + functions=self.functions, + function_call=function_call, + ) + + # special case for 'length' + if response.choices[0].finish_reason == "length": + raise Exception("Finish reason was length (maximum context length)") + + # catches for soft errors + if response.choices[0].finish_reason not in ["stop", "function_call"]: + raise Exception(f"API call finish with bad finish reason: {response}") + + # unpack with response.choices[0].message.content + return response + except Exception as e: + raise e + + try: + response = chat_completion_with_backoff( + agent_config=self.config, + model=self.model, # TODO: remove (is redundant) + messages=message_sequence, + functions=self.functions, + function_call=function_call, + ) + # special case for 'length' + if response.choices[0].finish_reason == "length": + raise Exception("Finish reason was length (maximum context length)") + + # catches for soft errors + if response.choices[0].finish_reason not in ["stop", "function_call"]: + raise Exception(f"API call finish with bad finish reason: {response}") + + # unpack with response.choices[0].message.content + return response + except Exception as e: + raise e diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 12b02b89..407ae2e7 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -1,4 +1,5 @@ import typer +import json import sys import io import logging @@ -35,16 +36,21 @@ def run( persona: str = typer.Option(None, help="Specify persona"), agent: str = typer.Option(None, help="Specify agent save file"), human: str = typer.Option(None, help="Specify human"), - model: str = typer.Option(None, help="Specify the LLM model"), preset: str = typer.Option(None, help="Specify preset"), + # model flags + model: str = typer.Option(None, help="Specify the LLM model"), + model_wrapper: str = typer.Option(None, help="Specify the LLM model wrapper"), + model_endpoint: str = typer.Option(None, help="Specify the LLM model endpoint"), + model_endpoint_type: str = typer.Option(None, help="Specify the LLM model endpoint type"), + context_window: int = typer.Option( + None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)" + ), + # other first: bool = typer.Option(False, "--first", help="Use --first to send the first message in the sequence"), strip_ui: bool = typer.Option(False, "--strip_ui", help="Remove all the bells and whistles in CLI output (helpful for testing)"), debug: bool = typer.Option(False, "--debug", help="Use --debug to enable debugging output"), no_verify: bool = typer.Option(False, "--no_verify", help="Bypass message verification"), yes: bool = typer.Option(False, "-y", help="Skip confirmation prompt and use defaults"), - context_window: int = typer.Option( - None, "--context_window", help="The context window of the LLM you are using (e.g. 8k for most Mistral 7B variants)" - ), ): """Start chatting with an MemGPT agent @@ -99,11 +105,6 @@ def run( set_global_service_context(service_context) sys.stdout = original_stdout - # overwrite the context_window if specified - if context_window is not None and int(context_window) != int(config.context_window): - typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW) - config.context_window = str(context_window) - # create agent config if agent and AgentConfig.exists(agent): # use existing agent typer.secho(f"Using existing agent {agent}", fg=typer.colors.GREEN) @@ -121,10 +122,34 @@ def run( typer.secho(f"Warning: Overriding existing human {agent_config.human} with {human}", fg=typer.colors.YELLOW) agent_config.human = human # raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}") + + # Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window) if model and model != agent_config.model: typer.secho(f"Warning: Overriding existing model {agent_config.model} with {model}", fg=typer.colors.YELLOW) agent_config.model = model - # raise ValueError(f"Cannot override {agent_config.name} existing model {agent_config.model} with {model}") + if context_window is not None and int(context_window) != agent_config.context_window: + typer.secho( + f"Warning: Overriding existing context window {agent_config.context_window} with {context_window}", fg=typer.colors.YELLOW + ) + agent_config.context_window = context_window + if model_wrapper and model_wrapper != agent_config.model_wrapper: + typer.secho( + f"Warning: Overriding existing model wrapper {agent_config.model_wrapper} with {model_wrapper}", fg=typer.colors.YELLOW + ) + agent_config.model_wrapper = model_wrapper + if model_endpoint and model_endpoint != agent_config.model_endpoint: + typer.secho( + f"Warning: Overriding existing model endpoint {agent_config.model_endpoint} with {model_endpoint}", fg=typer.colors.YELLOW + ) + agent_config.model_endpoint = model_endpoint + if model_endpoint_type and model_endpoint_type != agent_config.model_endpoint_type: + typer.secho( + f"Warning: Overriding existing model endpoint type {agent_config.model_endpoint_type} with {model_endpoint_type}", + fg=typer.colors.YELLOW, + ) + agent_config.model_endpoint_type = model_endpoint_type + + # Update the agent config with any overrides agent_config.save() # load existing agent @@ -133,17 +158,17 @@ def run( # create new agent config: override defaults with args if provided typer.secho("Creating new agent...", fg=typer.colors.GREEN) agent_config = AgentConfig( - name=agent if agent else None, - persona=persona if persona else config.default_persona, - human=human if human else config.default_human, - model=model if model else config.model, - context_window=context_window if context_window else config.context_window, - preset=preset if preset else config.preset, + name=agent, + persona=persona, + human=human, + preset=preset, + model=model, + model_wrapper=model_wrapper, + model_endpoint_type=model_endpoint_type, + model_endpoint=model_endpoint, + context_window=context_window, ) - ## attach data source to agent - # agent_config.attach_data_source(data_source) - # TODO: allow configrable state manager (only local is supported right now) persistence_manager = LocalStateManager(agent_config) # TODO: insert dataset/pre-fill @@ -162,6 +187,9 @@ def run( persistence_manager, ) + # pretty print agent config + printd(json.dumps(vars(agent_config), indent=4, sort_keys=True)) + # start event loop from memgpt.main import run_agent_loop diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 8053779b..2433efb1 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -1,3 +1,4 @@ +import builtins import questionary import openai from prettytable import PrettyTable @@ -11,126 +12,118 @@ from memgpt import utils import memgpt.humans.humans as humans import memgpt.personas.personas as personas -from memgpt.config import MemGPTConfig, AgentConfig +from memgpt.config import MemGPTConfig, AgentConfig, Config from memgpt.constants import MEMGPT_DIR from memgpt.connectors.storage import StorageConnector from memgpt.constants import LLM_MAX_TOKENS +from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME +from memgpt.local_llm.utils import get_available_wrappers app = typer.Typer() -@app.command() -def configure(): - """Updates default MemGPT configurations""" +def get_azure_credentials(): + azure_key = os.getenv("AZURE_OPENAI_KEY") + azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") + azure_version = os.getenv("AZURE_OPENAI_VERSION") + azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") + azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT") + return azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment - from memgpt.presets.presets import DEFAULT_PRESET, preset_options - MemGPTConfig.create_config_dir() +def get_openai_credentials(): + openai_key = os.getenv("OPENAI_API_KEY") + return openai_key - # Will pre-populate with defaults, or what the user previously set - config = MemGPTConfig.load() - # openai credentials - use_openai = questionary.confirm("Do you want to enable MemGPT with OpenAI?", default=True).ask() - if use_openai: - # search for key in enviornment - openai_key = os.getenv("OPENAI_API_KEY") - if not openai_key: - print("Missing enviornment variables for OpenAI. Please set them and run `memgpt configure` again.") - # TODO: eventually stop relying on env variables and pass in keys explicitly - # openai_key = questionary.text("Open AI API keys not found in enviornment - please enter:").ask() +def configure_llm_endpoint(config: MemGPTConfig): + # configure model endpoint + model_endpoint_type, model_endpoint = None, None - # azure credentials - use_azure = questionary.confirm("Do you want to enable MemGPT with Azure?", default=(config.azure_key is not None)).ask() - use_azure_deployment_ids = False - if use_azure: - # search for key in enviornment - azure_key = os.getenv("AZURE_OPENAI_KEY") - azure_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT") - azure_version = os.getenv("AZURE_OPENAI_VERSION") - azure_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT") - azure_embedding_deployment = os.getenv("AZURE_OPENAI_EMBEDDINGS_DEPLOYMENT") + # get default + default_model_endpoint_type = config.model_endpoint_type + if config.model_endpoint_type is not None and config.model_endpoint_type not in ["openai", "azure"]: # local model + default_model_endpoint_type = "local" - if all([azure_key, azure_endpoint, azure_version]): - print(f"Using Microsoft endpoint {azure_endpoint}.") - if all([azure_deployment, azure_embedding_deployment]): - print(f"Using deployment id {azure_deployment}") - use_azure_deployment_ids = True + provider = questionary.select( + "Select LLM inference provider:", choices=["openai", "azure", "local"], default=default_model_endpoint_type + ).ask() - # configure openai - openai.api_type = "azure" - openai.api_key = azure_key - openai.api_base = azure_endpoint - openai.api_version = azure_version + # set: model_endpoint_type, model_endpoint + if provider == "openai": + model_endpoint_type = "openai" + model_endpoint = "https://api.openai.com/v1" + model_endpoint = questionary.text("Override default endpoint:", default=model_endpoint).ask() + provider = "openai" + elif provider == "azure": + model_endpoint_type = "azure" + _, model_endpoint, _, _, _ = get_azure_credentials() + else: # local models + backend_options = ["webui", "llamacpp", "koboldcpp", "ollama", "lmstudio", "openai"] + default_model_endpoint_type = None + if config.model_endpoint_type in backend_options: + # set from previous config + default_model_endpoint_type = config.model_endpoint_type else: - print("Missing enviornment variables for Azure. Please set then run `memgpt configure` again.") - # TODO: allow for manual setting - use_azure = False + # set form env variable (ok if none) + default_model_endpoint_type = os.getenv("BACKEND_TYPE") + model_endpoint_type = questionary.select( + "Select LLM backend (select 'openai' if you have an OpenAI compatible proxy):", + backend_options, + default=default_model_endpoint_type, + ).ask() - # TODO: configure local model + # set default endpoint + # if OPENAI_API_BASE is set, assume that this is the IP+port the user wanted to use + default_model_endpoint = os.getenv("OPENAI_API_BASE") + # if OPENAI_API_BASE is not set, try to pull a default IP+port format from a hardcoded set + if default_model_endpoint is None: + if model_endpoint_type in DEFAULT_ENDPOINTS: + default_model_endpoint = DEFAULT_ENDPOINTS[model_endpoint_type] + model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask() + else: + # default_model_endpoint = None + model_endpoint = None + while not model_endpoint: + model_endpoint = questionary.text("Enter default endpoint:").ask() + if "http://" not in model_endpoint and "https://" not in model_endpoint: + typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) + model_endpoint = None + assert model_endpoint, f"Environment variable OPENAI_API_BASE must be set." - # configure provider - model_endpoint_options = [] - if os.getenv("OPENAI_API_BASE") is not None: - model_endpoint_options.append(os.getenv("OPENAI_API_BASE")) - if use_openai: - model_endpoint_options += ["openai"] - if use_azure: - model_endpoint_options += ["azure"] - assert ( - len(model_endpoint_options) > 0 - ), "No endpoints found. Please enable OpenAI, Azure, or set OPENAI_API_BASE to point at the IP address of your LLM server." - valid_default_model = config.model_endpoint in model_endpoint_options - default_endpoint = questionary.select( - "Select default inference endpoint:", - model_endpoint_options, - default=config.model_endpoint if valid_default_model else model_endpoint_options[0], - ).ask() + return model_endpoint_type, model_endpoint - # configure embedding provider - embedding_endpoint_options = [] - if use_azure: - embedding_endpoint_options += ["azure"] - if use_openai: - embedding_endpoint_options += ["openai"] - embedding_endpoint_options += ["local"] - valid_default_embedding = config.embedding_model in embedding_endpoint_options - # determine the default selection in a smart way - if "openai" in embedding_endpoint_options and default_endpoint == "openai": - # openai llm -> openai embeddings - default_embedding_endpoint_default = "openai" - elif default_endpoint not in ["openai", "azure"]: # is local - # local llm -> local embeddings - default_embedding_endpoint_default = "local" - else: - default_embedding_endpoint_default = config.embedding_model if valid_default_embedding else embedding_endpoint_options[-1] - default_embedding_endpoint = questionary.select( - "Select default embedding endpoint:", embedding_endpoint_options, default=default_embedding_endpoint_default - ).ask() - # configure embedding dimentions - default_embedding_dim = config.embedding_dim - if default_embedding_endpoint == "local": - # HF model uses lower dimentionality - default_embedding_dim = 384 - - # configure preset - default_preset = questionary.select("Select default preset:", preset_options, default=config.preset).ask() - - # default model - if use_openai or use_azure: - model_options = [] - if use_openai: - model_options += ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo-16k"] +def configure_model(config: MemGPTConfig, model_endpoint_type: str): + # set: model, model_wrapper + model, model_wrapper = None, None + if model_endpoint_type == "openai" or model_endpoint_type == "azure": + model_options = ["gpt-4", "gpt-4-1106-preview", "gpt-3.5-turbo", "gpt-3.5-turbo-16k"] + # TODO: select valid_model = config.model in model_options - default_model = questionary.select( + model = questionary.select( "Select default model (recommended: gpt-4):", choices=model_options, default=config.model if valid_model else model_options[0] ).ask() - else: - default_model = "local" # TODO: figure out if this is ok? this is for local endpoint + else: # local models + # ollama also needs model type + if model_endpoint_type == "ollama": + default_model = config.model if config.model and config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL + model = questionary.text( + "Enter default model name (required for Ollama, see: https://memgpt.readthedocs.io/en/latest/ollama):", + default=default_model, + ).ask() + model = None if len(model) == 0 else model - # get the max tokens (context window) for the model - if default_model == "local" or str(default_model) not in LLM_MAX_TOKENS: + # model wrapper + available_model_wrappers = builtins.list(get_available_wrappers().keys()) + model_wrapper = questionary.select( + f"Select default model wrapper (recommended: {DEFAULT_WRAPPER_NAME}):", + choices=available_model_wrappers, + default=DEFAULT_WRAPPER_NAME, + ).ask() + + # set: context_window + if str(model) not in LLM_MAX_TOKENS: # Ask the user to specify the context length context_length_options = [ str(2**12), # 4096 @@ -140,46 +133,80 @@ def configure(): str(2**18), # 262144 "custom", # enter yourself ] - default_model_context_window = questionary.select( + context_window = questionary.select( "Select your model's context window (for Mistral 7B models, this is probably 8k / 8192):", choices=context_length_options, default=str(LLM_MAX_TOKENS["DEFAULT"]), ).ask() # If custom, ask for input - if default_model_context_window == "custom": + if context_window == "custom": while True: - default_model_context_window = questionary.text("Enter context window (e.g. 8192)").ask() + context_window = questionary.text("Enter context window (e.g. 8192)").ask() try: - default_model_context_window = int(default_model_context_window) + context_window = int(context_window) break except ValueError: print(f"Context window must be a valid integer") else: - default_model_context_window = int(default_model_context_window) + context_window = int(context_window) else: # Pull the context length from the models - default_model_context_window = LLM_MAX_TOKENS[default_model] + context_window = LLM_MAX_TOKENS[model] + return model, model_wrapper, context_window - # defaults + +def configure_embedding_endpoint(config: MemGPTConfig): + # configure embedding endpoint + + default_embedding_endpoint_type = config.embedding_endpoint_type + if config.embedding_endpoint_type is not None and config.embedding_endpoint_type not in ["openai", "azure"]: # local model + default_embedding_endpoint_type = "local" + + embedding_endpoint_type, embedding_endpoint, embedding_dim = None, None, None + embedding_provider = questionary.select( + "Select embedding provider:", choices=["openai", "azure", "local"], default=default_embedding_endpoint_type + ).ask() + if embedding_provider == "openai": + embedding_endpoint_type = "openai" + embedding_endpoint = "https://api.openai.com/v1" + embedding_dim = 1536 + elif embedding_provider == "azure": + embedding_endpoint_type = "azure" + _, _, _, _, embedding_endpoint = get_azure_credentials() + embedding_dim = 1536 + else: # local models + embedding_endpoint_type = "local" + embedding_endpoint = None + embedding_dim = 384 + return embedding_endpoint_type, embedding_endpoint, embedding_dim + + +def configure_cli(config: MemGPTConfig): + # set: preset, default_persona, default_human, default_agent`` + from memgpt.presets.presets import preset_options + + # preset + default_preset = config.preset if config.preset and config.preset in preset_options else None + preset = questionary.select("Select default preset:", preset_options, default=default_preset).ask() + + # persona personas = [os.path.basename(f).replace(".txt", "") for f in utils.list_persona_files()] - # print(personas) - default_persona = questionary.select("Select default persona:", personas, default=config.default_persona).ask() + default_persona = config.persona if config.persona and config.persona in personas else None + persona = questionary.select("Select default persona:", personas, default=default_persona).ask() + + # human humans = [os.path.basename(f).replace(".txt", "") for f in utils.list_human_files()] - # print(humans) - default_human = questionary.select("Select default human:", humans, default=config.default_human).ask() + default_human = config.human if config.human and config.human in humans else None + human = questionary.select("Select default human:", humans, default=default_human).ask() # TODO: figure out if we should set a default agent or not - default_agent = None - # agents = [os.path.basename(f).replace(".json", "") for f in utils.list_agent_config_files()] - # if len(agents) > 0: # agents have been created - # default_agent = questionary.select( - # "Select default agent:", - # agents - # ).ask() - # else: - # default_agent = None + agent = None + return preset, persona, human, agent + + +def configure_archival_storage(config: MemGPTConfig): # Configure archival storage backend archival_storage_options = ["local", "postgres"] archival_storage_type = questionary.select( @@ -191,25 +218,65 @@ def configure(): "Enter postgres connection string (e.g. postgresql+pg8000://{user}:{password}@{ip}:5432/{database}):", default=config.archival_storage_uri if config.archival_storage_uri else "", ).ask() + return archival_storage_type, archival_storage_uri - # TODO: allow configuring embedding model + +@app.command() +def configure(): + """Updates default MemGPT configurations""" + + MemGPTConfig.create_config_dir() + + # Will pre-populate with defaults, or what the user previously set + config = MemGPTConfig.load() + model_endpoint_type, model_endpoint = configure_llm_endpoint(config) + model, model_wrapper, context_window = configure_model(config, model_endpoint_type) + embedding_endpoint_type, embedding_endpoint, embedding_dim = configure_embedding_endpoint(config) + default_preset, default_persona, default_human, default_agent = configure_cli(config) + archival_storage_type, archival_storage_uri = configure_archival_storage(config) + + # check credentials + azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = get_azure_credentials() + openai_key = get_openai_credentials() + if model_endpoint_type == "azure" or embedding_endpoint_type == "azure": + if all([azure_key, azure_endpoint, azure_version]): + print(f"Using Microsoft endpoint {azure_endpoint}.") + if all([azure_deployment, azure_embedding_deployment]): + print(f"Using deployment id {azure_deployment}") + else: + raise ValueError( + "Missing environment variables for Azure (see https://memgpt.readthedocs.io/en/latest/endpoints/#azure). Please set then run `memgpt configure` again." + ) + if model_endpoint_type == "openai" or embedding_endpoint_type == "openai": + if not openai_key: + raise ValueError( + "Missing environment variables for OpenAI (see https://memgpt.readthedocs.io/en/latest/endpoints/#openai). Please set them and run `memgpt configure` again." + ) config = MemGPTConfig( - model=default_model, - context_window=default_model_context_window, + # model configs + model=model, + model_endpoint=model_endpoint, + model_endpoint_type=model_endpoint_type, + model_wrapper=model_wrapper, + context_window=context_window, + # embedding configs + embedding_endpoint_type=embedding_endpoint_type, + embedding_endpoint=embedding_endpoint, + embedding_dim=embedding_dim, + # cli configs preset=default_preset, - model_endpoint=default_endpoint, - embedding_model=default_embedding_endpoint, - embedding_dim=default_embedding_dim, - default_persona=default_persona, - default_human=default_human, - default_agent=default_agent, - openai_key=openai_key if use_openai else None, - azure_key=azure_key if use_azure else None, - azure_endpoint=azure_endpoint if use_azure else None, - azure_version=azure_version if use_azure else None, - azure_deployment=azure_deployment if use_azure_deployment_ids else None, - azure_embedding_deployment=azure_embedding_deployment if use_azure_deployment_ids else None, + persona=default_persona, + human=default_human, + agent=default_agent, + # credentials + openai_key=openai_key, + azure_key=azure_key, + azure_endpoint=azure_endpoint, + azure_version=azure_version, + azure_deployment=azure_deployment, + azure_embedding_deployment=azure_embedding_deployment, + # storage archival_storage_type=archival_storage_type, archival_storage_uri=archival_storage_uri, ) diff --git a/memgpt/config.py b/memgpt/config.py index feb4bbea..6a57bf8d 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -16,6 +16,7 @@ from colorama import Fore, Style from typing import List, Type +import memgpt import memgpt.utils as utils from memgpt.interface import CLIInterface as interface from memgpt.personas.personas import get_persona_text @@ -40,6 +41,24 @@ model_choices = [ ] +# helper functions for writing to configs +def get_field(config, section, field): + if section not in config: + return None + if config.has_option(section, field): + return config.get(section, field) + else: + return None + + +def set_field(config, section, field, value): + if value is None: # cannot write None + return + if section not in config: # create section + config.add_section(section) + config.set(section, field, value) + + @dataclass class MemGPTConfig: config_path: str = os.path.join(MEMGPT_DIR, "config") @@ -49,9 +68,10 @@ class MemGPTConfig: preset: str = DEFAULT_PRESET # model parameters - # provider: str = "openai" # openai, azure, local (TODO) - model_endpoint: str = "openai" - model: str = "gpt-4" # gpt-4, gpt-3.5-turbo, local + model: str = None + model_endpoint_type: str = None + model_endpoint: str = None # localhost:8000 + model_wrapper: str = None context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"] # model parameters: openai @@ -65,12 +85,13 @@ class MemGPTConfig: azure_embedding_deployment: str = None # persona parameters - default_persona: str = personas.DEFAULT - default_human: str = humans.DEFAULT - default_agent: str = None + persona: str = personas.DEFAULT + human: str = humans.DEFAULT + agent: str = None # embedding parameters - embedding_model: str = "openai" + embedding_endpoint_type: str = "openai" # openai, azure, local + embedding_endpoint: str = None embedding_dim: int = 1536 embedding_chunk_size: int = 300 # number of tokens @@ -89,6 +110,12 @@ class MemGPTConfig: persistence_manager_save_file: str = None # local file persistence_manager_uri: str = None # db URI + def __post_init__(self): + # ensure types + self.embedding_chunk_size = int(self.embedding_chunk_size) + self.embedding_dim = int(self.embedding_dim) + self.context_window = int(self.context_window) + @staticmethod def generate_uuid() -> str: return uuid.UUID(int=uuid.getnode()).hex @@ -104,72 +131,38 @@ class MemGPTConfig: config_path = MemGPTConfig.config_path if os.path.exists(config_path): + # read existing config config.read(config_path) + config_dict = { + "model": get_field(config, "model", "model"), + "model_endpoint": get_field(config, "model", "model_endpoint"), + "model_endpoint_type": get_field(config, "model", "model_endpoint_type"), + "model_wrapper": get_field(config, "model", "model_wrapper"), + "context_window": get_field(config, "model", "context_window"), + "preset": get_field(config, "defaults", "preset"), + "persona": get_field(config, "defaults", "persona"), + "human": get_field(config, "defaults", "human"), + "agent": get_field(config, "defaults", "agent"), + "openai_key": get_field(config, "openai", "key"), + "azure_key": get_field(config, "azure", "key"), + "azure_endpoint": get_field(config, "azure", "endpoint"), + "azure_version": get_field(config, "azure", "version"), + "azure_deployment": get_field(config, "azure", "deployment"), + "azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"), + "embedding_endpoint": get_field(config, "embedding", "embedding_endpoint"), + "embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"), + "embedding_dim": get_field(config, "embedding", "embedding_dim"), + "embedding_chunk_size": get_field(config, "embedding", "chunk_size"), + "archival_storage_type": get_field(config, "archival_storage", "type"), + "archival_storage_path": get_field(config, "archival_storage", "path"), + "archival_storage_uri": get_field(config, "archival_storage", "uri"), + "anon_clientid": get_field(config, "client", "anon_clientid"), + "config_path": config_path, + } + config_dict = {k: v for k, v in config_dict.items() if v is not None} + return cls(**config_dict) - # read config values - model = config.get("defaults", "model") - context_window = ( - int(config.get("defaults", "context_window")) - if config.has_option("defaults", "context_window") - else LLM_MAX_TOKENS["DEFAULT"] - ) - preset = config.get("defaults", "preset") - model_endpoint = config.get("defaults", "model_endpoint") - default_persona = config.get("defaults", "persona") - default_human = config.get("defaults", "human") - default_agent = config.get("defaults", "agent") if config.has_option("defaults", "agent") else None - - openai_key, openai_model = None, None - if "openai" in config: - openai_key = config.get("openai", "key") - - azure_key, azure_endpoint, azure_version, azure_deployment, azure_embedding_deployment = None, None, None, None, None - if "azure" in config: - azure_key = config.get("azure", "key") - azure_endpoint = config.get("azure", "endpoint") - azure_version = config.get("azure", "version") - azure_deployment = config.get("azure", "deployment") if config.has_option("azure", "deployment") else None - azure_embedding_deployment = ( - config.get("azure", "embedding_deployment") if config.has_option("azure", "embedding_deployment") else None - ) - - embedding_model = config.get("embedding", "model") - embedding_dim = config.getint("embedding", "dim") - embedding_chunk_size = config.getint("embedding", "chunk_size") - - # archival storage - archival_storage_type, archival_storage_path, archival_storage_uri = "local", None, None - if "archival_storage" in config: - archival_storage_type = config.get("archival_storage", "type") - archival_storage_path = config.get("archival_storage", "path") if config.has_option("archival_storage", "path") else None - archival_storage_uri = config.get("archival_storage", "uri") if config.has_option("archival_storage", "uri") else None - - anon_clientid = config.get("client", "anon_clientid") - - return cls( - model=model, - context_window=context_window, - preset=preset, - model_endpoint=model_endpoint, - default_persona=default_persona, - default_human=default_human, - default_agent=default_agent, - openai_key=openai_key, - azure_key=azure_key, - azure_endpoint=azure_endpoint, - azure_version=azure_version, - azure_deployment=azure_deployment, - azure_embedding_deployment=azure_embedding_deployment, - embedding_model=embedding_model, - embedding_dim=embedding_dim, - embedding_chunk_size=embedding_chunk_size, - archival_storage_type=archival_storage_type, - archival_storage_path=archival_storage_path, - archival_storage_uri=archival_storage_uri, - anon_clientid=anon_clientid, - config_path=config_path, - ) - + # create new config anon_clientid = MemGPTConfig.generate_uuid() config = cls(anon_clientid=anon_clientid, config_path=config_path) config.save() # save updated config @@ -179,51 +172,43 @@ class MemGPTConfig: config = configparser.ConfigParser() # CLI defaults - config.add_section("defaults") - config.set("defaults", "model", self.model) - config.set("defaults", "context_window", str(self.context_window)) - config.set("defaults", "preset", self.preset) - assert self.model_endpoint is not None, "Endpoint must be set" - config.set("defaults", "model_endpoint", self.model_endpoint) - config.set("defaults", "persona", self.default_persona) - config.set("defaults", "human", self.default_human) - if self.default_agent: - config.set("defaults", "agent", self.default_agent) + set_field(config, "defaults", "preset", self.preset) + set_field(config, "defaults", "persona", self.persona) + set_field(config, "defaults", "human", self.human) + set_field(config, "defaults", "agent", self.agent) - # security credentials - if self.openai_key: - config.add_section("openai") - config.set("openai", "key", self.openai_key) + # model defaults + set_field(config, "model", "model", self.model) + set_field(config, "model", "model_endpoint", self.model_endpoint) + set_field(config, "model", "model_endpoint_type", self.model_endpoint_type) + set_field(config, "model", "model_wrapper", self.model_wrapper) + set_field(config, "model", "context_window", str(self.context_window)) - if self.azure_key: - config.add_section("azure") - config.set("azure", "key", self.azure_key) - config.set("azure", "endpoint", self.azure_endpoint) - config.set("azure", "version", self.azure_version) - if self.azure_deployment: - config.set("azure", "deployment", self.azure_deployment) - config.set("azure", "embedding_deployment", self.azure_embedding_deployment) + # security credentials: openai + set_field(config, "openai", "key", self.openai_key) + + # security credentials: azure + set_field(config, "azure", "key", self.azure_key) + set_field(config, "azure", "endpoint", self.azure_endpoint) + set_field(config, "azure", "version", self.azure_version) + set_field(config, "azure", "deployment", self.azure_deployment) + set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment) # embeddings - config.add_section("embedding") - config.set("embedding", "model", self.embedding_model) - config.set("embedding", "dim", str(self.embedding_dim)) - config.set("embedding", "chunk_size", str(self.embedding_chunk_size)) + set_field(config, "embedding", "embedding_endpoint_type", self.embedding_endpoint_type) + set_field(config, "embedding", "embedding_endpoint", self.embedding_endpoint) + set_field(config, "embedding", "embedding_dim", str(self.embedding_dim)) + set_field(config, "embedding", "embedding_chunk_size", str(self.embedding_chunk_size)) # archival storage - config.add_section("archival_storage") - # print("archival storage", self.archival_storage_type) - config.set("archival_storage", "type", self.archival_storage_type) - if self.archival_storage_path: - config.set("archival_storage", "path", self.archival_storage_path) - if self.archival_storage_uri: - config.set("archival_storage", "uri", self.archival_storage_uri) + set_field(config, "archival_storage", "type", self.archival_storage_type) + set_field(config, "archival_storage", "path", self.archival_storage_path) + set_field(config, "archival_storage", "uri", self.archival_storage_uri) # client - config.add_section("client") if not self.anon_clientid: self.anon_clientid = self.generate_uuid() - config.set("client", "anon_clientid", self.anon_clientid) + set_field(config, "client", "anon_clientid", self.anon_clientid) if not os.path.exists(MEMGPT_DIR): os.makedirs(MEMGPT_DIR, exist_ok=True) @@ -262,32 +247,54 @@ class AgentConfig: self, persona, human, + # model info model, + model_endpoint_type=None, + model_endpoint=None, + model_wrapper=None, context_window=None, - preset=DEFAULT_PRESET, - name=None, - data_sources=[], + # embedding info + embedding_endpoint_type=None, + embedding_endpoint=None, + embedding_dim=None, + embedding_chunk_size=None, + # other + preset=None, + data_sources=None, + # agent info agent_config_path=None, + name=None, create_time=None, - data_source=None, + memgpt_version=None, ): if name is None: self.name = f"agent_{self.generate_agent_id()}" else: self.name = name - self.persona = persona - self.human = human - self.model = model - self.context_window = context_window - self.preset = preset - self.data_sources = data_sources - self.create_time = create_time if create_time is not None else utils.get_local_time() - self.data_source = None # deprecated - if context_window is None: - self.context_window = LLM_MAX_TOKENS[self.model] if self.model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"] + config = MemGPTConfig.load() # get default values + self.persona = config.persona if persona is None else persona + self.human = config.human if human is None else human + self.preset = config.preset if preset is None else preset + self.context_window = config.context_window if context_window is None else context_window + self.model = config.model if model is None else model + self.model_endpoint_type = config.model_endpoint_type if model_endpoint_type is None else model_endpoint_type + self.model_endpoint = config.model_endpoint if model_endpoint is None else model_endpoint + self.model_wrapper = config.model_wrapper if model_wrapper is None else model_wrapper + self.embedding_endpoint_type = config.embedding_endpoint_type if embedding_endpoint_type is None else embedding_endpoint_type + self.embedding_endpoint = config.embedding_endpoint if embedding_endpoint is None else embedding_endpoint + self.embedding_dim = config.embedding_dim if embedding_dim is None else embedding_dim + self.embedding_chunk_size = config.embedding_chunk_size if embedding_chunk_size is None else embedding_chunk_size + + # agent metadata + self.data_sources = data_sources if data_sources is not None else [] + self.create_time = create_time if create_time is not None else utils.get_local_time() + if memgpt_version is None: + import memgpt + + self.memgpt_version = memgpt.__version__ else: - self.context_window = context_window + self.memgpt_version = memgpt_version # save agent config self.agent_config_path = ( @@ -326,6 +333,8 @@ class AgentConfig: def save(self): # save state of persistence manager os.makedirs(os.path.join(MEMGPT_DIR, "agents", self.name), exist_ok=True) + # save version + self.memgpt_version = memgpt.__version__ with open(self.agent_config_path, "w") as f: json.dump(vars(self), f, indent=4) @@ -342,7 +351,6 @@ class AgentConfig: assert os.path.exists(agent_config_path), f"Agent config file does not exist at {agent_config_path}" with open(agent_config_path, "r") as f: agent_config = json.load(f) - # allow compatibility accross versions try: class_args = inspect.getargspec(cls.__init__).args @@ -354,7 +362,6 @@ class AgentConfig: if key not in class_args: utils.printd(f"Removing missing argument {key} from agent config") del agent_config[key] - return cls(**agent_config) diff --git a/memgpt/embeddings.py b/memgpt/embeddings.py index 20260ed1..d2c2e476 100644 --- a/memgpt/embeddings.py +++ b/memgpt/embeddings.py @@ -11,9 +11,9 @@ def embedding_model(): # load config config = MemGPTConfig.load() - endpoint = config.embedding_model + endpoint = config.embedding_endpoint_type if endpoint == "openai": - model = OpenAIEmbedding(api_base="https://api.openai.com/v1", api_key=config.openai_key) + model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=config.openai_key) return model elif endpoint == "azure": return OpenAIEmbedding( diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py index 424017d0..38d75286 100644 --- a/memgpt/functions/function_sets/base.py +++ b/memgpt/functions/function_sets/base.py @@ -4,7 +4,7 @@ import os import json import math -from ...constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE +from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE ### Functions / tools the agent can use # All functions should return a response string (or None) diff --git a/memgpt/functions/function_sets/extras.py b/memgpt/functions/function_sets/extras.py index 86883e3e..eda4d53c 100644 --- a/memgpt/functions/function_sets/extras.py +++ b/memgpt/functions/function_sets/extras.py @@ -4,8 +4,8 @@ import json import requests -from ...constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS -from ...openai_tools import completions_with_backoff as create +from memgpt.constants import MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MAX_PAUSE_HEARTBEATS +from memgpt.openai_tools import completions_with_backoff as create def message_chatgpt(self, message: str): diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 9a59dce2..63a0b540 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -10,74 +10,65 @@ from .llamacpp.api import get_llamacpp_completion from .koboldcpp.api import get_koboldcpp_completion from .ollama.api import get_ollama_completion from .llm_chat_completion_wrappers import airoboros, dolphin, zephyr, simple_summary_wrapper -from .utils import DotDict +from .constants import DEFAULT_WRAPPER +from .utils import DotDict, get_available_wrappers from ..prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE from ..errors import LocalLLMConnectionError, LocalLLMError -HOST = os.getenv("OPENAI_API_BASE") -HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion +endpoint = os.getenv("OPENAI_API_BASE") +endpoint_type = os.getenv("BACKEND_TYPE") # default None == ChatCompletion DEBUG = False # DEBUG = True -DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper + has_shown_warning = False def get_chat_completion( - model, # no model, since the model is fixed to whatever you set in your own backend + model, # no model required (except for Ollama), since the model is fixed to whatever you set in your own backend messages, functions=None, function_call="auto", context_window=None, + # required + wrapper=None, + endpoint=None, + endpoint_type=None, ): assert context_window is not None, "Local LLM calls need the context length to be explicitly set" + assert endpoint is not None, "Local LLM calls need the endpoint (eg http://localendpoint:1234) to be explicitly set" + assert endpoint_type is not None, "Local LLM calls need the endpoint type (eg webui) to be explicitly set" global has_shown_warning grammar_name = None - if HOST is None: - raise ValueError(f"The OPENAI_API_BASE environment variable is not defined. Please set it in your environment.") - if HOST_TYPE is None: - raise ValueError(f"The BACKEND_TYPE environment variable is not defined. Please set it in your environment.") - if function_call != "auto": raise ValueError(f"function_call == {function_call} not supported (auto only)") + available_wrappers = get_available_wrappers() if messages[0]["role"] == "system" and messages[0]["content"].strip() == SUMMARIZE_SYSTEM_MESSAGE.strip(): # Special case for if the call we're making is coming from the summarizer llm_wrapper = simple_summary_wrapper.SimpleSummaryWrapper() - elif model == "airoboros-l2-70b-2.1": - llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper() - elif model == "airoboros-l2-70b-2.1-grammar": - llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False) - # grammar_name = "json" - grammar_name = "json_func_calls_with_inner_thoughts" - elif model == "dolphin-2.1-mistral-7b": - llm_wrapper = dolphin.Dolphin21MistralWrapper() - elif model == "dolphin-2.1-mistral-7b-grammar": - llm_wrapper = dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False) - # grammar_name = "json" - grammar_name = "json_func_calls_with_inner_thoughts" - elif model == "zephyr-7B-alpha" or model == "zephyr-7B-beta": - llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper() - elif model == "zephyr-7B-alpha-grammar" or model == "zephyr-7B-beta-grammar": - llm_wrapper = zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False) - # grammar_name = "json" - grammar_name = "json_func_calls_with_inner_thoughts" - else: + elif wrapper is None: # Warn the user that we're using the fallback if not has_shown_warning: print( - f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --model)" + f"Warning: no wrapper specified for local LLM, using the default wrapper (you can remove this warning by specifying the wrapper with --wrapper)" ) has_shown_warning = True - if HOST_TYPE in ["koboldcpp", "llamacpp", "webui"]: + if endpoint_type in ["koboldcpp", "llamacpp", "webui"]: # make the default to use grammar llm_wrapper = DEFAULT_WRAPPER(include_opening_brace_in_prefix=False) # grammar_name = "json" grammar_name = "json_func_calls_with_inner_thoughts" else: llm_wrapper = DEFAULT_WRAPPER() + elif wrapper not in available_wrappers: + raise ValueError(f"Could not find requested wrapper '{wrapper} in available wrappers list:\n{available_wrappers}") + else: + llm_wrapper = available_wrappers[wrapper] + if "grammar" in wrapper: + grammar_name = "json_func_calls_with_inner_thoughts" - if grammar_name is not None and HOST_TYPE not in ["koboldcpp", "llamacpp", "webui"]: + if grammar_name is not None and endpoint_type not in ["koboldcpp", "llamacpp", "webui"]: print(f"Warning: grammars are currently only supported when using llama.cpp as the MemGPT local LLM backend") # First step: turn the message sequence into a prompt that the model expects @@ -91,25 +82,25 @@ def get_chat_completion( ) try: - if HOST_TYPE == "webui": - result = get_webui_completion(prompt, context_window, grammar=grammar_name) - elif HOST_TYPE == "lmstudio": - result = get_lmstudio_completion(prompt, context_window) - elif HOST_TYPE == "llamacpp": - result = get_llamacpp_completion(prompt, context_window, grammar=grammar_name) - elif HOST_TYPE == "koboldcpp": - result = get_koboldcpp_completion(prompt, context_window, grammar=grammar_name) - elif HOST_TYPE == "ollama": - result = get_ollama_completion(prompt, context_window) + if endpoint_type == "webui": + result = get_webui_completion(endpoint, prompt, context_window, grammar=grammar_name) + elif endpoint_type == "lmstudio": + result = get_lmstudio_completion(endpoint, prompt, context_window) + elif endpoint_type == "llamacpp": + result = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar_name) + elif endpoint_type == "koboldcpp": + result = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar_name) + elif endpoint_type == "ollama": + result = get_ollama_completion(endpoint, model, prompt, context_window) else: raise LocalLLMError( f"BACKEND_TYPE is not set, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" ) except requests.exceptions.ConnectionError as e: - raise LocalLLMConnectionError(f"Unable to connect to host {HOST}") + raise LocalLLMConnectionError(f"Unable to connect to endpoint {endpoint}") if result is None or result == "": - raise LocalLLMError(f"Got back an empty response string from {HOST}") + raise LocalLLMError(f"Got back an empty response string from {endpoint}") if DEBUG: print(f"Raw LLM output:\n{result}") @@ -123,7 +114,7 @@ def get_chat_completion( # unpack with response.choices[0].message.content response = DotDict( { - "model": None, + "model": model, "choices": [ DotDict( { diff --git a/memgpt/local_llm/constants.py b/memgpt/local_llm/constants.py new file mode 100644 index 00000000..ecb1e637 --- /dev/null +++ b/memgpt/local_llm/constants.py @@ -0,0 +1,14 @@ +import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros + +DEFAULT_ENDPOINTS = { + "koboldcpp": "http://localhost:5001", + "llamacpp": "http://localhost:8080", + "lmstudio": "http://localhost:1234", + "ollama": "http://localhost:11434", + "webui": "http://localhost:5000", +} + +DEFAULT_OLLAMA_MODEL = "dolphin2.2-mistral:7b-q6_K" + +DEFAULT_WRAPPER = airoboros.Airoboros21InnerMonologueWrapper +DEFAULT_WRAPPER_NAME = "airoboros-l2-70b-2.1" diff --git a/memgpt/local_llm/koboldcpp/api.py b/memgpt/local_llm/koboldcpp/api.py index 1ca93392..0d000017 100644 --- a/memgpt/local_llm/koboldcpp/api.py +++ b/memgpt/local_llm/koboldcpp/api.py @@ -5,14 +5,12 @@ import requests from .settings import SIMPLE from ..utils import load_grammar_file, count_tokens -HOST = os.getenv("OPENAI_API_BASE") -HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion KOBOLDCPP_API_SUFFIX = "/api/v1/generate" -# DEBUG = False -DEBUG = True +DEBUG = False +# DEBUG = True -def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMPLE): +def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE): """See https://lite.koboldai.net/koboldcpp_api for API spec""" prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: @@ -27,13 +25,13 @@ def get_koboldcpp_completion(prompt, context_window, grammar=None, settings=SIMP if grammar is not None: request["grammar"] = load_grammar_file(grammar) - if not HOST.startswith(("http://", "https://")): - raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + if not endpoint.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") try: # NOTE: llama.cpp server returns the following when it's out of context # curl: (52) Empty reply from server - URI = urljoin(HOST.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/")) + URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: result = response.json() diff --git a/memgpt/local_llm/llamacpp/api.py b/memgpt/local_llm/llamacpp/api.py index ce91ad61..4b3c693b 100644 --- a/memgpt/local_llm/llamacpp/api.py +++ b/memgpt/local_llm/llamacpp/api.py @@ -5,14 +5,12 @@ import requests from .settings import SIMPLE from ..utils import load_grammar_file, count_tokens -HOST = os.getenv("OPENAI_API_BASE") -HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion LLAMACPP_API_SUFFIX = "/completion" -# DEBUG = False -DEBUG = True +DEBUG = False +# DEBUG = True -def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPLE): +def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE): """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: @@ -26,13 +24,13 @@ def get_llamacpp_completion(prompt, context_window, grammar=None, settings=SIMPL if grammar is not None: request["grammar"] = load_grammar_file(grammar) - if not HOST.startswith(("http://", "https://")): - raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + if not endpoint.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") try: # NOTE: llama.cpp server returns the following when it's out of context # curl: (52) Empty reply from server - URI = urljoin(HOST.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/")) + URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: result = response.json() diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py index 79c5f93f..e643b17a 100644 --- a/memgpt/local_llm/lmstudio/api.py +++ b/memgpt/local_llm/lmstudio/api.py @@ -5,14 +5,12 @@ import requests from .settings import SIMPLE from ..utils import count_tokens -HOST = os.getenv("OPENAI_API_BASE") -HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions" LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" DEBUG = False -def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat"): +def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="chat"): """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: @@ -25,19 +23,19 @@ def get_lmstudio_completion(prompt, context_window, settings=SIMPLE, api="chat") if api == "chat": # Uses the ChatCompletions API style # Seems to work better, probably because it's applying some extra settings under-the-hood? - URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/")) + URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/")) message_structure = [{"role": "user", "content": prompt}] request["messages"] = message_structure elif api == "completions": # Uses basic string completions (string in, string out) # Does not work as well as ChatCompletions for some reason - URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/")) + URI = urljoin(endpoint.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/")) request["prompt"] = prompt else: raise ValueError(api) - if not HOST.startswith(("http://", "https://")): - raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + if not endpoint.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") try: response = requests.post(URI, json=request) diff --git a/memgpt/local_llm/ollama/api.py b/memgpt/local_llm/ollama/api.py index 6f1fed4d..d2b959f4 100644 --- a/memgpt/local_llm/ollama/api.py +++ b/memgpt/local_llm/ollama/api.py @@ -6,26 +6,25 @@ from .settings import SIMPLE from ..utils import count_tokens from ...errors import LocalLLMError -HOST = os.getenv("OPENAI_API_BASE") -HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion -MODEL_NAME = os.getenv("OLLAMA_MODEL") # ollama API requires this in the request OLLAMA_API_SUFFIX = "/api/generate" DEBUG = False -def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None): +def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None): """See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") - if MODEL_NAME is None: - raise LocalLLMError(f"Error: OLLAMA_MODEL not specified. Set OLLAMA_MODEL to the model you want to run (e.g. 'dolphin2.2-mistral')") + if model is None: + raise LocalLLMError( + f"Error: model name not specified. Set model in your config to the model you want to run (e.g. 'dolphin2.2-mistral')" + ) # Settings for the generation, includes the prompt + stop tokens, max length, etc request = settings request["prompt"] = prompt - request["model"] = MODEL_NAME + request["model"] = model request["options"]["num_ctx"] = context_window # Set grammar @@ -33,11 +32,11 @@ def get_ollama_completion(prompt, context_window, settings=SIMPLE, grammar=None) # request["grammar_string"] = load_grammar_file(grammar) raise NotImplementedError(f"Ollama does not support grammars") - if not HOST.startswith(("http://", "https://")): - raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + if not endpoint.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") try: - URI = urljoin(HOST.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/")) + URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: result = response.json() diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py index f6c44eae..e55f017f 100644 --- a/memgpt/local_llm/utils.py +++ b/memgpt/local_llm/utils.py @@ -1,6 +1,10 @@ import os import tiktoken +import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros +import memgpt.local_llm.llm_chat_completion_wrappers.dolphin as dolphin +import memgpt.local_llm.llm_chat_completion_wrappers.zephyr as zephyr + class DotDict(dict): """Allow dot access on properties similar to OpenAI response object""" @@ -37,3 +41,14 @@ def load_grammar_file(grammar): def count_tokens(s: str, model: str = "gpt-4") -> int: encoding = tiktoken.encoding_for_model(model) return len(encoding.encode(s)) + + +def get_available_wrappers() -> dict: + return { + "airoboros-l2-70b-2.1": airoboros.Airoboros21InnerMonologueWrapper(), + "airoboros-l2-70b-2.1-grammar": airoboros.Airoboros21InnerMonologueWrapper(include_opening_brace_in_prefix=False), + "dolphin-2.1-mistral-7b": dolphin.Dolphin21MistralWrapper(), + "dolphin-2.1-mistral-7b-grammar": dolphin.Dolphin21MistralWrapper(include_opening_brace_in_prefix=False), + "zephyr-7B": zephyr.ZephyrMistralInnerMonologueWrapper(), + "zephyr-7B-grammar": zephyr.ZephyrMistralInnerMonologueWrapper(include_opening_brace_in_prefix=False), + } diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index e9373e20..b447cf8e 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -5,13 +5,11 @@ import requests from .settings import SIMPLE from ..utils import load_grammar_file, count_tokens -HOST = os.getenv("OPENAI_API_BASE") -HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion WEBUI_API_SUFFIX = "/api/v1/generate" DEBUG = False -def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None): +def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None): """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" prompt_tokens = count_tokens(prompt) if prompt_tokens > context_window: @@ -26,11 +24,11 @@ def get_webui_completion(prompt, context_window, settings=SIMPLE, grammar=None): if grammar is not None: request["grammar_string"] = load_grammar_file(grammar) - if not HOST.startswith(("http://", "https://")): - raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + if not endpoint.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") try: - URI = urljoin(HOST.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) + URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: result = response.json() diff --git a/memgpt/main.py b/memgpt/main.py index eb1e701a..48972ecb 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -241,7 +241,7 @@ def main( memgpt_persona = persona if memgpt_persona is None: memgpt_persona = ( - personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT, + personas.GPT35_DEFAULT if (model is not None and "gpt-3.5" in model) else personas.DEFAULT, None, # represents the personas dir in pymemgpt package ) else: diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py index fce7d796..6e5cd822 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/openai_tools.py @@ -2,10 +2,14 @@ import random import os import time -from .local_llm.chat_completion_proxy import get_chat_completion +import time +from typing import Callable, TypeVar + +from memgpt.local_llm.chat_completion_proxy import get_chat_completion HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion +R = TypeVar("R") import openai @@ -55,6 +59,7 @@ def retry_with_exponential_backoff( return wrapper +# TODO: delete/ignore --legacy @retry_with_exponential_backoff def completions_with_backoff(**kwargs): # Local model @@ -75,6 +80,38 @@ def completions_with_backoff(**kwargs): return openai.ChatCompletion.create(**kwargs) +@retry_with_exponential_backoff +def chat_completion_with_backoff(agent_config, **kwargs): + from memgpt.utils import printd + from memgpt.config import AgentConfig, MemGPTConfig + + printd(f"Using model {agent_config.model_endpoint_type}, endpoint: {agent_config.model_endpoint}") + if agent_config.model_endpoint_type == "openai": + # openai + openai.api_base = agent_config.model_endpoint + return openai.ChatCompletion.create(**kwargs) + elif agent_config.model_endpoint_type == "azure": + # configure openai + config = MemGPTConfig.load() # load credentials (currently not stored in agent config) + openai.api_type = "azure" + openai.api_key = config.azure_key + openai.api_base = config.azure_endpoint + openai.api_version = config.azure_version + if config.azure_deployment is not None: + kwargs["deployment_id"] = config.azure_deployment + else: + kwargs["engine"] = MODEL_TO_AZURE_ENGINE[config.model] + del kwargs["model"] + return openai.ChatCompletion.create(**kwargs) + else: # local model + kwargs["context_window"] = agent_config.context_window # specify for open LLMs + kwargs["endpoint"] = agent_config.model_endpoint # specify for open LLMs + kwargs["endpoint_type"] = agent_config.model_endpoint_type # specify for open LLMs + kwargs["wrapper"] = agent_config.model_wrapper # specify for open LLMs + return get_chat_completion(**kwargs) + + +# TODO: deprecate @retry_with_exponential_backoff def create_embedding_with_backoff(**kwargs): if using_azure(): diff --git a/memgpt/presets/presets.py b/memgpt/presets/presets.py index 8745aee5..d688f0fb 100644 --- a/memgpt/presets/presets.py +++ b/memgpt/presets/presets.py @@ -57,5 +57,5 @@ def use_preset(preset_name, agent_config, model, persona, human, interface, pers persona_notes=persona, human_notes=human, # gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now - first_message_verify_mono=True if "gpt-4" in model else False, + first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False, ) diff --git a/tests/test_cli.py b/tests/test_cli.py index ea06f549..90a55821 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -13,7 +13,7 @@ def test_configure_memgpt(): def test_save_load(): - configure_memgpt() + # configure_memgpt() # rely on configure running first^ child = pexpect.spawn("memgpt run --agent test_save_load --first --strip_ui") child.expect("Enter your message:", timeout=TIMEOUT) diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index d21eb7c2..f58bc0c7 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -1,12 +1,12 @@ # import tempfile # import asyncio -import os +# import os # import asyncio -from datasets import load_dataset +# from datasets import load_dataset -import memgpt -from memgpt.cli.cli_load import load_directory, load_database, load_webpage +# import memgpt +# from memgpt.cli.cli_load import load_directory, load_database, load_webpage # import memgpt.presets as presets # import memgpt.personas.personas as personas diff --git a/tests/test_storage.py b/tests/test_storage.py index fc1286b6..efecfedc 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -1,6 +1,7 @@ import os import subprocess import sys +import pytest subprocess.check_call( [sys.executable, "-m", "pip", "install", "pgvector", "psycopg", "psycopg2-binary"] @@ -15,9 +16,11 @@ from memgpt.config import MemGPTConfig, AgentConfig import argparse +@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL") or not os.getenv("OPENAI_API_KEY"), reason="Missing PG URI and/or OpenAI API key") def test_postgres_openai(): - assert os.getenv("PGVECTOR_TEST_DB_URL") is not None - if os.getenv("OPENAI_API_KEY") is None: + if not os.getenv("PGVECTOR_TEST_DB_URL"): + return # soft pass + if not os.getenv("OPENAI_API_KEY"): return # soft pass # os.environ["MEMGPT_CONFIG_PATH"] = "./config" @@ -54,14 +57,16 @@ def test_postgres_openai(): # print("...finished") +@pytest.mark.skipif(not os.getenv("PGVECTOR_TEST_DB_URL"), reason="Missing PG URI") def test_postgres_local(): - assert os.getenv("PGVECTOR_TEST_DB_URL") is not None + if not os.getenv("PGVECTOR_TEST_DB_URL"): + return # os.environ["MEMGPT_CONFIG_PATH"] = "./config" config = MemGPTConfig( archival_storage_type="postgres", archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"), - embedding_model="local", + embedding_endpoint_type="local", embedding_dim=384, # use HF model ) print(config.config_path) diff --git a/tests/utils.py b/tests/utils.py index 271f70b5..23fc969f 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -3,43 +3,55 @@ import pexpect from .constants import TIMEOUT -def configure_memgpt(enable_openai=True, enable_azure=False): +def configure_memgpt_localllm(): child = pexpect.spawn("memgpt configure") - child.expect("Do you want to enable MemGPT with OpenAI?", timeout=TIMEOUT) - if enable_openai: - child.sendline("y") - else: - child.sendline("n") - - child.expect("Do you want to enable MemGPT with Azure?", timeout=TIMEOUT) - if enable_azure: - child.sendline("y") - else: - child.sendline("n") - - child.expect("Select default inference endpoint:", timeout=TIMEOUT) + child.expect("Select LLM inference provider", timeout=TIMEOUT) + child.send("\x1b[B") # Send the down arrow key + child.send("\x1b[B") # Send the down arrow key child.sendline() - child.expect("Select default embedding endpoint:", timeout=TIMEOUT) + child.expect("Select LLM backend", timeout=TIMEOUT) child.sendline() - child.expect("Select default preset:", timeout=TIMEOUT) + child.expect("Enter default endpoint", timeout=TIMEOUT) child.sendline() - child.expect("Select default model", timeout=TIMEOUT) + child.expect("Select default model wrapper", timeout=TIMEOUT) child.sendline() - child.expect("Select default persona:", timeout=TIMEOUT) + child.expect("Select your model's context window", timeout=TIMEOUT) child.sendline() - child.expect("Select default human:", timeout=TIMEOUT) + child.expect("Select embedding provider", timeout=TIMEOUT) + child.send("\x1b[B") # Send the down arrow key + child.send("\x1b[B") # Send the down arrow key + child.sendline() + + child.expect("Select default preset", timeout=TIMEOUT) + child.sendline() + + child.expect("Select default persona", timeout=TIMEOUT) + child.sendline() + + child.expect("Select default human", timeout=TIMEOUT) + child.sendline() + + child.expect("Select storage backend for archival data", timeout=TIMEOUT) child.sendline() - child.expect("Select storage backend for archival data:", timeout=TIMEOUT) child.sendline() child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit child.close() assert child.isalive() is False, "CLI should have terminated." assert child.exitstatus == 0, "CLI did not exit cleanly." + + +def configure_memgpt(enable_openai=False, enable_azure=False): + if enable_openai: + raise NotImplementedError + elif enable_azure: + raise NotImplementedError + else: + configure_memgpt_localllm()