feat: added basic heartbeat override heuristics (#621)
* added basic heartbeat override * tested and working on lmstudio (patched typo + patched new bug emerging in latest lmstudio build * added lmstudio patch to chatml wrapper * update the system messages to be informative about the source * updated string constants after some tuning
This commit is contained in:
@@ -19,6 +19,8 @@ INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG = STARTUP_QUOTES[2]
|
||||
|
||||
CLI_WARNING_PREFIX = "Warning: "
|
||||
|
||||
NON_USER_MSG_PREFIX = "[This is an automated system message hidden from the user] "
|
||||
|
||||
# Constants to do with summarization / conversation length window
|
||||
# The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B)
|
||||
LLM_MAX_TOKENS = {
|
||||
@@ -43,7 +45,9 @@ LLM_MAX_TOKENS = {
|
||||
# The amount of tokens before a sytem warning about upcoming truncation is sent to MemGPT
|
||||
MESSAGE_SUMMARY_WARNING_FRAC = 0.75
|
||||
# The error message that MemGPT will receive
|
||||
MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
|
||||
# MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
|
||||
# Much longer and more specific variant of the prompt
|
||||
MESSAGE_SUMMARY_WARNING_STR = f"{NON_USER_MSG_PREFIX}The conversation history will soon reach its maximum length and be trimmed. If there is any important new information or general memories about you or the user that you would like to save, you should save that information immediately by calling function core_memory_append, core_memory_replace, or archival_memory_insert (remember to pass request_heartbeat = true if you would like to send a message immediately after)."
|
||||
# The fraction of tokens we truncate down to
|
||||
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC = 0.75
|
||||
|
||||
@@ -65,9 +69,13 @@ MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep you
|
||||
|
||||
#### Functions related
|
||||
|
||||
REQ_HEARTBEAT_MESSAGE = "request_heartbeat == true"
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE = "Function call failed"
|
||||
# REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}request_heartbeat == true"
|
||||
REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function called using request_heartbeat=true, returning control"
|
||||
# FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed"
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed, returning control"
|
||||
|
||||
FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat"
|
||||
FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean"
|
||||
FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function."
|
||||
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
|
||||
|
||||
@@ -16,6 +16,7 @@ from memgpt.local_llm.vllm.api import get_vllm_completion
|
||||
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
|
||||
from memgpt.local_llm.constants import DEFAULT_WRAPPER
|
||||
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
|
||||
from memgpt.local_llm.function_parser import patch_function
|
||||
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
|
||||
from memgpt.errors import LocalLLMConnectionError, LocalLLMError
|
||||
from memgpt.constants import CLI_WARNING_PREFIX
|
||||
@@ -34,6 +35,8 @@ def get_chat_completion(
|
||||
wrapper=None,
|
||||
endpoint=None,
|
||||
endpoint_type=None,
|
||||
# optional cleanup
|
||||
function_correction=True,
|
||||
# extra hints to allow for additional prompt formatting hacks
|
||||
# TODO this could alternatively be supported via passing function_call="send_message" into the wrapper
|
||||
first_message=False,
|
||||
@@ -126,6 +129,10 @@ def get_chat_completion(
|
||||
except Exception as e:
|
||||
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")
|
||||
|
||||
# Run through some manual function correction (optional)
|
||||
if function_correction:
|
||||
chat_completion_result = patch_function(message_history=messages, new_message=chat_completion_result)
|
||||
|
||||
# Fill in potential missing usage information (used for tracking token use)
|
||||
if not ("prompt_tokens" in usage and "completion_tokens" in usage and "total_tokens" in usage):
|
||||
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")
|
||||
|
||||
67
memgpt/local_llm/function_parser.py
Normal file
67
memgpt/local_llm/function_parser.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import copy
|
||||
import json
|
||||
|
||||
|
||||
NO_HEARTBEAT_FUNCS = ["send_message"]
|
||||
|
||||
|
||||
def insert_heartbeat(message):
|
||||
# message_copy = message.copy()
|
||||
message_copy = copy.deepcopy(message)
|
||||
|
||||
if message_copy.get("function_call"):
|
||||
# function_name = message.get("function_call").get("name")
|
||||
params = message_copy.get("function_call").get("arguments")
|
||||
params = json.loads(params)
|
||||
params["request_heartbeat"] = True
|
||||
message_copy["function_call"]["arguments"] = json.dumps(params)
|
||||
|
||||
elif message_copy.get("tool_call"):
|
||||
# function_name = message.get("tool_calls")[0].get("function").get("name")
|
||||
params = message_copy.get("tool_calls")[0].get("function").get("arguments")
|
||||
params = json.loads(params)
|
||||
params["request_heartbeat"] = True
|
||||
message_copy["tools_calls"][0]["function"]["arguments"] = json.dumps(params)
|
||||
|
||||
return message_copy
|
||||
|
||||
|
||||
def heartbeat_correction(message_history, new_message):
|
||||
"""Add heartbeats where we think the agent forgot to add them themselves
|
||||
|
||||
If the last message in the stack is a user message and the new message is an assistant func call, fix the heartbeat
|
||||
|
||||
See: https://github.com/cpacker/MemGPT/issues/601
|
||||
"""
|
||||
if len(message_history) < 1:
|
||||
return None
|
||||
|
||||
last_message_was_user = False
|
||||
if message_history[-1]["role"] == "user":
|
||||
try:
|
||||
content = json.loads(message_history[-1]["content"])
|
||||
except json.JSONDecodeError:
|
||||
return None
|
||||
# Check if it's a user message or system message
|
||||
if content["type"] == "user_message":
|
||||
last_message_was_user = True
|
||||
|
||||
new_message_is_heartbeat_function = False
|
||||
if new_message["role"] == "assistant":
|
||||
if new_message.get("function_call") or new_message.get("tool_calls"):
|
||||
if new_message.get("function_call"):
|
||||
function_name = new_message.get("function_call").get("name")
|
||||
elif new_message.get("tool_calls"):
|
||||
function_name = new_message.get("tool_calls")[0].get("function").get("name")
|
||||
if function_name not in NO_HEARTBEAT_FUNCS:
|
||||
new_message_is_heartbeat_function = True
|
||||
|
||||
if last_message_was_user and new_message_is_heartbeat_function:
|
||||
return insert_heartbeat(new_message)
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def patch_function(message_history, new_message):
|
||||
corrected_output = heartbeat_correction(message_history=message_history, new_message=new_message)
|
||||
return corrected_output if corrected_output is not None else new_message
|
||||
@@ -418,10 +418,15 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
|
||||
try:
|
||||
# NOTE: weird bug can happen where 'function' gets nested if the prefix in the prompt isn't abided by
|
||||
if isinstance(function_json_output["function"], dict):
|
||||
function_json_output = function_json_output["function"]
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
except KeyError as e:
|
||||
raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}")
|
||||
raise LLMJSONParsingError(
|
||||
f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}. JSON result was:\n{function_json_output}"
|
||||
)
|
||||
|
||||
if self.clean_func_args:
|
||||
(
|
||||
|
||||
@@ -255,10 +255,16 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
|
||||
try:
|
||||
# NOTE: weird bug can happen where 'function' gets nested if the prefix in the prompt isn't abided by
|
||||
if isinstance(function_json_output["function"], dict):
|
||||
function_json_output = function_json_output["function"]
|
||||
# regular unpacking
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
except KeyError as e:
|
||||
raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}")
|
||||
raise LLMJSONParsingError(
|
||||
f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}. JSON result was:\n{function_json_output}"
|
||||
)
|
||||
|
||||
if self.clean_func_args:
|
||||
(
|
||||
@@ -366,6 +372,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
|
||||
and function_json_output["function"] is not None
|
||||
and function_json_output["function"].strip().lower() != "none"
|
||||
):
|
||||
# TODO apply lm studio nested bug patch?
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
else:
|
||||
|
||||
50
tests/test_function_parser.py
Normal file
50
tests/test_function_parser.py
Normal file
@@ -0,0 +1,50 @@
|
||||
import json
|
||||
|
||||
from memgpt.local_llm.function_parser import patch_function
|
||||
import memgpt.system as system
|
||||
|
||||
|
||||
EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = {
|
||||
"message_history": [
|
||||
{"role": "user", "content": system.package_user_message("hello")},
|
||||
],
|
||||
# "new_message": {
|
||||
# "role": "function",
|
||||
# "name": "send_message",
|
||||
# "content": system.package_function_response(was_success=True, response_string="None"),
|
||||
# },
|
||||
"new_message": {
|
||||
"role": "assistant",
|
||||
"content": "I'll send a message.",
|
||||
"function_call": {
|
||||
"name": "send_message",
|
||||
"arguments": "null",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = {
|
||||
"message_history": [
|
||||
{"role": "user", "content": system.package_user_message("hello")},
|
||||
],
|
||||
"new_message": {
|
||||
"role": "assistant",
|
||||
"content": "I'll append to memory.",
|
||||
"function_call": {
|
||||
"name": "core_memory_append",
|
||||
"arguments": json.dumps({"content": "new_stuff"}),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def test_function_parsers():
|
||||
"""Try various broken JSON and check that the parsers can fix it"""
|
||||
|
||||
og_message = EXAMPLE_FUNCTION_CALL_SEND_MESSAGE["new_message"]
|
||||
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_SEND_MESSAGE)
|
||||
assert corrected_message == og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"
|
||||
|
||||
og_message = EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING["new_message"].copy()
|
||||
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING)
|
||||
assert corrected_message != og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"
|
||||
Reference in New Issue
Block a user