Improvements to JSON handling for local LLMs (#269)

* some extra json hacks

* add 'smart' json loader to other wrapers

* added chatml related stop tokens by default
This commit is contained in:
Charles Packer
2023-11-03 00:18:31 -07:00
committed by GitHub
parent fde0087a19
commit 437306388f
6 changed files with 83 additions and 16 deletions

View File

@@ -0,0 +1,63 @@
import json
def extract_first_json(string):
"""Handles the case of two JSON objects back-to-back"""
depth = 0
start_index = None
for i, char in enumerate(string):
if char == "{":
if depth == 0:
start_index = i
depth += 1
elif char == "}":
depth -= 1
if depth == 0 and start_index is not None:
try:
return json.loads(string[start_index : i + 1])
except json.JSONDecodeError as e:
raise json.JSONDecodeError(f"Matched closing bracket, but decode failed with error: {str(e)}")
print("No valid JSON object found.")
raise json.JSONDecodeError("Couldn't find starting bracket")
def add_missing_heartbeat(llm_json):
"""Manually insert heartbeat requests into messages that should have them
Use the following heuristic:
- if (function call is not send_message && prev message['role'] == user): insert heartbeat
Basically, if MemGPT is calling a function (not send_message) immediately after the user sending a message,
it probably is a retriever or insertion call, in which case we likely want to eventually reply with send_message
"message" = {
"role": "assistant",
"content": ...,
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
}
}
"""
raise NotImplementedError
def clean_json(raw_llm_output, messages=None, functions=None):
"""Try a bunch of hacks to parse the data coming out of the LLM"""
try:
data = json.loads(raw_llm_output)
except json.JSONDecodeError:
try:
data = json.loads(raw_llm_output + "}")
except json.JSONDecodeError:
try:
data = extract_first_json(raw_llm_output + "}")
except:
raise
return data

View File

@@ -1,6 +1,7 @@
import json
from .wrapper_base import LLMChatCompletionWrapper
from ..json_parser import clean_json
from ...errors import LLMJSONParsingError
@@ -184,9 +185,9 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
raw_llm_output = "{" + raw_llm_output
try:
function_json_output = json.loads(raw_llm_output)
function_json_output = clean_json(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
@@ -393,12 +394,9 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
raw_llm_output = "{" + raw_llm_output
try:
function_json_output = json.loads(raw_llm_output)
function_json_output = clean_json(raw_llm_output)
except Exception as e:
try:
function_json_output = json.loads(raw_llm_output + "\n}")
except:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]

View File

@@ -1,6 +1,7 @@
import json
from .wrapper_base import LLMChatCompletionWrapper
from ..json_parser import clean_json
from ...errors import LLMJSONParsingError
@@ -219,9 +220,9 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
raw_llm_output = "{" + raw_llm_output
try:
function_json_output = json.loads(raw_llm_output)
function_json_output = clean_json(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]

View File

@@ -1,5 +1,7 @@
import json
from .wrapper_base import LLMChatCompletionWrapper
from ..json_parser import clean_json
from ...errors import LLMJSONParsingError
@@ -149,9 +151,9 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
raw_llm_output = "{" + raw_llm_output
try:
function_json_output = json.loads(raw_llm_output)
function_json_output = clean_json(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
@@ -312,12 +314,9 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
raw_llm_output = "{" + raw_llm_output
try:
function_json_output = json.loads(raw_llm_output)
function_json_output = clean_json(raw_llm_output)
except Exception as e:
try:
function_json_output = json.loads(raw_llm_output + "\n}")
except:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]

View File

@@ -5,6 +5,9 @@ SIMPLE = {
"\nUSER:",
"\nASSISTANT:",
"\nFUNCTION RETURN:",
"<|im_start|>",
"<|im_end|>",
"<|im_sep|>",
# '\n' +
# '</s>',
# '<|',

View File

@@ -5,6 +5,9 @@ SIMPLE = {
"\nUSER:",
"\nASSISTANT:",
"\nFUNCTION RETURN:",
"<|im_start|>",
"<|im_end|>",
"<|im_sep|>",
# '\n' +
# '</s>',
# '<|',