feat: added basic heartbeat override heuristics (#621)

* added basic heartbeat override

* tested and working on lmstudio (patched typo + patched new bug emerging in latest lmstudio build

* added lmstudio patch to chatml wrapper

* update the system messages to be informative about the source

* updated string constants after some tuning
This commit is contained in:
Charles Packer
2023-12-24 23:46:00 -08:00
committed by GitHub
parent 6e3d9e143e
commit 20f5231aff
6 changed files with 149 additions and 5 deletions

View File

@@ -19,6 +19,8 @@ INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG = STARTUP_QUOTES[2]
CLI_WARNING_PREFIX = "Warning: "
NON_USER_MSG_PREFIX = "[This is an automated system message hidden from the user] "
# Constants to do with summarization / conversation length window
# The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B)
LLM_MAX_TOKENS = {
@@ -43,7 +45,9 @@ LLM_MAX_TOKENS = {
# The amount of tokens before a sytem warning about upcoming truncation is sent to MemGPT
MESSAGE_SUMMARY_WARNING_FRAC = 0.75
# The error message that MemGPT will receive
MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
# MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
# Much longer and more specific variant of the prompt
MESSAGE_SUMMARY_WARNING_STR = f"{NON_USER_MSG_PREFIX}The conversation history will soon reach its maximum length and be trimmed. If there is any important new information or general memories about you or the user that you would like to save, you should save that information immediately by calling function core_memory_append, core_memory_replace, or archival_memory_insert (remember to pass request_heartbeat = true if you would like to send a message immediately after)."
# The fraction of tokens we truncate down to
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC = 0.75
@@ -65,9 +69,13 @@ MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep you
#### Functions related
REQ_HEARTBEAT_MESSAGE = "request_heartbeat == true"
FUNC_FAILED_HEARTBEAT_MESSAGE = "Function call failed"
# REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}request_heartbeat == true"
REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function called using request_heartbeat=true, returning control"
# FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed"
FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed, returning control"
FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat"
FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean"
FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function."
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5

View File

@@ -16,6 +16,7 @@ from memgpt.local_llm.vllm.api import get_vllm_completion
from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper
from memgpt.local_llm.constants import DEFAULT_WRAPPER
from memgpt.local_llm.utils import get_available_wrappers, count_tokens
from memgpt.local_llm.function_parser import patch_function
from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE
from memgpt.errors import LocalLLMConnectionError, LocalLLMError
from memgpt.constants import CLI_WARNING_PREFIX
@@ -34,6 +35,8 @@ def get_chat_completion(
wrapper=None,
endpoint=None,
endpoint_type=None,
# optional cleanup
function_correction=True,
# extra hints to allow for additional prompt formatting hacks
# TODO this could alternatively be supported via passing function_call="send_message" into the wrapper
first_message=False,
@@ -126,6 +129,10 @@ def get_chat_completion(
except Exception as e:
raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}")
# Run through some manual function correction (optional)
if function_correction:
chat_completion_result = patch_function(message_history=messages, new_message=chat_completion_result)
# Fill in potential missing usage information (used for tracking token use)
if not ("prompt_tokens" in usage and "completion_tokens" in usage and "total_tokens" in usage):
raise LocalLLMError(f"usage dict in response was missing fields ({usage})")

View File

@@ -0,0 +1,67 @@
import copy
import json
NO_HEARTBEAT_FUNCS = ["send_message"]
def insert_heartbeat(message):
# message_copy = message.copy()
message_copy = copy.deepcopy(message)
if message_copy.get("function_call"):
# function_name = message.get("function_call").get("name")
params = message_copy.get("function_call").get("arguments")
params = json.loads(params)
params["request_heartbeat"] = True
message_copy["function_call"]["arguments"] = json.dumps(params)
elif message_copy.get("tool_call"):
# function_name = message.get("tool_calls")[0].get("function").get("name")
params = message_copy.get("tool_calls")[0].get("function").get("arguments")
params = json.loads(params)
params["request_heartbeat"] = True
message_copy["tools_calls"][0]["function"]["arguments"] = json.dumps(params)
return message_copy
def heartbeat_correction(message_history, new_message):
"""Add heartbeats where we think the agent forgot to add them themselves
If the last message in the stack is a user message and the new message is an assistant func call, fix the heartbeat
See: https://github.com/cpacker/MemGPT/issues/601
"""
if len(message_history) < 1:
return None
last_message_was_user = False
if message_history[-1]["role"] == "user":
try:
content = json.loads(message_history[-1]["content"])
except json.JSONDecodeError:
return None
# Check if it's a user message or system message
if content["type"] == "user_message":
last_message_was_user = True
new_message_is_heartbeat_function = False
if new_message["role"] == "assistant":
if new_message.get("function_call") or new_message.get("tool_calls"):
if new_message.get("function_call"):
function_name = new_message.get("function_call").get("name")
elif new_message.get("tool_calls"):
function_name = new_message.get("tool_calls")[0].get("function").get("name")
if function_name not in NO_HEARTBEAT_FUNCS:
new_message_is_heartbeat_function = True
if last_message_was_user and new_message_is_heartbeat_function:
return insert_heartbeat(new_message)
else:
return None
def patch_function(message_history, new_message):
corrected_output = heartbeat_correction(message_history=message_history, new_message=new_message)
return corrected_output if corrected_output is not None else new_message

View File

@@ -418,10 +418,15 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper):
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
# NOTE: weird bug can happen where 'function' gets nested if the prefix in the prompt isn't abided by
if isinstance(function_json_output["function"], dict):
function_json_output = function_json_output["function"]
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
except KeyError as e:
raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}")
raise LLMJSONParsingError(
f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}. JSON result was:\n{function_json_output}"
)
if self.clean_func_args:
(

View File

@@ -255,10 +255,16 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper):
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}")
try:
# NOTE: weird bug can happen where 'function' gets nested if the prefix in the prompt isn't abided by
if isinstance(function_json_output["function"], dict):
function_json_output = function_json_output["function"]
# regular unpacking
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
except KeyError as e:
raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}")
raise LLMJSONParsingError(
f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}. JSON result was:\n{function_json_output}"
)
if self.clean_func_args:
(
@@ -366,6 +372,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper):
and function_json_output["function"] is not None
and function_json_output["function"].strip().lower() != "none"
):
# TODO apply lm studio nested bug patch?
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
else:

View File

@@ -0,0 +1,50 @@
import json
from memgpt.local_llm.function_parser import patch_function
import memgpt.system as system
EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = {
"message_history": [
{"role": "user", "content": system.package_user_message("hello")},
],
# "new_message": {
# "role": "function",
# "name": "send_message",
# "content": system.package_function_response(was_success=True, response_string="None"),
# },
"new_message": {
"role": "assistant",
"content": "I'll send a message.",
"function_call": {
"name": "send_message",
"arguments": "null",
},
},
}
EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = {
"message_history": [
{"role": "user", "content": system.package_user_message("hello")},
],
"new_message": {
"role": "assistant",
"content": "I'll append to memory.",
"function_call": {
"name": "core_memory_append",
"arguments": json.dumps({"content": "new_stuff"}),
},
},
}
def test_function_parsers():
"""Try various broken JSON and check that the parsers can fix it"""
og_message = EXAMPLE_FUNCTION_CALL_SEND_MESSAGE["new_message"]
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_SEND_MESSAGE)
assert corrected_message == og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"
og_message = EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING["new_message"].copy()
corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING)
assert corrected_message != og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"