From da5a8cdbfe57ebb75c1559c8c61f518acde2a0e9 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 18 Jan 2024 16:11:35 -0800 Subject: [PATCH] refactor: remove User LLM/embed. defaults, add credentials file, add authentication option for custom LLM backends (#835) --- memgpt/agent.py | 2 +- memgpt/agent_store/db.py | 2 +- memgpt/cli/cli.py | 36 ++- memgpt/cli/cli_config.py | 164 ++++++++------ memgpt/cli/cli_load.py | 10 +- memgpt/config.py | 222 +++++-------------- memgpt/credentials.py | 117 ++++++++++ memgpt/data_types.py | 158 +------------ memgpt/embeddings.py | 14 +- memgpt/functions/function_sets/extras.py | 2 +- memgpt/{openai_tools.py => llm_api_tools.py} | 21 +- memgpt/local_llm/chat_completion_proxy.py | 19 +- memgpt/local_llm/koboldcpp/api.py | 6 +- memgpt/local_llm/llamacpp/api.py | 6 +- memgpt/local_llm/lmstudio/api.py | 5 +- memgpt/local_llm/ollama/api.py | 5 +- memgpt/local_llm/utils.py | 28 +++ memgpt/local_llm/vllm/api.py | 7 +- memgpt/local_llm/webui/api.py | 6 +- memgpt/local_llm/webui/legacy_api.py | 6 +- memgpt/memory.py | 4 +- memgpt/metadata.py | 16 -- memgpt/server/server.py | 38 ++-- tests/test_load_archival.py | 13 +- tests/test_metadata_store.py | 8 +- tests/test_server.py | 41 ++-- tests/test_storage.py | 11 +- 27 files changed, 454 insertions(+), 513 deletions(-) create mode 100644 memgpt/credentials.py rename memgpt/{openai_tools.py => llm_api_tools.py} (94%) diff --git a/memgpt/agent.py b/memgpt/agent.py index c5ae5380..ca5d379a 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -17,7 +17,7 @@ from memgpt.persistence_manager import PersistenceManager, LocalStateManager from memgpt.config import MemGPTConfig from memgpt.system import get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from memgpt.memory import CoreMemory as InContextMemory, summarize_messages -from memgpt.openai_tools import create, is_context_overflow_error +from memgpt.llm_api_tools import create, is_context_overflow_error from memgpt.utils import ( get_tool_call_id, get_local_time, diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index 0c9c078f..6cf62e0a 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -122,7 +122,7 @@ def get_db_model( user = ms.get_user(user_id) if user is None: raise ValueError(f"User {user_id} not found") - embedding_dim = user.default_embedding_config.embedding_dim + embedding_dim = config.default_embedding_config.embedding_dim # this cannot be the case if we are making an agent-specific table assert table_type != TableType.RECALL_MEMORY, f"Agent {agent_id} not found" diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 88f620c9..409da048 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -22,6 +22,7 @@ import memgpt.presets.presets as presets import memgpt.utils as utils from memgpt.utils import printd, open_folder_in_explorer, suppress_stdout from memgpt.config import MemGPTConfig +from memgpt.credentials import MemGPTCredentials from memgpt.constants import MEMGPT_DIR, CLI_WARNING_PREFIX, JSON_ENSURE_ASCII from memgpt.agent import Agent from memgpt.embeddings import embedding_model @@ -71,11 +72,27 @@ def set_config_with_dict(new_config: dict) -> bool: printd(f"Saving new config file.") old_config.save() typer.secho(f"šŸ“– MemGPT configuration file updated!", fg=typer.colors.GREEN) - typer.secho(f"🧠 model\t-> {old_config.model}\nšŸ–„ļø endpoint\t-> {old_config.model_endpoint}", fg=typer.colors.GREEN) + typer.secho( + "\n".join( + [ + f"🧠 model\t-> {old_config.default_llm_config.model}", + f"šŸ–„ļø endpoint\t-> {old_config.default_llm_config.model_endpoint}", + ] + ), + fg=typer.colors.GREEN, + ) return True else: typer.secho(f"šŸ“– MemGPT configuration file unchanged.", fg=typer.colors.WHITE) - typer.secho(f"🧠 model\t-> {old_config.model}\nšŸ–„ļø endpoint\t-> {old_config.model_endpoint}", fg=typer.colors.WHITE) + typer.secho( + "\n".join( + [ + f"🧠 model\t-> {old_config.default_llm_config.model}", + f"šŸ–„ļø endpoint\t-> {old_config.default_llm_config.model_endpoint}", + ] + ), + fg=typer.colors.WHITE, + ) return False @@ -95,6 +112,7 @@ def quickstart( # make sure everything is set up properly MemGPTConfig.create_config_dir() + credentials = MemGPTCredentials.load() config_was_modified = False if backend == QuickstartChoice.memgpt_hosted: @@ -145,6 +163,8 @@ def quickstart( while api_key is None or len(api_key) == 0: # Ask for API key as input api_key = questionary.text("Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):").ask() + credentials.openai_key = api_key + credentials.save() # if latest, try to pull the config from the repo # fallback to using local @@ -158,8 +178,6 @@ def quickstart( config = response.json() # Output a success message and the first few items in the dictionary as a sample print("JSON config file downloaded successfully.") - # Add the API key - config["openai_key"] = api_key config_was_modified = set_config_with_dict(config) else: typer.secho(f"Failed to download config from {url}. Status code: {response.status_code}", fg=typer.colors.RED) @@ -170,7 +188,6 @@ def quickstart( try: with open(backup_config_path, "r") as file: backup_config = json.load(file) - backup_config["openai_key"] = api_key printd("Loaded backup config file successfully.") config_was_modified = set_config_with_dict(backup_config) except FileNotFoundError: @@ -183,7 +200,6 @@ def quickstart( try: with open(backup_config_path, "r") as file: backup_config = json.load(file) - backup_config["openai_key"] = api_key printd("Loaded config file successfully.") config_was_modified = set_config_with_dict(backup_config) except FileNotFoundError: @@ -492,8 +508,8 @@ def run( # agent = f"agent_{agent_count}" agent = utils.create_random_username() - llm_config = user.default_llm_config - embedding_config = user.default_embedding_config # TODO allow overriding embedding params via CLI run + llm_config = config.default_llm_config + embedding_config = config.default_embedding_config # TODO allow overriding embedding params via CLI run # Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window) if model and model != llm_config.model: @@ -579,7 +595,9 @@ def run( original_stdout = sys.stdout # unfortunate hack required to suppress confusing print statements from llama index sys.stdout = io.StringIO() embed_model = embedding_model(config=agent_state.embedding_config, user_id=user.id) - service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size) + service_context = ServiceContext.from_defaults( + llm=None, embed_model=embed_model, chunk_size=agent_state.embedding_config.embedding_chunk_size + ) set_global_service_context(service_context) sys.stdout = original_stdout diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 4f0792f8..8574cafb 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -16,13 +16,14 @@ from memgpt.log import logger from memgpt import utils from memgpt.config import MemGPTConfig +from memgpt.credentials import MemGPTCredentials from memgpt.constants import MEMGPT_DIR # from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.constants import LLM_MAX_TOKENS from memgpt.local_llm.constants import DEFAULT_ENDPOINTS, DEFAULT_OLLAMA_MODEL, DEFAULT_WRAPPER_NAME from memgpt.local_llm.utils import get_available_wrappers -from memgpt.openai_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin +from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin from memgpt.server.utils import shorten_key_middle from memgpt.data_types import User, LLMConfig, EmbeddingConfig from memgpt.metadata import MetadataStore @@ -50,13 +51,16 @@ def get_openai_credentials(): return openai_key -def configure_llm_endpoint(config: MemGPTConfig): +def configure_llm_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials): # configure model endpoint model_endpoint_type, model_endpoint = None, None # get default - default_model_endpoint_type = config.model_endpoint_type - if config.model_endpoint_type is not None and config.model_endpoint_type not in ["openai", "azure"]: # local model + default_model_endpoint_type = config.default_llm_config.model_endpoint_type + if config.default_llm_config.model_endpoint_type is not None and config.default_llm_config.model_endpoint_type not in [ + "openai", + "azure", + ]: # local model default_model_endpoint_type = "local" provider = questionary.select( @@ -68,7 +72,7 @@ def configure_llm_endpoint(config: MemGPTConfig): # set: model_endpoint_type, model_endpoint if provider == "openai": # check for key - if config.openai_key is None: + if credentials.openai_key is None: # allow key to get pulled from env vars openai_api_key = os.getenv("OPENAI_API_KEY", None) # if we still can't find it, ask for it as input @@ -80,12 +84,14 @@ def configure_llm_endpoint(config: MemGPTConfig): ).ask() if openai_api_key is None: raise KeyboardInterrupt - config.openai_key = openai_api_key - config.save() + credentials.openai_key = openai_api_key + credentials.save() else: # Give the user an opportunity to overwrite the key openai_api_key = None - default_input = shorten_key_middle(config.openai_key) if config.openai_key.startswith("sk-") else config.openai_key + default_input = ( + shorten_key_middle(credentials.openai_key) if credentials.openai_key.startswith("sk-") else credentials.openai_key + ) openai_api_key = questionary.text( "Enter your OpenAI API key (starts with 'sk-', see https://platform.openai.com/api-keys):", default=default_input, @@ -94,8 +100,8 @@ def configure_llm_endpoint(config: MemGPTConfig): raise KeyboardInterrupt # If the user modified it, use the new one if openai_api_key != default_input: - config.openai_key = openai_api_key - config.save() + credentials.openai_key = openai_api_key + credentials.save() model_endpoint_type = "openai" model_endpoint = "https://api.openai.com/v1" @@ -112,9 +118,9 @@ def configure_llm_endpoint(config: MemGPTConfig): "Missing environment variables for Azure (see https://memgpt.readme.io/docs/endpoints#azure-openai). Please set then run `memgpt configure` again." ) else: - config.azure_key = azure_creds["azure_key"] - config.azure_endpoint = azure_creds["azure_endpoint"] - config.azure_version = azure_creds["azure_version"] + credentials.azure_key = azure_creds["azure_key"] + credentials.azure_endpoint = azure_creds["azure_endpoint"] + credentials.azure_version = azure_creds["azure_version"] config.save() model_endpoint_type = "azure" @@ -123,9 +129,9 @@ def configure_llm_endpoint(config: MemGPTConfig): else: # local models backend_options = ["webui", "webui-legacy", "llamacpp", "koboldcpp", "ollama", "lmstudio", "lmstudio-legacy", "vllm", "openai"] default_model_endpoint_type = None - if config.model_endpoint_type in backend_options: + if config.default_llm_config.model_endpoint_type in backend_options: # set from previous config - default_model_endpoint_type = config.model_endpoint_type + default_model_endpoint_type = config.default_llm_config.model_endpoint_type model_endpoint_type = questionary.select( "Select LLM backend (select 'openai' if you have an OpenAI compatible proxy):", backend_options, @@ -149,13 +155,13 @@ def configure_llm_endpoint(config: MemGPTConfig): model_endpoint = questionary.text("Enter default endpoint:", default=default_model_endpoint).ask() if model_endpoint is None: raise KeyboardInterrupt - elif config.model_endpoint: - model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask() + elif config.default_llm_config.model_endpoint: + model_endpoint = questionary.text("Enter default endpoint:", default=config.default_llm_config.model_endpoint).ask() if model_endpoint is None: raise KeyboardInterrupt while not utils.is_valid_url(model_endpoint): typer.secho(f"Endpoint must be a valid address", fg=typer.colors.YELLOW) - model_endpoint = questionary.text("Enter default endpoint:", default=config.model_endpoint).ask() + model_endpoint = questionary.text("Enter default endpoint:", default=config.default_llm_config.model_endpoint).ask() if model_endpoint is None: raise KeyboardInterrupt else: @@ -176,7 +182,7 @@ def configure_llm_endpoint(config: MemGPTConfig): return model_endpoint_type, model_endpoint -def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoint: str): +def configure_model(config: MemGPTConfig, credentials: MemGPTCredentials, model_endpoint_type: str, model_endpoint: str): # set: model, model_wrapper model, model_wrapper = None, None if model_endpoint_type == "openai" or model_endpoint_type == "azure": @@ -185,10 +191,10 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi fetched_model_options = None try: if model_endpoint_type == "openai": - fetched_model_options = openai_get_model_list(url=model_endpoint, api_key=config.openai_key) + fetched_model_options = openai_get_model_list(url=model_endpoint, api_key=credentials.openai_key) elif model_endpoint_type == "azure": fetched_model_options = azure_openai_get_model_list( - url=model_endpoint, api_key=config.azure_key, api_version=config.azure_version + url=model_endpoint, api_key=credentials.azure_key, api_version=credentials.azure_version ) fetched_model_options = [obj["id"] for obj in fetched_model_options["data"] if obj["id"].startswith("gpt-")] except: @@ -217,7 +223,7 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi model = questionary.select( "Select default model (recommended: gpt-4):", choices=fetched_model_options + [other_option_str], - default=config.model if valid_model else fetched_model_options[0], + default=config.default_llm_config.model if valid_model else fetched_model_options[0], ).ask() if model is None: raise KeyboardInterrupt @@ -235,7 +241,11 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi else: # local models # ollama also needs model type if model_endpoint_type == "ollama": - default_model = config.model if config.model and config.model_endpoint_type == "ollama" else DEFAULT_OLLAMA_MODEL + default_model = ( + config.default_llm_config.model + if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "ollama" + else DEFAULT_OLLAMA_MODEL + ) model = questionary.text( "Enter default model name (required for Ollama, see: https://memgpt.readme.io/docs/ollama):", default=default_model, @@ -244,7 +254,11 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi raise KeyboardInterrupt model = None if len(model) == 0 else model - default_model = config.model if config.model and config.model_endpoint_type == "vllm" else "" + default_model = ( + config.default_llm_config.model + if config.default_llm_config.model and config.default_llm_config.model_endpoint_type == "vllm" + else "" + ) # vllm needs huggingface model tag if model_endpoint_type == "vllm": @@ -260,10 +274,12 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi # If we got model options from vLLM endpoint, allow selection + custom input if model_options is not None: other_option_str = "other (enter name)" - valid_model = config.model in model_options + valid_model = config.default_llm_config.model in model_options model_options.append(other_option_str) model = questionary.select( - "Select default model:", choices=model_options, default=config.model if valid_model else model_options[0] + "Select default model:", + choices=model_options, + default=config.default_llm_config.model if valid_model else model_options[0], ).ask() if model is None: raise KeyboardInterrupt @@ -336,10 +352,10 @@ def configure_model(config: MemGPTConfig, model_endpoint_type: str, model_endpoi return model, model_wrapper, context_window -def configure_embedding_endpoint(config: MemGPTConfig): +def configure_embedding_endpoint(config: MemGPTConfig, credentials: MemGPTCredentials): # configure embedding endpoint - default_embedding_endpoint_type = config.embedding_endpoint_type + default_embedding_endpoint_type = config.default_embedding_config.embedding_endpoint_type embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = None, None, None, None embedding_provider = questionary.select( @@ -350,7 +366,7 @@ def configure_embedding_endpoint(config: MemGPTConfig): if embedding_provider == "openai": # check for key - if config.openai_key is None: + if credentials.openai_key is None: # allow key to get pulled from env vars openai_api_key = os.getenv("OPENAI_API_KEY", None) if openai_api_key is None: @@ -362,7 +378,7 @@ def configure_embedding_endpoint(config: MemGPTConfig): ).ask() if openai_api_key is None: raise KeyboardInterrupt - config.openai_key = openai_api_key + credentials.openai_key = openai_api_key config.save() embedding_endpoint_type = "openai" @@ -397,7 +413,9 @@ def configure_embedding_endpoint(config: MemGPTConfig): embedding_endpoint = questionary.text("Enter default endpoint:").ask() # get model type - default_embedding_model = config.embedding_model if config.embedding_model else "BAAI/bge-large-en-v1.5" + default_embedding_model = ( + config.default_embedding_config.embedding_model if config.default_embedding_config.embedding_model else "BAAI/bge-large-en-v1.5" + ) embedding_model = questionary.text( "Enter HuggingFace model tag (e.g. BAAI/bge-large-en-v1.5):", default=default_embedding_model, @@ -406,7 +424,7 @@ def configure_embedding_endpoint(config: MemGPTConfig): raise KeyboardInterrupt # get model dimentions - default_embedding_dim = config.embedding_dim if config.embedding_dim else "1024" + default_embedding_dim = config.default_embedding_config.embedding_dim if config.default_embedding_config.embedding_dim else "1024" embedding_dim = questionary.text("Enter embedding model dimentions (e.g. 1024):", default=str(default_embedding_dim)).ask() if embedding_dim is None: raise KeyboardInterrupt @@ -422,7 +440,7 @@ def configure_embedding_endpoint(config: MemGPTConfig): return embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model -def configure_cli(config: MemGPTConfig): +def configure_cli(config: MemGPTConfig, credentials: MemGPTCredentials): # set: preset, default_persona, default_human, default_agent`` from memgpt.presets.presets import preset_options @@ -452,7 +470,7 @@ def configure_cli(config: MemGPTConfig): return preset, persona, human, agent -def configure_archival_storage(config: MemGPTConfig): +def configure_archival_storage(config: MemGPTConfig, credentials: MemGPTCredentials): # Configure archival storage backend archival_storage_options = ["postgres", "chroma"] archival_storage_type = questionary.select( @@ -522,6 +540,7 @@ def configure(): """Updates default MemGPT configurations""" # check credentials + credentials = MemGPTCredentials.load() openai_key = get_openai_credentials() azure_creds = get_azure_credentials() @@ -530,34 +549,54 @@ def configure(): # Will pre-populate with defaults, or what the user previously set config = MemGPTConfig.load() try: - model_endpoint_type, model_endpoint = configure_llm_endpoint(config) - model, model_wrapper, context_window = configure_model( - config=config, model_endpoint_type=model_endpoint_type, model_endpoint=model_endpoint + model_endpoint_type, model_endpoint = configure_llm_endpoint( + config=config, + credentials=credentials, + ) + model, model_wrapper, context_window = configure_model( + config=config, + credentials=credentials, + model_endpoint_type=model_endpoint_type, + model_endpoint=model_endpoint, + ) + embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint( + config=config, + credentials=credentials, + ) + default_preset, default_persona, default_human, default_agent = configure_cli( + config=config, + credentials=credentials, + ) + archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage( + config=config, + credentials=credentials, + ) + recall_storage_type, recall_storage_uri, recall_storage_path = configure_recall_storage( + config=config, + credentials=credentials, ) - embedding_endpoint_type, embedding_endpoint, embedding_dim, embedding_model = configure_embedding_endpoint(config) - default_preset, default_persona, default_human, default_agent = configure_cli(config) - archival_storage_type, archival_storage_uri, archival_storage_path = configure_archival_storage(config) - recall_storage_type, recall_storage_uri, recall_storage_path = configure_recall_storage(config) except ValueError as e: typer.secho(str(e), fg=typer.colors.RED) return # openai key might have gotten added along the way - openai_key = config.openai_key if config.openai_key is not None else openai_key + openai_key = credentials.openai_key if credentials.openai_key is not None else openai_key # TODO: remove most of this (deplicated with User table) config = MemGPTConfig( - # model configs - model=model, - model_endpoint=model_endpoint, - model_endpoint_type=model_endpoint_type, - model_wrapper=model_wrapper, - context_window=context_window, - # embedding configs - embedding_endpoint_type=embedding_endpoint_type, - embedding_endpoint=embedding_endpoint, - embedding_dim=embedding_dim, - embedding_model=embedding_model, + default_llm_config=LLMConfig( + model=model, + model_endpoint=model_endpoint, + model_endpoint_type=model_endpoint_type, + model_wrapper=model_wrapper, + context_window=context_window, + ), + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type=embedding_endpoint_type, + embedding_endpoint=embedding_endpoint, + embedding_dim=embedding_dim, + embedding_model=embedding_model, + ), # cli configs preset=default_preset, persona=default_persona, @@ -596,25 +635,6 @@ def configure(): default_persona=default_persona, default_human=default_human, default_agent=default_agent, - default_llm_config=LLMConfig( - model=model, - model_endpoint=model_endpoint, - model_endpoint_type=model_endpoint_type, - model_wrapper=model_wrapper, - context_window=context_window, - openai_key=openai_key, - ), - default_embedding_config=EmbeddingConfig( - embedding_endpoint_type=embedding_endpoint_type, - embedding_endpoint=embedding_endpoint, - embedding_dim=embedding_dim, - embedding_model=embedding_model, - openai_key=openai_key, - azure_key=azure_creds["azure_key"], - azure_endpoint=azure_creds["azure_endpoint"], - azure_version=azure_creds["azure_version"], - azure_deployment=azure_creds["azure_deployment"], # OK if None - ), ) if ms.get_user(user_id): # update user diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 37453a7e..302ba9ee 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -103,11 +103,13 @@ def store_docs(name, docs, user_id=None, show_progress=True): print(f"Source {name} for user {user.id} already exists") # compute and record passages - embed_model = embedding_model(user.default_embedding_config) + embed_model = embedding_model(config.default_embedding_config) # use llama index to run embeddings code with suppress_stdout(): - service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size) + service_context = ServiceContext.from_defaults( + llm=None, embed_model=embed_model, chunk_size=config.default_embedding_config.embedding_chunk_size + ) index = VectorStoreIndex.from_documents(docs, service_context=service_context, show_progress=True) embed_dict = index._vector_store._data.embedding_dict node_dict = index._docstore.docs @@ -121,8 +123,8 @@ def store_docs(name, docs, user_id=None, show_progress=True): node.embedding = vector text = node.text.replace("\x00", "\uFFFD") # hacky fix for error on null characters assert ( - len(node.embedding) == user.default_embedding_config.embedding_dim - ), f"Expected embedding dimension {user.default_embedding_config.embedding_dim}, got {len(node.embedding)}: {node.embedding}" + len(node.embedding) == config.default_embedding_config.embedding_dim + ), f"Expected embedding dimension {config.default_embedding_config.embedding_dim}, got {len(node.embedding)}: {node.embedding}" passages.append( Passage( user_id=user.id, diff --git a/memgpt/config.py b/memgpt/config.py index a0ffd86b..a4c24c44 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -3,7 +3,7 @@ import inspect import json import os import uuid -from dataclasses import dataclass +from dataclasses import dataclass, field import configparser import typer import questionary @@ -35,119 +35,6 @@ def set_field(config, section, field, value): config.set(section, field, value) -@dataclass -class Config: - # system config for MemGPT - config_path = os.path.join(MEMGPT_DIR, "config") - anon_clientid = None - - # database configs: archival - archival_storage_type: str = "chroma" # local, db - archival_storage_path: str = os.path.join(MEMGPT_DIR, "chroma") - archival_storage_uri: str = None # TODO: eventually allow external vector DB - - # database configs: recall - recall_storage_type: str = "sqlite" # local, db - recall_storage_path: str = MEMGPT_DIR - recall_storage_uri: str = None # TODO: eventually allow external vector DB - - # database configs: metadata storage (sources, agents, data sources) - metadata_storage_type: str = "sqlite" - metadata_storage_path: str = MEMGPT_DIR - metadata_storage_uri: str = None - - memgpt_version: str = None - - @classmethod - def load(cls) -> "MemGPTConfig": - config = configparser.ConfigParser() - - # allow overriding with env variables - if os.getenv("MEMGPT_CONFIG_PATH"): - config_path = os.getenv("MEMGPT_CONFIG_PATH") - else: - config_path = MemGPTConfig.config_path - - if os.path.exists(config_path): - # read existing config - config.read(config_path) - config_dict = { - "archival_storage_type": get_field(config, "archival_storage", "type"), - "archival_storage_path": get_field(config, "archival_storage", "path"), - "archival_storage_uri": get_field(config, "archival_storage", "uri"), - "recall_storage_type": get_field(config, "recall_storage", "type"), - "recall_storage_path": get_field(config, "recall_storage", "path"), - "recall_storage_uri": get_field(config, "recall_storage", "uri"), - "metadata_storage_type": get_field(config, "metadata_storage", "type"), - "metadata_storage_path": get_field(config, "metadata_storage", "path"), - "metadata_storage_uri": get_field(config, "metadata_storage", "uri"), - "anon_clientid": get_field(config, "client", "anon_clientid"), - "config_path": config_path, - "memgpt_version": get_field(config, "version", "memgpt_version"), - } - config_dict = {k: v for k, v in config_dict.items() if v is not None} - return cls(**config_dict) - - # create new config - anon_clientid = str(uuid.uuid()) - config = cls(anon_clientid=anon_clientid, config_path=config_path) - config.save() # save updated config - return config - - def save(self): - import memgpt - - config = configparser.ConfigParser() - # archival storage - set_field(config, "archival_storage", "type", self.archival_storage_type) - set_field(config, "archival_storage", "path", self.archival_storage_path) - set_field(config, "archival_storage", "uri", self.archival_storage_uri) - - # recall storage - set_field(config, "recall_storage", "type", self.recall_storage_type) - set_field(config, "recall_storage", "path", self.recall_storage_path) - set_field(config, "recall_storage", "uri", self.recall_storage_uri) - - # metadata storage - set_field(config, "metadata_storage", "type", self.metadata_storage_type) - set_field(config, "metadata_storage", "path", self.metadata_storage_path) - set_field(config, "metadata_storage", "uri", self.metadata_storage_uri) - - # set version - set_field(config, "version", "memgpt_version", memgpt.__version__) - - # client - if not self.anon_clientid: - self.anon_clientid = str(uuid.uuid()) - set_field(config, "client", "anon_clientid", self.anon_clientid) - - if not os.path.exists(MEMGPT_DIR): - os.makedirs(MEMGPT_DIR, exist_ok=True) - with open(self.config_path, "w") as f: - config.write(f) - - @staticmethod - def exists(): - # allow overriding with env variables - if os.getenv("MEMGPT_CONFIG_PATH"): - config_path = os.getenv("MEMGPT_CONFIG_PATH") - else: - config_path = MemGPTConfig.config_path - - assert not os.path.isdir(config_path), f"Config path {config_path} cannot be set to a directory." - return os.path.exists(config_path) - - @staticmethod - def create_config_dir(): - if not os.path.exists(MEMGPT_DIR): - os.makedirs(MEMGPT_DIR, exist_ok=True) - - folders = ["functions", "system_prompts", "presets", "settings"] - for folder in folders: - if not os.path.exists(os.path.join(MEMGPT_DIR, folder)): - os.makedirs(os.path.join(MEMGPT_DIR, folder)) - - @dataclass class MemGPTConfig: config_path: str = os.path.join(MEMGPT_DIR, "config") @@ -156,34 +43,16 @@ class MemGPTConfig: # preset preset: str = DEFAULT_PRESET - # model parameters - model: str = None - model_endpoint_type: str = None - model_endpoint: str = None # localhost:8000 - model_wrapper: str = None - context_window: int = LLM_MAX_TOKENS[model] if model in LLM_MAX_TOKENS else LLM_MAX_TOKENS["DEFAULT"] - - # model parameters: openai - openai_key: str = None - - # model parameters: azure - azure_key: str = None - azure_endpoint: str = None - azure_version: str = None - azure_deployment: str = None - azure_embedding_deployment: str = None - # persona parameters persona: str = DEFAULT_PERSONA human: str = DEFAULT_HUMAN agent: str = None + # model parameters + default_llm_config: LLMConfig = field(default_factory=LLMConfig) + # embedding parameters - embedding_endpoint_type: str = "openai" # openai, azure, local - embedding_endpoint: str = None - embedding_model: str = None - embedding_dim: int = 1536 - embedding_chunk_size: int = 300 # number of tokens + default_embedding_config: EmbeddingConfig = field(default_factory=EmbeddingConfig) # database configs: archival archival_storage_type: str = "chroma" # local, db @@ -213,9 +82,10 @@ class MemGPTConfig: def __post_init__(self): # ensure types - self.embedding_chunk_size = int(self.embedding_chunk_size) - self.embedding_dim = int(self.embedding_dim) - self.context_window = int(self.context_window) + # self.embedding_chunk_size = int(self.embedding_chunk_size) + # self.embedding_dim = int(self.embedding_dim) + # self.context_window = int(self.context_window) + pass @staticmethod def generate_uuid() -> str: @@ -248,27 +118,46 @@ class MemGPTConfig: if os.path.exists(config_path): # read existing config config.read(config_path) - config_dict = { + + # Handle extraction of nested LLMConfig and EmbeddingConfig + llm_config_dict = { + # Extract relevant LLM configuration from the config file "model": get_field(config, "model", "model"), "model_endpoint": get_field(config, "model", "model_endpoint"), "model_endpoint_type": get_field(config, "model", "model_endpoint_type"), "model_wrapper": get_field(config, "model", "model_wrapper"), "context_window": get_field(config, "model", "context_window"), - "preset": get_field(config, "defaults", "preset"), - "persona": get_field(config, "defaults", "persona"), - "human": get_field(config, "defaults", "human"), - "agent": get_field(config, "defaults", "agent"), - "openai_key": get_field(config, "openai", "key"), - "azure_key": get_field(config, "azure", "key"), - "azure_endpoint": get_field(config, "azure", "endpoint"), - "azure_version": get_field(config, "azure", "version"), - "azure_deployment": get_field(config, "azure", "deployment"), - "azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"), + } + embedding_config_dict = { + # Extract relevant Embedding configuration from the config file "embedding_endpoint": get_field(config, "embedding", "embedding_endpoint"), "embedding_model": get_field(config, "embedding", "embedding_model"), "embedding_endpoint_type": get_field(config, "embedding", "embedding_endpoint_type"), "embedding_dim": get_field(config, "embedding", "embedding_dim"), "embedding_chunk_size": get_field(config, "embedding", "chunk_size"), + } + # Correct the types that aren't strings + if llm_config_dict["context_window"] is not None: + llm_config_dict["context_window"] = int(llm_config_dict["context_window"]) + if embedding_config_dict["embedding_dim"] is not None: + embedding_config_dict["embedding_dim"] = int(embedding_config_dict["embedding_dim"]) + if embedding_config_dict["embedding_chunk_size"] is not None: + embedding_config_dict["embedding_chunk_size"] = int(embedding_config_dict["embedding_chunk_size"]) + # Construct the inner properties + llm_config = LLMConfig(**llm_config_dict) + embedding_config = EmbeddingConfig(**embedding_config_dict) + + # Everything else + config_dict = { + # Two prepared configs + "default_llm_config": llm_config, + "default_embedding_config": embedding_config, + # Agent related + "preset": get_field(config, "defaults", "preset"), + "persona": get_field(config, "defaults", "persona"), + "human": get_field(config, "defaults", "human"), + "agent": get_field(config, "defaults", "agent"), + # Storage related "archival_storage_type": get_field(config, "archival_storage", "type"), "archival_storage_path": get_field(config, "archival_storage", "path"), "archival_storage_uri": get_field(config, "archival_storage", "uri"), @@ -278,10 +167,13 @@ class MemGPTConfig: "metadata_storage_type": get_field(config, "metadata_storage", "type"), "metadata_storage_path": get_field(config, "metadata_storage", "path"), "metadata_storage_uri": get_field(config, "metadata_storage", "uri"), + # Misc "anon_clientid": get_field(config, "client", "anon_clientid"), "config_path": config_path, "memgpt_version": get_field(config, "version", "memgpt_version"), } + + # Don't include null values config_dict = {k: v for k, v in config_dict.items() if v is not None} return cls(**config_dict) @@ -306,28 +198,18 @@ class MemGPTConfig: set_field(config, "defaults", "agent", self.agent) # model defaults - set_field(config, "model", "model", self.model) - set_field(config, "model", "model_endpoint", self.model_endpoint) - set_field(config, "model", "model_endpoint_type", self.model_endpoint_type) - set_field(config, "model", "model_wrapper", self.model_wrapper) - set_field(config, "model", "context_window", str(self.context_window)) - - # security credentials: openai - set_field(config, "openai", "key", self.openai_key) - - # security credentials: azure - set_field(config, "azure", "key", self.azure_key) - set_field(config, "azure", "endpoint", self.azure_endpoint) - set_field(config, "azure", "version", self.azure_version) - set_field(config, "azure", "deployment", self.azure_deployment) - set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment) + set_field(config, "model", "model", self.default_llm_config.model) + set_field(config, "model", "model_endpoint", self.default_llm_config.model_endpoint) + set_field(config, "model", "model_endpoint_type", self.default_llm_config.model_endpoint_type) + set_field(config, "model", "model_wrapper", self.default_llm_config.model_wrapper) + set_field(config, "model", "context_window", str(self.default_llm_config.context_window)) # embeddings - set_field(config, "embedding", "embedding_endpoint_type", self.embedding_endpoint_type) - set_field(config, "embedding", "embedding_endpoint", self.embedding_endpoint) - set_field(config, "embedding", "embedding_model", self.embedding_model) - set_field(config, "embedding", "embedding_dim", str(self.embedding_dim)) - set_field(config, "embedding", "embedding_chunk_size", str(self.embedding_chunk_size)) + set_field(config, "embedding", "embedding_endpoint_type", self.default_embedding_config.embedding_endpoint_type) + set_field(config, "embedding", "embedding_endpoint", self.default_embedding_config.embedding_endpoint) + set_field(config, "embedding", "embedding_model", self.default_embedding_config.embedding_model) + set_field(config, "embedding", "embedding_dim", str(self.default_embedding_config.embedding_dim)) + set_field(config, "embedding", "embedding_chunk_size", str(self.default_embedding_config.embedding_chunk_size)) # archival storage set_field(config, "archival_storage", "type", self.archival_storage_type) diff --git a/memgpt/credentials.py b/memgpt/credentials.py new file mode 100644 index 00000000..be78f89a --- /dev/null +++ b/memgpt/credentials.py @@ -0,0 +1,117 @@ +from memgpt.log import logger +import inspect +import json +import os +import uuid +from dataclasses import dataclass +import configparser +import typer +import questionary + +import memgpt +import memgpt.utils as utils +from memgpt.utils import printd, get_schema_diff +from memgpt.functions.functions import load_all_function_sets + +from memgpt.constants import MEMGPT_DIR, LLM_MAX_TOKENS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET +from memgpt.data_types import AgentState, User, LLMConfig, EmbeddingConfig +from memgpt.config import get_field, set_field + + +SUPPORTED_AUTH_TYPES = ["bearer_token", "api_key"] + + +@dataclass +class MemGPTCredentials: + # credentials for MemGPT + credentials_path: str = os.path.join(MEMGPT_DIR, "credentials") + + # openai config + openai_auth_type: str = "bearer_token" + openai_key: str = None + + # azure config + azure_auth_type: str = "api_key" + azure_key: str = None + azure_endpoint: str = None + azure_version: str = None + azure_deployment: str = None + azure_embedding_deployment: str = None + + # custom llm API config + openllm_auth_type: str = None + openllm_key: str = None + + @classmethod + def load(cls) -> "MemGPTCredentials": + config = configparser.ConfigParser() + + # allow overriding with env variables + if os.getenv("MEMGPT_CREDENTIALS_PATH"): + credentials_path = os.getenv("MEMGPT_CREDENTIALS_PATH") + else: + credentials_path = MemGPTCredentials.credentials_path + + if os.path.exists(credentials_path): + # read existing credentials + config.read(credentials_path) + config_dict = { + # openai + "openai_auth_type": get_field(config, "openai", "auth_type"), + "openai_key": get_field(config, "openai", "key"), + # azure + "azure_auth_type": get_field(config, "azure", "auth_type"), + "azure_key": get_field(config, "azure", "key"), + "azure_endpoint": get_field(config, "azure", "endpoint"), + "azure_version": get_field(config, "azure", "version"), + "azure_deployment": get_field(config, "azure", "deployment"), + "azure_embedding_deployment": get_field(config, "azure", "embedding_deployment"), + # open llm + "openllm_auth_type": get_field(config, "openllm", "auth_type"), + "openllm_key": get_field(config, "openllm", "key"), + # path + "credentials_path": credentials_path, + } + config_dict = {k: v for k, v in config_dict.items() if v is not None} + return cls(**config_dict) + + # create new config + config = cls(credentials_path=credentials_path) + config.save() # save updated config + return config + + def save(self): + import memgpt + + config = configparser.ConfigParser() + # openai config + set_field(config, "openai", "auth_type", self.openai_auth_type) + set_field(config, "openai", "key", self.openai_key) + + # azure config + set_field(config, "azure", "auth_type", self.azure_auth_type) + set_field(config, "azure", "key", self.azure_key) + set_field(config, "azure", "endpoint", self.azure_endpoint) + set_field(config, "azure", "version", self.azure_version) + set_field(config, "azure", "deployment", self.azure_deployment) + set_field(config, "azure", "embedding_deployment", self.azure_embedding_deployment) + + # openai config + set_field(config, "openllm", "auth_type", self.openllm_auth_type) + set_field(config, "openllm", "key", self.openllm_key) + + if not os.path.exists(MEMGPT_DIR): + os.makedirs(MEMGPT_DIR, exist_ok=True) + with open(self.credentials_path, "w") as f: + config.write(f) + + @staticmethod + def exists(): + # allow overriding with env variables + if os.getenv("MEMGPT_CREDENTIALS_PATH"): + credentials_path = os.getenv("MEMGPT_CREDENTIALS_PATH") + else: + credentials_path = MemGPTCredentials.credentials_path + + assert not os.path.isdir(credentials_path), f"Credentials path {credentials_path} cannot be set to a directory." + return os.path.exists(credentials_path) diff --git a/memgpt/data_types.py b/memgpt/data_types.py index b4b71eb1..8eaabe16 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -291,9 +291,6 @@ class Passage(Record): assert not agent_id or isinstance(self.agent_id, uuid.UUID), f"UUID {self.agent_id} must be a UUID type" assert not doc_id or isinstance(self.doc_id, uuid.UUID), f"UUID {self.doc_id} must be a UUID type" - # def __repr__(self): - # pass - class LLMConfig: def __init__( @@ -303,13 +300,6 @@ class LLMConfig: model_endpoint: Optional[str] = "https://api.openai.com/v1", model_wrapper: Optional[str] = None, context_window: Optional[int] = None, - # openai-only - openai_key: Optional[str] = None, - # azure-only - azure_key: Optional[str] = None, - azure_endpoint: Optional[str] = None, - azure_version: Optional[str] = None, - azure_deployment: Optional[str] = None, ): self.model = model self.model_endpoint_type = model_endpoint_type @@ -322,31 +312,15 @@ class LLMConfig: else: self.context_window = context_window - # openai - self.openai_key = openai_key - - # azure - self.azure_key = azure_key - self.azure_endpoint = azure_endpoint - self.azure_version = azure_version - self.azure_deployment = azure_deployment - class EmbeddingConfig: def __init__( self, - embedding_endpoint_type: Optional[str] = "local", - embedding_endpoint: Optional[str] = None, - embedding_model: Optional[str] = None, - embedding_dim: Optional[int] = 384, + embedding_endpoint_type: Optional[str] = "openai", + embedding_endpoint: Optional[str] = "https://api.openai.com/v1", + embedding_model: Optional[str] = "text-embedding-ada-002", + embedding_dim: Optional[int] = 1536, embedding_chunk_size: Optional[int] = 300, - # openai-only - openai_key: Optional[str] = None, - # azure-only - azure_key: Optional[str] = None, - azure_endpoint: Optional[str] = None, - azure_version: Optional[str] = None, - azure_deployment: Optional[str] = None, ): self.embedding_endpoint_type = embedding_endpoint_type self.embedding_endpoint = embedding_endpoint @@ -354,15 +328,6 @@ class EmbeddingConfig: self.embedding_dim = embedding_dim self.embedding_chunk_size = embedding_chunk_size - # openai - self.openai_key = openai_key - - # azure - self.azure_key = azure_key - self.azure_endpoint = azure_endpoint - self.azure_version = azure_version - self.azure_deployment = azure_deployment - class OpenAIEmbeddingConfig(EmbeddingConfig): def __init__(self, openai_key: Optional[str] = None, **kwargs): @@ -400,15 +365,6 @@ class User: default_persona=DEFAULT_PERSONA, default_human=DEFAULT_HUMAN, default_agent=None, - default_llm_config: Optional[LLMConfig] = None, # defaults: llm model - default_embedding_config: Optional[EmbeddingConfig] = None, # defaults: embeddings - # azure information - azure_key=None, - azure_endpoint=None, - azure_version=None, - azure_deployment=None, - # openai information - openai_key=None, # other policies_accepted=False, ): @@ -423,82 +379,6 @@ class User: self.default_human = default_human self.default_agent = default_agent - # model defaults - self.default_llm_config = default_llm_config if default_llm_config is not None else LLMConfig() - self.default_embedding_config = default_embedding_config if default_embedding_config is not None else EmbeddingConfig() - - # azure information - # TODO: split this up accross model config and embedding config? - self.azure_key = azure_key - self.azure_endpoint = azure_endpoint - self.azure_version = azure_version - self.azure_deployment = azure_deployment - - # openai information - self.openai_key = openai_key - - # set default embedding config - if default_embedding_config is None: - if self.openai_key: - self.default_embedding_config = OpenAIEmbeddingConfig( - openai_key=self.openai_key, - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_dim=1536, - ) - elif self.azure_key: - self.default_embedding_config = AzureEmbeddingConfig( - azure_key=self.azure_key, - azure_endpoint=self.azure_endpoint, - azure_version=self.azure_version, - azure_deployment=self.azure_deployment, - embedding_endpoint_type="azure", - embedding_endpoint="https://api.openai.com/v1", - embedding_dim=1536, - ) - else: - # memgpt hosted - self.default_embedding_config = EmbeddingConfig( - embedding_endpoint_type="hugging-face", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_model="BAAI/bge-large-en-v1.5", - embedding_dim=1024, - embedding_chunk_size=300, - ) - - # set default LLM config - if default_llm_config is None: - if self.openai_key: - self.default_llm_config = OpenAILLMConfig( - openai_key=self.openai_key, - model="gpt-4", - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - model_wrapper=None, - context_window=LLM_MAX_TOKENS["gpt-4"], - ) - elif self.azure_key: - self.default_llm_config = AzureLLMConfig( - azure_key=self.azure_key, - azure_endpoint=self.azure_endpoint, - azure_version=self.azure_version, - azure_deployment=self.azure_deployment, - model="gpt-4", - model_endpoint_type="azure", - model_endpoint="https://api.openai.com/v1", - model_wrapper=None, - context_window=LLM_MAX_TOKENS["gpt-4"], - ) - else: - # memgpt hosted - self.default_llm_config = LLMConfig( - model="ehartford/dolphin-2.5-mixtral-8x7b", - model_endpoint_type="vllm", - model_endpoint="https://api.memgpt.ai", - model_wrapper="chatml", - context_window=16384, - ) - # misc self.policies_accepted = policies_accepted @@ -546,36 +426,6 @@ class AgentState: # state self.state = {} if not state else state - # def __eq__(self, other): - # if not isinstance(other, AgentState): - # # return False - # return NotImplemented - - # return ( - # self.name == other.name - # and self.user_id == other.user_id - # and self.persona == other.persona - # and self.human == other.human - # and vars(self.llm_config) == vars(other.llm_config) - # and vars(self.embedding_config) == vars(other.embedding_config) - # and self.preset == other.preset - # and self.state == other.state - # ) - - # def __dict__(self): - # return { - # "id": self.id, - # "name": self.name, - # "user_id": self.user_id, - # "preset": self.preset, - # "persona": self.persona, - # "human": self.human, - # "llm_config": self.llm_config, - # "embedding_config": self.embedding_config, - # "created_at": format_datetime(self.created_at), - # "state": self.state, - # } - class Source: def __init__( diff --git a/memgpt/embeddings.py b/memgpt/embeddings.py index d41830c6..e717e063 100644 --- a/memgpt/embeddings.py +++ b/memgpt/embeddings.py @@ -5,6 +5,7 @@ import os from memgpt.utils import is_valid_url from memgpt.data_types import EmbeddingConfig +from memgpt.credentials import MemGPTCredentials from llama_index.embeddings import OpenAIEmbedding, AzureOpenAIEmbedding from llama_index.bridge.pydantic import PrivateAttr @@ -161,20 +162,23 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None endpoint_type = config.embedding_endpoint_type + # TODO refactor to pass credentials through args + credentials = MemGPTCredentials.load() + if endpoint_type == "openai": additional_kwargs = {"user_id": user_id} if user_id else {} - model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=config.openai_key, additional_kwargs=additional_kwargs) + model = OpenAIEmbedding(api_base=config.embedding_endpoint, api_key=credentials.openai_key, additional_kwargs=additional_kwargs) return model elif endpoint_type == "azure": # https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings model = "text-embedding-ada-002" - deployment = config.azure_embedding_deployment if config.azure_embedding_deployment is not None else model + deployment = credentials.azure_embedding_deployment if credentials.azure_embedding_deployment is not None else model return AzureOpenAIEmbedding( model=model, deployment_name=deployment, - api_key=config.azure_key, - azure_endpoint=config.azure_endpoint, - api_version=config.azure_version, + api_key=credentials.azure_key, + azure_endpoint=credentials.azure_endpoint, + api_version=credentials.azure_version, ) elif endpoint_type == "hugging-face": try: diff --git a/memgpt/functions/function_sets/extras.py b/memgpt/functions/function_sets/extras.py index ca8de3a2..15d51ed4 100644 --- a/memgpt/functions/function_sets/extras.py +++ b/memgpt/functions/function_sets/extras.py @@ -10,7 +10,7 @@ from memgpt.constants import ( MAX_PAUSE_HEARTBEATS, JSON_ENSURE_ASCII, ) -from memgpt.openai_tools import create +from memgpt.llm_api_tools import create def message_chatgpt(self, message: str): diff --git a/memgpt/openai_tools.py b/memgpt/llm_api_tools.py similarity index 94% rename from memgpt/openai_tools.py rename to memgpt/llm_api_tools.py index 708a3d91..9230c145 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/llm_api_tools.py @@ -7,6 +7,7 @@ import urllib from box import Box +from memgpt.credentials import MemGPTCredentials from memgpt.local_llm.chat_completion_proxy import get_chat_completion from memgpt.constants import CLI_WARNING_PREFIX from memgpt.models.chat_completion_response import ChatCompletionResponse @@ -392,10 +393,13 @@ def create( printd(f"Using model {agent_state.llm_config.model_endpoint_type}, endpoint: {agent_state.llm_config.model_endpoint}") + # TODO eventually refactor so that credentials are passed through + credentials = MemGPTCredentials.load() + # openai if agent_state.llm_config.model_endpoint_type == "openai": # TODO do the same for Azure? - if agent_state.llm_config.openai_key is None: + if credentials.openai_key is None: raise ValueError(f"OpenAI key is missing from MemGPT config file") if use_tool_naming: data = dict( @@ -415,15 +419,15 @@ def create( ) return openai_chat_completions_request( url=agent_state.llm_config.model_endpoint, # https://api.openai.com/v1 -> https://api.openai.com/v1/chat/completions - api_key=agent_state.llm_config.openai_key, + api_key=credentials.openai_key, data=data, ) # azure elif agent_state.llm_config.model_endpoint_type == "azure": azure_deployment = ( - agent_state.llm_config.azure_deployment - if agent_state.llm_config.azure_deployment is not None + credentials.azure_deployment + if credentials.azure_deployment is not None else MODEL_TO_AZURE_ENGINE[agent_state.llm_config.model] ) if use_tool_naming: @@ -445,10 +449,10 @@ def create( user=str(agent_state.user_id), ) return azure_openai_chat_completions_request( - resource_name=agent_state.llm_config.azure_endpoint, + resource_name=credentials.azure_endpoint, deployment_id=azure_deployment, - api_version=agent_state.llm_config.azure_version, - api_key=agent_state.llm_config.azure_key, + api_version=credentials.azure_version, + api_key=credentials.azure_key, data=data, ) @@ -467,4 +471,7 @@ def create( user=str(agent_state.user_id), # hint first_message=first_message, + # auth-related + auth_type=credentials.openllm_auth_type, + auth_key=credentials.openllm_key, ) diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index b4e74e56..1ff08bbd 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -48,6 +48,9 @@ def get_chat_completion( # extra hints to allow for additional prompt formatting hacks # TODO this could alternatively be supported via passing function_call="send_message" into the wrapper first_message=False, + # optional auth headers + auth_type=None, + auth_key=None, ) -> ChatCompletionResponse: from memgpt.utils import printd @@ -139,21 +142,21 @@ def get_chat_completion( try: if endpoint_type == "webui": - result, usage = get_webui_completion(endpoint, prompt, context_window, grammar=grammar) + result, usage = get_webui_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=grammar) elif endpoint_type == "webui-legacy": - result, usage = get_webui_completion_legacy(endpoint, prompt, context_window, grammar=grammar) + result, usage = get_webui_completion_legacy(endpoint, auth_type, auth_key, prompt, context_window, grammar=grammar) elif endpoint_type == "lmstudio": - result, usage = get_lmstudio_completion(endpoint, prompt, context_window, api="completions") + result, usage = get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_window, api="completions") elif endpoint_type == "lmstudio-legacy": - result, usage = get_lmstudio_completion(endpoint, prompt, context_window, api="chat") + result, usage = get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_window, api="chat") elif endpoint_type == "llamacpp": - result, usage = get_llamacpp_completion(endpoint, prompt, context_window, grammar=grammar) + result, usage = get_llamacpp_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=grammar) elif endpoint_type == "koboldcpp": - result, usage = get_koboldcpp_completion(endpoint, prompt, context_window, grammar=grammar) + result, usage = get_koboldcpp_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=grammar) elif endpoint_type == "ollama": - result, usage = get_ollama_completion(endpoint, model, prompt, context_window) + result, usage = get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window) elif endpoint_type == "vllm": - result, usage = get_vllm_completion(endpoint, model, prompt, context_window, user) + result, usage = get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user) else: raise LocalLLMError( f"Invalid endpoint type {endpoint_type}, please set variable depending on your backend (webui, lmstudio, llamacpp, koboldcpp)" diff --git a/memgpt/local_llm/koboldcpp/api.py b/memgpt/local_llm/koboldcpp/api.py index 7a197ac3..79729af1 100644 --- a/memgpt/local_llm/koboldcpp/api.py +++ b/memgpt/local_llm/koboldcpp/api.py @@ -3,12 +3,12 @@ from urllib.parse import urljoin import requests from memgpt.local_llm.settings.settings import get_completions_settings -from memgpt.local_llm.utils import load_grammar_file, count_tokens +from memgpt.local_llm.utils import count_tokens, post_json_auth_request KOBOLDCPP_API_SUFFIX = "/api/v1/generate" -def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None): +def get_koboldcpp_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=None): """See https://lite.koboldai.net/koboldcpp_api for API spec""" from memgpt.utils import printd @@ -33,7 +33,7 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None): # NOTE: llama.cpp server returns the following when it's out of context # curl: (52) Empty reply from server URI = urljoin(endpoint.strip("/") + "/", KOBOLDCPP_API_SUFFIX.strip("/")) - response = requests.post(URI, json=request) + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) if response.status_code == 200: result_full = response.json() printd(f"JSON API response:\n{result_full}") diff --git a/memgpt/local_llm/llamacpp/api.py b/memgpt/local_llm/llamacpp/api.py index 39ea2372..2e0bea6b 100644 --- a/memgpt/local_llm/llamacpp/api.py +++ b/memgpt/local_llm/llamacpp/api.py @@ -3,13 +3,13 @@ from urllib.parse import urljoin import requests from memgpt.local_llm.settings.settings import get_completions_settings -from memgpt.local_llm.utils import count_tokens, load_grammar_file +from memgpt.local_llm.utils import count_tokens, post_json_auth_request LLAMACPP_API_SUFFIX = "/completion" -def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None): +def get_llamacpp_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=None): """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" from memgpt.utils import printd @@ -33,7 +33,7 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None): # NOTE: llama.cpp server returns the following when it's out of context # curl: (52) Empty reply from server URI = urljoin(endpoint.strip("/") + "/", LLAMACPP_API_SUFFIX.strip("/")) - response = requests.post(URI, json=request) + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) if response.status_code == 200: result_full = response.json() printd(f"JSON API response:\n{result_full}") diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py index 73747659..4aa3db13 100644 --- a/memgpt/local_llm/lmstudio/api.py +++ b/memgpt/local_llm/lmstudio/api.py @@ -3,6 +3,7 @@ from urllib.parse import urljoin import requests from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import post_json_auth_request from memgpt.utils import count_tokens @@ -10,7 +11,7 @@ LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions" LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" -def get_lmstudio_completion(endpoint, prompt, context_window, api="completions"): +def get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_window, api="completions"): """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" from memgpt.utils import printd @@ -64,7 +65,7 @@ def get_lmstudio_completion(endpoint, prompt, context_window, api="completions") raise ValueError(f"Provided OPENAI_API_BASE value ({endpoint}) must begin with http:// or https://") try: - response = requests.post(URI, json=request) + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) if response.status_code == 200: result_full = response.json() printd(f"JSON API response:\n{result_full}") diff --git a/memgpt/local_llm/ollama/api.py b/memgpt/local_llm/ollama/api.py index b351bbe5..a73d4990 100644 --- a/memgpt/local_llm/ollama/api.py +++ b/memgpt/local_llm/ollama/api.py @@ -4,6 +4,7 @@ import requests from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import post_json_auth_request from memgpt.utils import count_tokens from memgpt.errors import LocalLLMError @@ -11,7 +12,7 @@ from memgpt.errors import LocalLLMError OLLAMA_API_SUFFIX = "/api/generate" -def get_ollama_completion(endpoint, model, prompt, context_window, grammar=None): +def get_ollama_completion(endpoint, auth_type, auth_key, model, prompt, context_window, grammar=None): """See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server""" from memgpt.utils import printd @@ -61,7 +62,7 @@ def get_ollama_completion(endpoint, model, prompt, context_window, grammar=None) try: URI = urljoin(endpoint.strip("/") + "/", OLLAMA_API_SUFFIX.strip("/")) - response = requests.post(URI, json=request) + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) if response.status_code == 200: # https://github.com/jmorganca/ollama/blob/main/docs/api.md result_full = response.json() diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py index 7a04b187..03a95180 100644 --- a/memgpt/local_llm/utils.py +++ b/memgpt/local_llm/utils.py @@ -1,4 +1,5 @@ import os +import requests import tiktoken import memgpt.local_llm.llm_chat_completion_wrappers.airoboros as airoboros @@ -7,6 +8,33 @@ import memgpt.local_llm.llm_chat_completion_wrappers.zephyr as zephyr import memgpt.local_llm.llm_chat_completion_wrappers.chatml as chatml +def post_json_auth_request(uri, json_payload, auth_type, auth_key): + """Send a POST request with a JSON payload and optional authentication""" + + # By default most local LLM inference servers do not have authorization enabled + if auth_type is None: + response = requests.post(uri, json=json_payload) + + # Used by OpenAI, together.ai, Mistral AI + elif auth_type == "bearer_token": + if auth_key is None: + raise ValueError(f"auth_type is {auth_type}, but auth_key is null") + headers = {"Content-Type": "application/json", "Authorization": f"Bearer {auth_key}"} + response = requests.post(uri, json=json_payload, headers=headers) + + # Used by OpenAI Azure + elif auth_type == "api_key": + if auth_key is None: + raise ValueError(f"auth_type is {auth_type}, but auth_key is null") + headers = {"Content-Type": "application/json", "api-key": f"{auth_key}"} + response = requests.post(uri, json=json_payload, headers=headers) + + else: + raise ValueError(f"Unsupport authentication type: {auth_type}") + + return response + + # deprecated for Box class DotDict(dict): """Allow dot access on properties similar to OpenAI response object""" diff --git a/memgpt/local_llm/vllm/api.py b/memgpt/local_llm/vllm/api.py index 2d6afe2f..aa5fd6a6 100644 --- a/memgpt/local_llm/vllm/api.py +++ b/memgpt/local_llm/vllm/api.py @@ -3,12 +3,12 @@ from urllib.parse import urljoin import requests from memgpt.local_llm.settings.settings import get_completions_settings -from memgpt.local_llm.utils import load_grammar_file, count_tokens +from memgpt.local_llm.utils import count_tokens, post_json_auth_request WEBUI_API_SUFFIX = "/v1/completions" -def get_vllm_completion(endpoint, model, prompt, context_window, user, grammar=None): +def get_vllm_completion(endpoint, auth_type, auth_key, model, prompt, context_window, user, grammar=None): """https://github.com/vllm-project/vllm/blob/main/examples/api_client.py""" from memgpt.utils import printd @@ -30,14 +30,13 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, grammar=N # Set grammar if grammar is not None: raise NotImplementedError - request["grammar_string"] = load_grammar_file(grammar) if not endpoint.startswith(("http://", "https://")): raise ValueError(f"Endpoint ({endpoint}) must begin with http:// or https://") try: URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) - response = requests.post(URI, json=request) + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) if response.status_code == 200: result_full = response.json() printd(f"JSON API response:\n{result_full}") diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index acbe7787..2b617462 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -3,12 +3,12 @@ from urllib.parse import urljoin import requests from memgpt.local_llm.settings.settings import get_completions_settings -from memgpt.local_llm.utils import load_grammar_file, count_tokens +from memgpt.local_llm.utils import count_tokens, post_json_auth_request WEBUI_API_SUFFIX = "/v1/completions" -def get_webui_completion(endpoint, prompt, context_window, grammar=None): +def get_webui_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=None): """Compatibility for the new OpenAI API: https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples""" from memgpt.utils import printd @@ -33,7 +33,7 @@ def get_webui_completion(endpoint, prompt, context_window, grammar=None): try: URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) - response = requests.post(URI, json=request) + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) if response.status_code == 200: result_full = response.json() printd(f"JSON API response:\n{result_full}") diff --git a/memgpt/local_llm/webui/legacy_api.py b/memgpt/local_llm/webui/legacy_api.py index 9f1a4c5f..6f3d8b2e 100644 --- a/memgpt/local_llm/webui/legacy_api.py +++ b/memgpt/local_llm/webui/legacy_api.py @@ -3,12 +3,12 @@ from urllib.parse import urljoin import requests from memgpt.local_llm.settings.settings import get_completions_settings -from memgpt.local_llm.utils import load_grammar_file, count_tokens +from memgpt.local_llm.utils import count_tokens, post_json_auth_request WEBUI_API_SUFFIX = "/api/v1/generate" -def get_webui_completion(endpoint, prompt, context_window, grammar=None): +def get_webui_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=None): """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" from memgpt.utils import printd @@ -33,7 +33,7 @@ def get_webui_completion(endpoint, prompt, context_window, grammar=None): try: URI = urljoin(endpoint.strip("/") + "/", WEBUI_API_SUFFIX.strip("/")) - response = requests.post(URI, json=request) + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) if response.status_code == 200: result_full = response.json() printd(f"JSON API response:\n{result_full}") diff --git a/memgpt/memory.py b/memgpt/memory.py index f8d42649..a17e9a8b 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -5,7 +5,7 @@ from typing import Optional, List, Tuple from memgpt.constants import MESSAGE_SUMMARY_WARNING_FRAC from memgpt.utils import get_local_time, printd, count_tokens, validate_date_format, extract_date_from_timestamp from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM -from memgpt.openai_tools import create +from memgpt.llm_api_tools import create from memgpt.data_types import Message, Passage, AgentState from memgpt.embeddings import embedding_model from llama_index import Document @@ -357,7 +357,7 @@ class BaseRecallMemory(RecallMemory): class EmbeddingArchivalMemory(ArchivalMemory): """Archival memory with embedding based search""" - def __init__(self, agent_state, top_k: Optional[int] = 100): + def __init__(self, agent_state: AgentState, top_k: Optional[int] = 100): """Init function for archival memory :param archival_memory_database: name of dataset to pre-fill archival with diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 05e84100..1af4ed70 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -100,15 +100,6 @@ class UserModel(Base): default_human = Column(String) default_agent = Column(String) - default_llm_config = Column(LLMConfigColumn) - default_embedding_config = Column(EmbeddingConfigColumn) - - azure_key = Column(String, nullable=True) - azure_endpoint = Column(String, nullable=True) - azure_version = Column(String, nullable=True) - azure_deployment = Column(String, nullable=True) - - openai_key = Column(String, nullable=True) policies_accepted = Column(Boolean, nullable=False, default=False) def __repr__(self) -> str: @@ -122,13 +113,6 @@ class UserModel(Base): default_persona=self.default_persona, default_human=self.default_human, default_agent=self.default_agent, - default_llm_config=self.default_llm_config, - default_embedding_config=self.default_embedding_config, - azure_key=self.azure_key, - azure_endpoint=self.azure_endpoint, - azure_version=self.azure_version, - azure_deployment=self.azure_deployment, - openai_key=self.openai_key, policies_accepted=self.policies_accepted, ) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index f605b668..9f4e2167 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -9,6 +9,7 @@ from fastapi import HTTPException from memgpt.agent_store.storage import StorageConnector from memgpt.config import MemGPTConfig +from memgpt.credentials import MemGPTCredentials from memgpt.agent import Agent import memgpt.system as system import memgpt.constants as constants @@ -198,6 +199,9 @@ class SyncServer(LockingServer): # Initialize the connection to the DB self.config = MemGPTConfig.load() + # TODO figure out how to handle credentials for the server + self.credentials = MemGPTCredentials.load() + # Ensure valid database configuration # TODO: add back once tests are matched # assert ( @@ -211,22 +215,22 @@ class SyncServer(LockingServer): # Generate default LLM/Embedding configs for the server # TODO: we may also want to do the same thing with default persona/human/etc. self.server_llm_config = LLMConfig( - model=self.config.model, - model_endpoint_type=self.config.model_endpoint_type, - model_endpoint=self.config.model_endpoint, - model_wrapper=self.config.model_wrapper, - context_window=self.config.context_window, - openai_key=self.config.openai_key, - azure_key=self.config.azure_key, - azure_endpoint=self.config.azure_endpoint, - azure_version=self.config.azure_version, - azure_deployment=self.config.azure_deployment, + model=self.config.default_llm_config.model, + model_endpoint_type=self.config.default_llm_config.model_endpoint_type, + model_endpoint=self.config.default_llm_config.model_endpoint, + model_wrapper=self.config.default_llm_config.model_wrapper, + context_window=self.config.default_llm_config.context_window, + # openai_key=self.credentials.openai_key, + # azure_key=self.credentials.azure_key, + # azure_endpoint=self.credentials.azure_endpoint, + # azure_version=self.credentials.azure_version, + # azure_deployment=self.credentials.azure_deployment, ) self.server_embedding_config = EmbeddingConfig( - embedding_endpoint_type=self.config.embedding_endpoint_type, - embedding_endpoint=self.config.embedding_endpoint, - embedding_dim=self.config.embedding_dim, - openai_key=self.config.openai_key, + embedding_endpoint_type=self.config.default_embedding_config.embedding_endpoint_type, + embedding_endpoint=self.config.default_embedding_config.embedding_endpoint, + embedding_dim=self.config.default_embedding_config.embedding_dim, + # openai_key=self.credentials.openai_key, ) # Initialize the metadata store @@ -558,8 +562,6 @@ class SyncServer(LockingServer): default_preset=user_config["default_preset"] if "default_preset" in user_config else "memgpt_chat", default_persona=user_config["default_persona"] if "default_persona" in user_config else constants.DEFAULT_PERSONA, default_human=user_config["default_human"] if "default_human" in user_config else constants.DEFAULT_HUMAN, - default_llm_config=self.server_llm_config, - default_embedding_config=self.server_embedding_config, ) self.ms.create_user(user) logger.info(f"Created new user from config: {user}") @@ -599,8 +601,8 @@ class SyncServer(LockingServer): # TODO we need to allow passing raw persona/human text via the server request persona=agent_config["persona"] if "persona" in agent_config else user.default_persona, human=agent_config["human"] if "human" in agent_config else user.default_human, - llm_config=agent_config["llm_config"] if "llm_config" in agent_config else user.default_llm_config, - embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else user.default_embedding_config, + llm_config=agent_config["llm_config"] if "llm_config" in agent_config else self.server_llm_config, + embedding_config=agent_config["embedding_config"] if "embedding_config" in agent_config else self.server_embedding_config, ) # NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation # TODO: fix this db dependency and remove diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 52df3ffa..11578dba 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -9,6 +9,7 @@ from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.cli.cli_load import load_directory, load_database, load_webpage from memgpt.cli.cli import attach from memgpt.config import MemGPTConfig +from memgpt.credentials import MemGPTCredentials from memgpt.metadata import MetadataStore from memgpt.data_types import User, AgentState, EmbeddingConfig @@ -69,27 +70,31 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector, c embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, + # openai_key=os.getenv("OPENAI_API_KEY"), + ) + credentials = MemGPTCredentials( openai_key=os.getenv("OPENAI_API_KEY"), ) + credentials.save() else: embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384) # create user and agent - user = User(id=uuid.UUID(config.anon_clientid), default_embedding_config=embedding_config) + user = User(id=uuid.UUID(config.anon_clientid)) agent = AgentState( user_id=user.id, name="test_agent", preset=user.default_preset, persona=user.default_persona, human=user.default_human, - llm_config=user.default_llm_config, - embedding_config=user.default_embedding_config, + llm_config=config.default_llm_config, + embedding_config=config.default_embedding_config, ) ms.delete_user(user.id) ms.create_user(user) ms.create_agent(agent) user = ms.get_user(user.id) - print("Got user:", user, user.default_embedding_config) + print("Got user:", user, config.default_embedding_config) # setup storage connectors print("Creating storage connectors...") diff --git a/tests/test_metadata_store.py b/tests/test_metadata_store.py index 7ac32c43..ad6e8236 100644 --- a/tests/test_metadata_store.py +++ b/tests/test_metadata_store.py @@ -24,7 +24,7 @@ def test_storage(storage_connector): ms = MetadataStore(config) # generate data - user_1 = User(default_llm_config=LLMConfig(model="gpt-4")) + user_1 = User() user_2 = User() agent_1 = AgentState( user_id=user_1.id, @@ -32,8 +32,8 @@ def test_storage(storage_connector): preset=user_1.default_preset, persona=user_1.default_persona, human=user_1.default_human, - llm_config=user_1.default_llm_config, - embedding_config=user_1.default_embedding_config, + llm_config=config.default_llm_config, + embedding_config=config.default_embedding_config, ) source_1 = Source(user_id=user_1.id, name="source_1") @@ -52,7 +52,7 @@ def test_storage(storage_connector): # test: updating # test: update JSON-stored LLMConfig class - print(agent_1.llm_config, user_1.default_llm_config) + print(agent_1.llm_config, config.default_llm_config) llm_config = ms.get_agent(agent_1.id).llm_config assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}" assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}" diff --git a/tests/test_server.py b/tests/test_server.py index 12652899..963d53dc 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -4,6 +4,7 @@ import memgpt.utils as utils utils.DEBUG = True from memgpt.config import MemGPTConfig +from memgpt.credentials import MemGPTCredentials from memgpt.server.server import SyncServer from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User from memgpt.embeddings import embedding_model @@ -22,14 +23,20 @@ def test_server(): recall_storage_type="postgres", metadata_storage_type="postgres", # embeddings - embedding_endpoint_type="openai", - embedding_endpoint="https://api.openai.com/v1", - embedding_dim=1536, - openai_key=os.getenv("OPENAI_API_KEY"), + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + ), # llms - model_endpoint_type="openai", - model_endpoint="https://api.openai.com/v1", - model="gpt-4", + default_llm_config=LLMConfig( + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + model="gpt-4", + ), + ) + credentials = MemGPTCredentials( + openai_key=os.getenv("OPENAI_API_KEY"), ) else: # hosted config = MemGPTConfig( @@ -40,16 +47,22 @@ def test_server(): recall_storage_type="postgres", metadata_storage_type="postgres", # embeddings - embedding_endpoint_type="hugging-face", - embedding_endpoint="https://embeddings.memgpt.ai", - embedding_model="BAAI/bge-large-en-v1.5", - embedding_dim=1024, + default_embedding_config=EmbeddingConfig( + embedding_endpoint_type="hugging-face", + embedding_endpoint="https://embeddings.memgpt.ai", + embedding_model="BAAI/bge-large-en-v1.5", + embedding_dim=1024, + ), # llms - model_endpoint_type="vllm", - model_endpoint="https://api.memgpt.ai", - model="ehartford/dolphin-2.5-mixtral-8x7b", + default_llm_config=LLMConfig( + model_endpoint_type="vllm", + model_endpoint="https://api.memgpt.ai", + model="ehartford/dolphin-2.5-mixtral-8x7b", + ), ) + credentials = MemGPTCredentials() config.save() + credentials.save() server = SyncServer() diff --git a/tests/test_storage.py b/tests/test_storage.py index 9567aa00..21faf176 100644 --- a/tests/test_storage.py +++ b/tests/test_storage.py @@ -7,6 +7,7 @@ from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.embeddings import embedding_model from memgpt.data_types import Message, Passage, EmbeddingConfig, AgentState, OpenAIEmbeddingConfig from memgpt.config import MemGPTConfig +from memgpt.credentials import MemGPTCredentials from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.metadata import MetadataStore from memgpt.data_types import User @@ -123,8 +124,12 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models embedding_endpoint_type="openai", embedding_endpoint="https://api.openai.com/v1", embedding_dim=1536, + # openai_key=os.getenv("OPENAI_API_KEY"), + ) + credentials = MemGPTCredentials( openai_key=os.getenv("OPENAI_API_KEY"), ) + credentials.save() else: embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384) embed_model = embedding_model(embedding_config) @@ -132,7 +137,7 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models # create user ms = MetadataStore(config) ms.delete_user(user_id) - user = User(id=user_id, default_embedding_config=embedding_config) + user = User(id=user_id) agent = AgentState( user_id=user_id, name="agent_1", @@ -140,8 +145,8 @@ def test_storage(storage_connector, table_type, clear_dynamically_created_models preset=user.default_preset, persona=user.default_persona, human=user.default_human, - llm_config=user.default_llm_config, - embedding_config=user.default_embedding_config, + llm_config=config.default_llm_config, + embedding_config=config.default_embedding_config, ) ms.create_user(user) ms.create_agent(agent)