LM Studio inference server support (#167)
* updated airo wrapper to catch specific case where extra closing } is missing * added lmstudio support
This commit is contained in:
@@ -5,6 +5,7 @@ import requests
|
||||
import json
|
||||
|
||||
from .webui.api import get_webui_completion
|
||||
from .lmstudio.api import get_lmstudio_completion
|
||||
from .llm_chat_completion_wrappers import airoboros, dolphin
|
||||
from .utils import DotDict
|
||||
|
||||
@@ -40,6 +41,8 @@ async def get_chat_completion(
|
||||
try:
|
||||
if HOST_TYPE == "webui":
|
||||
result = get_webui_completion(prompt)
|
||||
elif HOST_TYPE == "lmstudio":
|
||||
result = get_lmstudio_completion(prompt)
|
||||
else:
|
||||
print(f"Warning: BACKEND_TYPE was not set, defaulting to webui")
|
||||
result = get_webui_completion(prompt)
|
||||
|
||||
@@ -391,7 +391,10 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output + "\n}")
|
||||
except:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
|
||||
|
||||
41
memgpt/local_llm/lmstudio/api.py
Normal file
41
memgpt/local_llm/lmstudio/api.py
Normal file
@@ -0,0 +1,41 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
# from .settings import SIMPLE
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
LMSTUDIO_API_SUFFIX = "/v1/completions"
|
||||
DEBUG = False
|
||||
|
||||
from .settings import SIMPLE
|
||||
|
||||
|
||||
def get_lmstudio_completion(prompt, settings=SIMPLE):
|
||||
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["prompt"] = prompt
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
URI = os.path.join(HOST.strip("/"), LMSTUDIO_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
# result = result["results"][0]["text"]
|
||||
result = result["choices"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code for address: {URI}. Make sure that the LM Studio local inference server is running and reachable at {URI}."
|
||||
)
|
||||
except:
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
13
memgpt/local_llm/lmstudio/settings.py
Normal file
13
memgpt/local_llm/lmstudio/settings.py
Normal file
@@ -0,0 +1,13 @@
|
||||
SIMPLE = {
|
||||
"stop": [
|
||||
"\nUSER:",
|
||||
"\nASSISTANT:",
|
||||
"\nFUNCTION RETURN:",
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
"max_tokens": 500,
|
||||
}
|
||||
@@ -16,8 +16,11 @@ def get_webui_completion(prompt, settings=SIMPLE):
|
||||
request = settings
|
||||
request["prompt"] = prompt
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
URI = f"{HOST.strip('/')}{WEBUI_API_SUFFIX}"
|
||||
URI = os.path.join(HOST.strip("/"), WEBUI_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
@@ -25,7 +28,9 @@ def get_webui_completion(prompt, settings=SIMPLE):
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
else:
|
||||
raise Exception(f"API call got non-200 response code for address: {URI}")
|
||||
raise Exception(
|
||||
f"API call got non-200 response code for address: {URI}. Make sure that the web UI server is running and reachable at {URI}."
|
||||
)
|
||||
except:
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
Reference in New Issue
Block a user