From 78cb676cd2dceaae498b2c254a98e444d26762b4 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sat, 28 Oct 2023 18:30:35 -0700 Subject: [PATCH] LM Studio inference server support (#167) * updated airo wrapper to catch specific case where extra closing } is missing * added lmstudio support --- memgpt/local_llm/chat_completion_proxy.py | 3 ++ .../llm_chat_completion_wrappers/airoboros.py | 5 ++- memgpt/local_llm/lmstudio/api.py | 41 +++++++++++++++++++ memgpt/local_llm/lmstudio/settings.py | 13 ++++++ memgpt/local_llm/webui/api.py | 9 +++- 5 files changed, 68 insertions(+), 3 deletions(-) create mode 100644 memgpt/local_llm/lmstudio/api.py create mode 100644 memgpt/local_llm/lmstudio/settings.py diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index 41442781..497381da 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -5,6 +5,7 @@ import requests import json from .webui.api import get_webui_completion +from .lmstudio.api import get_lmstudio_completion from .llm_chat_completion_wrappers import airoboros, dolphin from .utils import DotDict @@ -40,6 +41,8 @@ async def get_chat_completion( try: if HOST_TYPE == "webui": result = get_webui_completion(prompt) + elif HOST_TYPE == "lmstudio": + result = get_lmstudio_completion(prompt) else: print(f"Warning: BACKEND_TYPE was not set, defaulting to webui") result = get_webui_completion(prompt) diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py index 0b2100fd..7bd793e3 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -391,7 +391,10 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): try: function_json_output = json.loads(raw_llm_output) except Exception as e: - raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}") + try: + function_json_output = json.loads(raw_llm_output + "\n}") + except: + raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}") function_name = function_json_output["function"] function_parameters = function_json_output["params"] diff --git a/memgpt/local_llm/lmstudio/api.py b/memgpt/local_llm/lmstudio/api.py new file mode 100644 index 00000000..e6f59685 --- /dev/null +++ b/memgpt/local_llm/lmstudio/api.py @@ -0,0 +1,41 @@ +import os +import requests + +# from .settings import SIMPLE + +HOST = os.getenv("OPENAI_API_BASE") +HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion +LMSTUDIO_API_SUFFIX = "/v1/completions" +DEBUG = False + +from .settings import SIMPLE + + +def get_lmstudio_completion(prompt, settings=SIMPLE): + """Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client""" + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["prompt"] = prompt + + if not HOST.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + + try: + URI = os.path.join(HOST.strip("/"), LMSTUDIO_API_SUFFIX.strip("/")) + response = requests.post(URI, json=request) + if response.status_code == 200: + result = response.json() + # result = result["results"][0]["text"] + result = result["choices"][0]["text"] + if DEBUG: + print(f"json API response.text: {result}") + else: + raise Exception( + f"API call got non-200 response code for address: {URI}. Make sure that the LM Studio local inference server is running and reachable at {URI}." + ) + except: + # TODO handle gracefully + raise + + return result diff --git a/memgpt/local_llm/lmstudio/settings.py b/memgpt/local_llm/lmstudio/settings.py new file mode 100644 index 00000000..cdf8962e --- /dev/null +++ b/memgpt/local_llm/lmstudio/settings.py @@ -0,0 +1,13 @@ +SIMPLE = { + "stop": [ + "\nUSER:", + "\nASSISTANT:", + "\nFUNCTION RETURN:", + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], + "max_tokens": 500, +} diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py index 547377b1..a79ec98a 100644 --- a/memgpt/local_llm/webui/api.py +++ b/memgpt/local_llm/webui/api.py @@ -16,8 +16,11 @@ def get_webui_completion(prompt, settings=SIMPLE): request = settings request["prompt"] = prompt + if not HOST.startswith(("http://", "https://")): + raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://") + try: - URI = f"{HOST.strip('/')}{WEBUI_API_SUFFIX}" + URI = os.path.join(HOST.strip("/"), WEBUI_API_SUFFIX.strip("/")) response = requests.post(URI, json=request) if response.status_code == 200: result = response.json() @@ -25,7 +28,9 @@ def get_webui_completion(prompt, settings=SIMPLE): if DEBUG: print(f"json API response.text: {result}") else: - raise Exception(f"API call got non-200 response code for address: {URI}") + raise Exception( + f"API call got non-200 response code for address: {URI}. Make sure that the web UI server is running and reachable at {URI}." + ) except: # TODO handle gracefully raise