diff --git a/memgpt/local_llm/koboldcpp/api.py b/memgpt/local_llm/koboldcpp/api.py index ecf259c0..5883b884 100644 --- a/memgpt/local_llm/koboldcpp/api.py +++ b/memgpt/local_llm/koboldcpp/api.py @@ -2,13 +2,13 @@ import os from urllib.parse import urljoin import requests -from .settings import SIMPLE -from ..utils import load_grammar_file, count_tokens +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import load_grammar_file, count_tokens KOBOLDCPP_API_SUFFIX = "/api/v1/generate" -def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE): +def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None): """See https://lite.koboldai.net/koboldcpp_api for API spec""" from memgpt.utils import printd @@ -17,6 +17,7 @@ def get_koboldcpp_completion(endpoint, prompt, context_window, grammar=None, set raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc + settings = get_completions_settings() request = settings request["prompt"] = prompt request["max_context_length"] = context_window diff --git a/memgpt/local_llm/llamacpp/api.py b/memgpt/local_llm/llamacpp/api.py index 649ec67c..147417e2 100644 --- a/memgpt/local_llm/llamacpp/api.py +++ b/memgpt/local_llm/llamacpp/api.py @@ -2,13 +2,14 @@ import os from urllib.parse import urljoin import requests -from .settings import SIMPLE -from ..utils import load_grammar_file, count_tokens +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import count_tokens, load_grammar_file + LLAMACPP_API_SUFFIX = "/completion" -def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, settings=SIMPLE): +def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None): """See https://github.com/ggerganov/llama.cpp/blob/master/examples/server/README.md for instructions on how to run the LLM web server""" from memgpt.utils import printd @@ -17,6 +18,7 @@ def get_llamacpp_completion(endpoint, prompt, context_window, grammar=None, sett raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc + settings = get_completions_settings() request = settings request["prompt"] = prompt diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py index 2c6af47e..73747659 100644 --- a/memgpt/local_llm/lmstudio/api.py +++ b/memgpt/local_llm/lmstudio/api.py @@ -2,15 +2,15 @@ import os from urllib.parse import urljoin import requests -from .settings import SIMPLE -from ..utils import count_tokens +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.utils import count_tokens + LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions" LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" -# TODO move to "completions" by default, not "chat" -def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, api="completions"): +def get_lmstudio_completion(endpoint, prompt, context_window, api="completions"): """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" from memgpt.utils import printd @@ -18,6 +18,20 @@ def get_lmstudio_completion(endpoint, prompt, context_window, settings=SIMPLE, a if prompt_tokens > context_window: raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") + settings = get_completions_settings() + settings.update( + { + "input_prefix": "", + "input_suffix": "", + # This controls how LM studio handles context overflow + # In MemGPT we handle this ourselves, so this should be disabled + # "context_overflow_policy": 0, + "lmstudio": {"context_overflow_policy": 0}, # 0 = stop at limit + "stream": False, + "model": "local model", + } + ) + # Uses the ChatCompletions API style # Seems to work better, probably because it's applying some extra settings under-the-hood? if api == "chat": diff --git a/memgpt/local_llm/ollama/api.py b/memgpt/local_llm/ollama/api.py index 32311386..38a84937 100644 --- a/memgpt/local_llm/ollama/api.py +++ b/memgpt/local_llm/ollama/api.py @@ -2,14 +2,16 @@ import os from urllib.parse import urljoin import requests -from .settings import SIMPLE -from ..utils import count_tokens -from ...errors import LocalLLMError + +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.utils import count_tokens +from memgpt.errors import LocalLLMError + OLLAMA_API_SUFFIX = "/api/generate" -def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMPLE, grammar=None): +def get_ollama_completion(endpoint, model, prompt, context_window, grammar=None): """See https://github.com/jmorganca/ollama/blob/main/docs/api.md for instructions on how to run the LLM web server""" from memgpt.utils import printd @@ -23,10 +25,30 @@ def get_ollama_completion(endpoint, model, prompt, context_window, settings=SIMP ) # Settings for the generation, includes the prompt + stop tokens, max length, etc - request = settings - request["prompt"] = prompt - request["model"] = model - request["options"]["num_ctx"] = context_window + # https://github.com/jmorganca/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values + settings = get_completions_settings() + settings.update( + { + # specific naming for context length + "num_ctx": context_window, + } + ) + + # https://github.com/jmorganca/ollama/blob/main/docs/api.md#generate-a-completion + request = { + ## base parameters + "model": model, + "prompt": prompt, + # "images": [], # TODO eventually support + ## advanced parameters + # "format": "json", # TODO eventually support + "stream": False, + "options": settings, + "system": "", # no prompt formatting + "template": "{{ .Prompt }}", # no prompt formatting + "raw": True, # no prompt formatting + "context": None, # no memory via prompt formatting + } # Set grammar if grammar is not None: diff --git a/memgpt/local_llm/settings/__init__.py b/memgpt/local_llm/settings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/memgpt/local_llm/settings/deterministic_mirostat.py b/memgpt/local_llm/settings/deterministic_mirostat.py new file mode 100644 index 00000000..5e19fce4 --- /dev/null +++ b/memgpt/local_llm/settings/deterministic_mirostat.py @@ -0,0 +1,45 @@ +from memgpt.local_llm.settings.simple import settings as simple_settings + +settings = { + "max_new_tokens": 250, + "do_sample": False, + "temperature": 0, + "top_p": 0, + "typical_p": 1, + "repetition_penalty": 1.18, + "repetition_penalty_range": 0, + "encoder_repetition_penalty": 1, + "top_k": 1, + "min_length": 0, + "no_repeat_ngram_size": 0, + "num_beams": 1, + "penalty_alpha": 0, + "length_penalty": 1, + "early_stopping": False, + "guidance_scale": 1, + "negative_prompt": "", + "seed": -1, + "add_bos_token": True, + # NOTE: important - these are the BASE stopping strings, and should be combined with {{user}}/{{char}}-based stopping strings + "stopping_strings": [ + simple_settings["stop"] + # '### Response (JSON only, engaging, natural, authentic, descriptive, creative):', + # "", + # "<|", + # "\n#", + # "\n*{{user}} ", + # "\n\n\n", + # "\n{", + # ",\n{", + ], + "truncation_length": 4096, + "ban_eos_token": False, + "skip_special_tokens": True, + "top_a": 0, + "tfs": 1, + "epsilon_cutoff": 0, + "eta_cutoff": 0, + "mirostat_mode": 2, + "mirostat_tau": 4, + "mirostat_eta": 0.1, +} diff --git a/memgpt/local_llm/settings/settings.py b/memgpt/local_llm/settings/settings.py new file mode 100644 index 00000000..d14a69dc --- /dev/null +++ b/memgpt/local_llm/settings/settings.py @@ -0,0 +1,68 @@ +import json +import os + +from memgpt.constants import MEMGPT_DIR +from memgpt.local_llm.settings.simple import settings as simple_settings +from memgpt.local_llm.settings.deterministic_mirostat import settings as det_miro_settings + + +DEFAULT = "simple" +SETTINGS_FOLDER_NAME = "settings" +COMPLETION_SETTINGS_FILE_NAME = "completions_api_settings.json" + + +def get_completions_settings(defaults="simple") -> dict: + """Pull from the home directory settings if they exist, otherwise default""" + from memgpt.utils import printd + + # Load up some default base settings + printd(f"Loading default settings from '{defaults}'") + if defaults == "simple": + # simple = basic stop strings + settings = simple_settings + elif defaults == "deterministic_mirostat": + settings = det_miro_settings + elif defaults is None: + settings = dict() + else: + raise ValueError(defaults) + + # Check if settings_dir folder exists (if not, create it) + settings_dir = os.path.join(MEMGPT_DIR, SETTINGS_FOLDER_NAME) + if not os.path.exists(settings_dir): + printd(f"Settings folder '{settings_dir}' doesn't exist, creating it...") + try: + os.makedirs(settings_dir) + except Exception as e: + print(f"Error: failed to create settings folder '{settings_dir}'.\n{e}") + return settings + + # Then, check if settings_dir/completions_api_settings.json file exists + settings_file = os.path.join(settings_dir, COMPLETION_SETTINGS_FILE_NAME) + + if os.path.isfile(settings_file): + # Load into a dict called "settings" + printd(f"Found completion settings file '{settings_file}', loading it...") + try: + with open(settings_file, "r") as file: + user_settings = json.load(file) + if len(user_settings) > 0: + settings.update(user_settings) + except json.JSONDecodeError as e: + print(f"Error: failed to load user settings file '{settings_file}', invalid json.\n{e}") + except Exception as e: + print(f"Error: failed to load user settings file.\n{e}") + + else: + printd(f"No completion settings file '{settings_file}', skipping...") + # Create the file settings_file to make it easy for the user to edit + try: + with open(settings_file, "w") as file: + # We don't want to dump existing default settings in case we modify + # the default settings in the future + # json.dump(settings, file, indent=4) + json.dump({}, file, indent=4) + except Exception as e: + print(f"Error: failed to create empty settings file '{settings_file}'.\n{e}") + + return settings diff --git a/memgpt/local_llm/settings/simple.py b/memgpt/local_llm/settings/simple.py new file mode 100644 index 00000000..ae56fb57 --- /dev/null +++ b/memgpt/local_llm/settings/simple.py @@ -0,0 +1,21 @@ +settings = { + # "stopping_strings": [ + "stop": [ + "\nUSER:", + "\nASSISTANT:", + "\nFUNCTION RETURN:", + "\nUSER", + "\nASSISTANT", + "\nFUNCTION RETURN", + "\nFUNCTION", + "\nFUNC", + "<|im_start|>", + "<|im_end|>", + "<|im_sep|>", + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], +} diff --git a/memgpt/local_llm/vllm/api.py b/memgpt/local_llm/vllm/api.py index 5033d04b..d420eda6 100644 --- a/memgpt/local_llm/vllm/api.py +++ b/memgpt/local_llm/vllm/api.py @@ -2,12 +2,13 @@ import os from urllib.parse import urljoin import requests -from ..utils import load_grammar_file, count_tokens +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import load_grammar_file, count_tokens WEBUI_API_SUFFIX = "/v1/completions" -def get_vllm_completion(endpoint, model, prompt, context_window, user, settings={}, grammar=None): +def get_vllm_completion(endpoint, model, prompt, context_window, user, grammar=None): """https://github.com/vllm-project/vllm/blob/main/examples/api_client.py""" from memgpt.utils import printd @@ -16,6 +17,7 @@ def get_vllm_completion(endpoint, model, prompt, context_window, user, settings= raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc + settings = get_completions_settings() request = settings request["prompt"] = prompt request["max_tokens"] = int(context_window - prompt_tokens) diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index 8fe9e851..b84899b1 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -2,13 +2,13 @@ import os from urllib.parse import urljoin import requests -from .settings import SIMPLE -from ..utils import load_grammar_file, count_tokens +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import load_grammar_file, count_tokens WEBUI_API_SUFFIX = "/v1/completions" -def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None): +def get_webui_completion(endpoint, prompt, context_window, grammar=None): """Compatibility for the new OpenAI API: https://github.com/oobabooga/text-generation-webui/wiki/12-%E2%80%90-OpenAI-API#examples""" from memgpt.utils import printd @@ -17,6 +17,7 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc + settings = get_completions_settings() request = settings request["prompt"] = prompt request["truncation_length"] = context_window diff --git a/memgpt/local_llm/webui/legacy_api.py b/memgpt/local_llm/webui/legacy_api.py index 17f7fd6f..23f88599 100644 --- a/memgpt/local_llm/webui/legacy_api.py +++ b/memgpt/local_llm/webui/legacy_api.py @@ -2,13 +2,13 @@ import os from urllib.parse import urljoin import requests -from .legacy_settings import SIMPLE -from ..utils import load_grammar_file, count_tokens +from memgpt.local_llm.settings.settings import get_completions_settings +from memgpt.local_llm.utils import load_grammar_file, count_tokens WEBUI_API_SUFFIX = "/api/v1/generate" -def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, grammar=None): +def get_webui_completion(endpoint, prompt, context_window, grammar=None): """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" from memgpt.utils import printd @@ -17,7 +17,10 @@ def get_webui_completion(endpoint, prompt, context_window, settings=SIMPLE, gram raise Exception(f"Request exceeds maximum context length ({prompt_tokens} > {context_window} tokens)") # Settings for the generation, includes the prompt + stop tokens, max length, etc + settings = get_completions_settings() request = settings + request["stopping_strings"] = request["stop"] # alias + request["max_new_tokens"] = 3072 # random hack? request["prompt"] = prompt request["truncation_length"] = context_window # assuming mistral 7b