LM Studio inference server support (#167)

* updated airo wrapper to catch specific case where extra closing } is missing

* added lmstudio support
This commit is contained in:
Charles Packer
2023-10-28 18:30:35 -07:00
committed by GitHub
parent d206de5687
commit 78cb676cd2
5 changed files with 68 additions and 3 deletions

View File

@@ -5,6 +5,7 @@ import requests
import json
from .webui.api import get_webui_completion
from .lmstudio.api import get_lmstudio_completion
from .llm_chat_completion_wrappers import airoboros, dolphin
from .utils import DotDict
@@ -40,6 +41,8 @@ async def get_chat_completion(
try:
if HOST_TYPE == "webui":
result = get_webui_completion(prompt)
elif HOST_TYPE == "lmstudio":
result = get_lmstudio_completion(prompt)
else:
print(f"Warning: BACKEND_TYPE was not set, defaulting to webui")
result = get_webui_completion(prompt)

View File

@@ -391,7 +391,10 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
try:
function_json_output = json.loads(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
try:
function_json_output = json.loads(raw_llm_output + "\n}")
except:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]

View File

@@ -0,0 +1,41 @@
import os
import requests
# from .settings import SIMPLE
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
LMSTUDIO_API_SUFFIX = "/v1/completions"
DEBUG = False
from .settings import SIMPLE
def get_lmstudio_completion(prompt, settings=SIMPLE):
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
try:
URI = os.path.join(HOST.strip("/"), LMSTUDIO_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
# result = result["results"][0]["text"]
result = result["choices"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(
f"API call got non-200 response code for address: {URI}. Make sure that the LM Studio local inference server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise
return result

View File

@@ -0,0 +1,13 @@
SIMPLE = {
"stop": [
"\nUSER:",
"\nASSISTANT:",
"\nFUNCTION RETURN:",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"max_tokens": 500,
}

View File

@@ -16,8 +16,11 @@ def get_webui_completion(prompt, settings=SIMPLE):
request = settings
request["prompt"] = prompt
if not HOST.startswith(("http://", "https://")):
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
try:
URI = f"{HOST.strip('/')}{WEBUI_API_SUFFIX}"
URI = os.path.join(HOST.strip("/"), WEBUI_API_SUFFIX.strip("/"))
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
@@ -25,7 +28,9 @@ def get_webui_completion(prompt, settings=SIMPLE):
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(f"API call got non-200 response code for address: {URI}")
raise Exception(
f"API call got non-200 response code for address: {URI}. Make sure that the web UI server is running and reachable at {URI}."
)
except:
# TODO handle gracefully
raise