Changes to lmstudio to fix JSON decode error (#208)
* Changes to lmstudio to fix JSON decode error * black formatting * properly handle context overflow error (propogate exception up the stack with recognizable error message) + add backwards compat option to use completions endpoint * set max tokens to 8k, comment out the overflow policy (use memgpt's overflow policy) * 8k not 3k --------- Co-authored-by: Matt Poff <mattpoff@Matts-MacBook-Pro-2.local> Co-authored-by: cpacker <packercharles@gmail.com>
This commit is contained in:
@@ -2,39 +2,59 @@ import os
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
|
||||
# from .settings import SIMPLE
|
||||
from .settings import SIMPLE
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
LMSTUDIO_API_SUFFIX = "/v1/completions"
|
||||
LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions"
|
||||
LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions"
|
||||
DEBUG = False
|
||||
|
||||
from .settings import SIMPLE
|
||||
|
||||
|
||||
def get_lmstudio_completion(prompt, settings=SIMPLE):
|
||||
def get_lmstudio_completion(prompt, settings=SIMPLE, api="chat"):
|
||||
"""Based on the example for using LM Studio as a backend from https://github.com/lmstudio-ai/examples/tree/main/Hello%2C%20world%20-%20OpenAI%20python%20client"""
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["prompt"] = prompt
|
||||
|
||||
if api == "chat":
|
||||
# Uses the ChatCompletions API style
|
||||
# Seems to work better, probably because it's applying some extra settings under-the-hood?
|
||||
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_CHAT_SUFFIX.strip("/"))
|
||||
message_structure = [{"role": "user", "content": prompt}]
|
||||
request["messages"] = message_structure
|
||||
elif api == "completions":
|
||||
# Uses basic string completions (string in, string out)
|
||||
# Does not work as well as ChatCompletions for some reason
|
||||
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_COMPLETIONS_SUFFIX.strip("/"))
|
||||
request["prompt"] = prompt
|
||||
else:
|
||||
raise ValueError(api)
|
||||
|
||||
if not HOST.startswith(("http://", "https://")):
|
||||
raise ValueError(f"Provided OPENAI_API_BASE value ({HOST}) must begin with http:// or https://")
|
||||
|
||||
try:
|
||||
URI = urljoin(HOST.strip("/") + "/", LMSTUDIO_API_SUFFIX.strip("/"))
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
# result = result["results"][0]["text"]
|
||||
result = result["choices"][0]["text"]
|
||||
if api == "chat":
|
||||
result = result["choices"][0]["message"]["content"]
|
||||
elif api == "completions":
|
||||
result = result["choices"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code for address: {URI}. Make sure that the LM Studio local inference server is running and reachable at {URI}."
|
||||
)
|
||||
# Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"}
|
||||
if "context length" in str(response.text).lower():
|
||||
# "exceeds context length" is what appears in the LM Studio error message
|
||||
# raise an alternate exception that matches OpenAI's message, which is "maximum context length"
|
||||
raise Exception(f"Request exceeds maximum context length (code={response.status_code}, msg={response.text}, URI={URI})")
|
||||
else:
|
||||
raise Exception(
|
||||
f"API call got non-200 response code (code={response.status_code}, msg={response.text}) for address: {URI}."
|
||||
+ f"Make sure that the LM Studio local inference server is running and reachable at {URI}."
|
||||
)
|
||||
except:
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
@@ -9,5 +9,12 @@ SIMPLE = {
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
"max_tokens": 500,
|
||||
# This controls the maximum number of tokens that the model can generate
|
||||
# Cap this at the model context length (assuming 8k for Mistral 7B)
|
||||
"max_tokens": 8000,
|
||||
# This controls how LM studio handles context overflow
|
||||
# In MemGPT we handle this ourselves, so this should be commented out
|
||||
# "lmstudio": {"context_overflow_policy": 2},
|
||||
"stream": False,
|
||||
"model": "local model",
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user