From 20f5231aff7189581b5e4d1a4e9633f68ef813f2 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sun, 24 Dec 2023 23:46:00 -0800 Subject: [PATCH] feat: added basic heartbeat override heuristics (#621) * added basic heartbeat override * tested and working on lmstudio (patched typo + patched new bug emerging in latest lmstudio build * added lmstudio patch to chatml wrapper * update the system messages to be informative about the source * updated string constants after some tuning --- memgpt/constants.py | 14 +++- memgpt/local_llm/chat_completion_proxy.py | 7 ++ memgpt/local_llm/function_parser.py | 67 +++++++++++++++++++ .../llm_chat_completion_wrappers/airoboros.py | 7 +- .../llm_chat_completion_wrappers/chatml.py | 9 ++- tests/test_function_parser.py | 50 ++++++++++++++ 6 files changed, 149 insertions(+), 5 deletions(-) create mode 100644 memgpt/local_llm/function_parser.py create mode 100644 tests/test_function_parser.py diff --git a/memgpt/constants.py b/memgpt/constants.py index 0c0564a7..165a6437 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -19,6 +19,8 @@ INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG = STARTUP_QUOTES[2] CLI_WARNING_PREFIX = "Warning: " +NON_USER_MSG_PREFIX = "[This is an automated system message hidden from the user] " + # Constants to do with summarization / conversation length window # The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B) LLM_MAX_TOKENS = { @@ -43,7 +45,9 @@ LLM_MAX_TOKENS = { # The amount of tokens before a sytem warning about upcoming truncation is sent to MemGPT MESSAGE_SUMMARY_WARNING_FRAC = 0.75 # The error message that MemGPT will receive -MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed." +# MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed." +# Much longer and more specific variant of the prompt +MESSAGE_SUMMARY_WARNING_STR = f"{NON_USER_MSG_PREFIX}The conversation history will soon reach its maximum length and be trimmed. If there is any important new information or general memories about you or the user that you would like to save, you should save that information immediately by calling function core_memory_append, core_memory_replace, or archival_memory_insert (remember to pass request_heartbeat = true if you would like to send a message immediately after)." # The fraction of tokens we truncate down to MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC = 0.75 @@ -65,9 +69,13 @@ MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE = "You are a helpful assistant. Keep you #### Functions related -REQ_HEARTBEAT_MESSAGE = "request_heartbeat == true" -FUNC_FAILED_HEARTBEAT_MESSAGE = "Function call failed" +# REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}request_heartbeat == true" +REQ_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function called using request_heartbeat=true, returning control" +# FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed" +FUNC_FAILED_HEARTBEAT_MESSAGE = f"{NON_USER_MSG_PREFIX}Function call failed, returning control" + FUNCTION_PARAM_NAME_REQ_HEARTBEAT = "request_heartbeat" FUNCTION_PARAM_TYPE_REQ_HEARTBEAT = "boolean" FUNCTION_PARAM_DESCRIPTION_REQ_HEARTBEAT = "Request an immediate heartbeat after function execution. Set to 'true' if you want to send a follow-up message or run a follow-up function." + RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5 diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index cf771c6d..28e7b5c6 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -16,6 +16,7 @@ from memgpt.local_llm.vllm.api import get_vllm_completion from memgpt.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper from memgpt.local_llm.constants import DEFAULT_WRAPPER from memgpt.local_llm.utils import get_available_wrappers, count_tokens +from memgpt.local_llm.function_parser import patch_function from memgpt.prompts.gpt_summarize import SYSTEM as SUMMARIZE_SYSTEM_MESSAGE from memgpt.errors import LocalLLMConnectionError, LocalLLMError from memgpt.constants import CLI_WARNING_PREFIX @@ -34,6 +35,8 @@ def get_chat_completion( wrapper=None, endpoint=None, endpoint_type=None, + # optional cleanup + function_correction=True, # extra hints to allow for additional prompt formatting hacks # TODO this could alternatively be supported via passing function_call="send_message" into the wrapper first_message=False, @@ -126,6 +129,10 @@ def get_chat_completion( except Exception as e: raise LocalLLMError(f"Failed to parse JSON from local LLM response - error: {str(e)}") + # Run through some manual function correction (optional) + if function_correction: + chat_completion_result = patch_function(message_history=messages, new_message=chat_completion_result) + # Fill in potential missing usage information (used for tracking token use) if not ("prompt_tokens" in usage and "completion_tokens" in usage and "total_tokens" in usage): raise LocalLLMError(f"usage dict in response was missing fields ({usage})") diff --git a/memgpt/local_llm/function_parser.py b/memgpt/local_llm/function_parser.py new file mode 100644 index 00000000..c79749b4 --- /dev/null +++ b/memgpt/local_llm/function_parser.py @@ -0,0 +1,67 @@ +import copy +import json + + +NO_HEARTBEAT_FUNCS = ["send_message"] + + +def insert_heartbeat(message): + # message_copy = message.copy() + message_copy = copy.deepcopy(message) + + if message_copy.get("function_call"): + # function_name = message.get("function_call").get("name") + params = message_copy.get("function_call").get("arguments") + params = json.loads(params) + params["request_heartbeat"] = True + message_copy["function_call"]["arguments"] = json.dumps(params) + + elif message_copy.get("tool_call"): + # function_name = message.get("tool_calls")[0].get("function").get("name") + params = message_copy.get("tool_calls")[0].get("function").get("arguments") + params = json.loads(params) + params["request_heartbeat"] = True + message_copy["tools_calls"][0]["function"]["arguments"] = json.dumps(params) + + return message_copy + + +def heartbeat_correction(message_history, new_message): + """Add heartbeats where we think the agent forgot to add them themselves + + If the last message in the stack is a user message and the new message is an assistant func call, fix the heartbeat + + See: https://github.com/cpacker/MemGPT/issues/601 + """ + if len(message_history) < 1: + return None + + last_message_was_user = False + if message_history[-1]["role"] == "user": + try: + content = json.loads(message_history[-1]["content"]) + except json.JSONDecodeError: + return None + # Check if it's a user message or system message + if content["type"] == "user_message": + last_message_was_user = True + + new_message_is_heartbeat_function = False + if new_message["role"] == "assistant": + if new_message.get("function_call") or new_message.get("tool_calls"): + if new_message.get("function_call"): + function_name = new_message.get("function_call").get("name") + elif new_message.get("tool_calls"): + function_name = new_message.get("tool_calls")[0].get("function").get("name") + if function_name not in NO_HEARTBEAT_FUNCS: + new_message_is_heartbeat_function = True + + if last_message_was_user and new_message_is_heartbeat_function: + return insert_heartbeat(new_message) + else: + return None + + +def patch_function(message_history, new_message): + corrected_output = heartbeat_correction(message_history=message_history, new_message=new_message) + return corrected_output if corrected_output is not None else new_message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py index bf838f04..bda61e04 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -418,10 +418,15 @@ class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): except Exception as e: raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}") try: + # NOTE: weird bug can happen where 'function' gets nested if the prefix in the prompt isn't abided by + if isinstance(function_json_output["function"], dict): + function_json_output = function_json_output["function"] function_name = function_json_output["function"] function_parameters = function_json_output["params"] except KeyError as e: - raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}") + raise LLMJSONParsingError( + f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}. JSON result was:\n{function_json_output}" + ) if self.clean_func_args: ( diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py b/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py index 06a2ebce..389721ae 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/chatml.py @@ -255,10 +255,16 @@ class ChatMLInnerMonologueWrapper(LLMChatCompletionWrapper): except Exception as e: raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output} - error\n{str(e)}") try: + # NOTE: weird bug can happen where 'function' gets nested if the prefix in the prompt isn't abided by + if isinstance(function_json_output["function"], dict): + function_json_output = function_json_output["function"] + # regular unpacking function_name = function_json_output["function"] function_parameters = function_json_output["params"] except KeyError as e: - raise LLMJSONParsingError(f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}") + raise LLMJSONParsingError( + f"Received valid JSON from LLM, but JSON was missing fields: {str(e)}. JSON result was:\n{function_json_output}" + ) if self.clean_func_args: ( @@ -366,6 +372,7 @@ class ChatMLOuterInnerMonologueWrapper(ChatMLInnerMonologueWrapper): and function_json_output["function"] is not None and function_json_output["function"].strip().lower() != "none" ): + # TODO apply lm studio nested bug patch? function_name = function_json_output["function"] function_parameters = function_json_output["params"] else: diff --git a/tests/test_function_parser.py b/tests/test_function_parser.py new file mode 100644 index 00000000..57153009 --- /dev/null +++ b/tests/test_function_parser.py @@ -0,0 +1,50 @@ +import json + +from memgpt.local_llm.function_parser import patch_function +import memgpt.system as system + + +EXAMPLE_FUNCTION_CALL_SEND_MESSAGE = { + "message_history": [ + {"role": "user", "content": system.package_user_message("hello")}, + ], + # "new_message": { + # "role": "function", + # "name": "send_message", + # "content": system.package_function_response(was_success=True, response_string="None"), + # }, + "new_message": { + "role": "assistant", + "content": "I'll send a message.", + "function_call": { + "name": "send_message", + "arguments": "null", + }, + }, +} + +EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING = { + "message_history": [ + {"role": "user", "content": system.package_user_message("hello")}, + ], + "new_message": { + "role": "assistant", + "content": "I'll append to memory.", + "function_call": { + "name": "core_memory_append", + "arguments": json.dumps({"content": "new_stuff"}), + }, + }, +} + + +def test_function_parsers(): + """Try various broken JSON and check that the parsers can fix it""" + + og_message = EXAMPLE_FUNCTION_CALL_SEND_MESSAGE["new_message"] + corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_SEND_MESSAGE) + assert corrected_message == og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}" + + og_message = EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING["new_message"].copy() + corrected_message = patch_function(**EXAMPLE_FUNCTION_CALL_CORE_MEMORY_APPEND_MISSING) + assert corrected_message != og_message, f"Uncorrected:\n{og_message}\nCorrected:\n{corrected_message}"