Improvements to JSON handling for local LLMs (#269)
* some extra json hacks * add 'smart' json loader to other wrapers * added chatml related stop tokens by default
This commit is contained in:
63
memgpt/local_llm/json_parser.py
Normal file
63
memgpt/local_llm/json_parser.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
|
||||
|
||||
def extract_first_json(string):
|
||||
"""Handles the case of two JSON objects back-to-back"""
|
||||
depth = 0
|
||||
start_index = None
|
||||
|
||||
for i, char in enumerate(string):
|
||||
if char == "{":
|
||||
if depth == 0:
|
||||
start_index = i
|
||||
depth += 1
|
||||
elif char == "}":
|
||||
depth -= 1
|
||||
if depth == 0 and start_index is not None:
|
||||
try:
|
||||
return json.loads(string[start_index : i + 1])
|
||||
except json.JSONDecodeError as e:
|
||||
raise json.JSONDecodeError(f"Matched closing bracket, but decode failed with error: {str(e)}")
|
||||
print("No valid JSON object found.")
|
||||
raise json.JSONDecodeError("Couldn't find starting bracket")
|
||||
|
||||
|
||||
def add_missing_heartbeat(llm_json):
|
||||
"""Manually insert heartbeat requests into messages that should have them
|
||||
|
||||
Use the following heuristic:
|
||||
- if (function call is not send_message && prev message['role'] == user): insert heartbeat
|
||||
|
||||
Basically, if MemGPT is calling a function (not send_message) immediately after the user sending a message,
|
||||
it probably is a retriever or insertion call, in which case we likely want to eventually reply with send_message
|
||||
|
||||
"message" = {
|
||||
"role": "assistant",
|
||||
"content": ...,
|
||||
"function_call": {
|
||||
"name": ...
|
||||
"arguments": {
|
||||
"arg1": val1,
|
||||
...
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def clean_json(raw_llm_output, messages=None, functions=None):
|
||||
"""Try a bunch of hacks to parse the data coming out of the LLM"""
|
||||
|
||||
try:
|
||||
data = json.loads(raw_llm_output)
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
data = json.loads(raw_llm_output + "}")
|
||||
except json.JSONDecodeError:
|
||||
try:
|
||||
data = extract_first_json(raw_llm_output + "}")
|
||||
except:
|
||||
raise
|
||||
|
||||
return data
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
from ..json_parser import clean_json
|
||||
from ...errors import LLMJSONParsingError
|
||||
|
||||
|
||||
@@ -184,9 +185,9 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
raw_llm_output = "{" + raw_llm_output
|
||||
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
function_json_output = clean_json(raw_llm_output)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
|
||||
try:
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
@@ -393,12 +394,9 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
raw_llm_output = "{" + raw_llm_output
|
||||
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
function_json_output = clean_json(raw_llm_output)
|
||||
except Exception as e:
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output + "\n}")
|
||||
except:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
|
||||
try:
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import json
|
||||
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
from ..json_parser import clean_json
|
||||
from ...errors import LLMJSONParsingError
|
||||
|
||||
|
||||
@@ -219,9 +220,9 @@ class Dolphin21MistralWrapper(LLMChatCompletionWrapper):
|
||||
raw_llm_output = "{" + raw_llm_output
|
||||
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
function_json_output = clean_json(raw_llm_output)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
|
||||
try:
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
from ..json_parser import clean_json
|
||||
from ...errors import LLMJSONParsingError
|
||||
|
||||
|
||||
@@ -149,9 +151,9 @@ class ZephyrMistralWrapper(LLMChatCompletionWrapper):
|
||||
raw_llm_output = "{" + raw_llm_output
|
||||
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
function_json_output = clean_json(raw_llm_output)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
|
||||
try:
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
@@ -312,12 +314,9 @@ class ZephyrMistralInnerMonologueWrapper(ZephyrMistralWrapper):
|
||||
raw_llm_output = "{" + raw_llm_output
|
||||
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
function_json_output = clean_json(raw_llm_output)
|
||||
except Exception as e:
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output + "\n}")
|
||||
except:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
|
||||
try:
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
|
||||
@@ -5,6 +5,9 @@ SIMPLE = {
|
||||
"\nUSER:",
|
||||
"\nASSISTANT:",
|
||||
"\nFUNCTION RETURN:",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|im_sep|>",
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
|
||||
@@ -5,6 +5,9 @@ SIMPLE = {
|
||||
"\nUSER:",
|
||||
"\nASSISTANT:",
|
||||
"\nFUNCTION RETURN:",
|
||||
"<|im_start|>",
|
||||
"<|im_end|>",
|
||||
"<|im_sep|>",
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
|
||||
Reference in New Issue
Block a user