From e6a0a746bb1dfb495ca0e956ea74731303546321 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Tue, 24 Oct 2023 01:39:56 -0700 Subject: [PATCH] inner monologue airo parser --- memgpt/local_llm/chat_completion_proxy.py | 12 +- .../llm_chat_completion_wrappers/airoboros.py | 213 ++++++++++++++++++ memgpt/local_llm/webui/settings.py | 1 + memgpt/main.py | 2 + 4 files changed, 224 insertions(+), 4 deletions(-) diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index de9da221..a9fccf3e 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -5,12 +5,13 @@ import requests import json from .webui.api import get_webui_completion -from .llm_chat_completion_wrappers import airoboros +from .llm_chat_completion_wrappers import airoboros, dolphin from .utils import DotDict HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion -DEBUG = False +# DEBUG = False +DEBUG = True async def get_chat_completion( @@ -22,8 +23,11 @@ async def get_chat_completion( if function_call != "auto": raise ValueError(f"function_call == {function_call} not supported (auto only)") - if model == "airoboros_v2.1": - llm_wrapper = airoboros.Airoboros21Wrapper() + if model == "airoboros-l2-70b-2.1": + # llm_wrapper = airoboros.Airoboros21Wrapper() + llm_wrapper = airoboros.Airoboros21InnerMonologueWrapper() + elif model == "dolphin-2.1-mistral-7b": + llm_wrapper = dolphin.Dolphin21MistralWrapper() else: # Warn the user that we're using the fallback print( diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py index 98d3625e..60f8ee6b 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -150,6 +150,7 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): if self.include_opening_brance_in_prefix: prompt += "\n{" + print(prompt) return prompt def clean_function_args(self, function_name, function_args): @@ -202,3 +203,215 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): }, } return message + + +class Airoboros21InnerMonologueWrapper(Airoboros21Wrapper): + """Still expect only JSON outputs from model, but add inner monologue as a field""" + + def __init__( + self, + simplify_json_content=True, + clean_function_args=True, + include_assistant_prefix=True, + include_opening_brace_in_prefix=True, + include_section_separators=True, + ): + self.simplify_json_content = simplify_json_content + self.clean_func_args = clean_function_args + self.include_assistant_prefix = include_assistant_prefix + self.include_opening_brance_in_prefix = include_opening_brace_in_prefix + self.include_section_separators = include_section_separators + + def chat_completion_to_prompt(self, messages, functions): + """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format + + A chat. + USER: {prompt} + ASSISTANT: + + Functions support: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling + + As an AI assistant, please select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format. + + Input: I want to know how many times 'Python' is mentioned in my text file. + + Available functions: + file_analytics: + description: This tool performs various operations on a text file. + params: + action: The operation we want to perform on the data, such as "count_occurrences", "find_line", etc. + filters: + keyword: The word or phrase we want to search for. + + OpenAI functions schema style: + + { + "name": "send_message", + "description": "Sends a message to the human user", + "parameters": { + "type": "object", + "properties": { + # https://json-schema.org/understanding-json-schema/reference/array.html + "message": { + "type": "string", + "description": "Message contents. All unicode (including emojis) are supported.", + }, + }, + "required": ["message"], + } + }, + """ + prompt = "" + + # System insturctions go first + assert messages[0]["role"] == "system" + prompt += messages[0]["content"] + + # Next is the functions preamble + def create_function_description(schema, add_inner_thoughts=True): + # airorobos style + func_str = "" + func_str += f"{schema['name']}:" + func_str += f"\n description: {schema['description']}" + func_str += f"\n params:" + if add_inner_thoughts: + func_str += ( + f"\n inner_thoughts: Deep inner monologue private to you only." + ) + for param_k, param_v in schema["parameters"]["properties"].items(): + # TODO we're ignoring type + func_str += f"\n {param_k}: {param_v['description']}" + # TODO we're ignoring schema['parameters']['required'] + return func_str + + # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." + prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." + prompt += f"\nAvailable functions:" + for function_dict in functions: + prompt += f"\n{create_function_description(function_dict)}" + + def create_function_call(function_call, inner_thoughts=None): + """Go from ChatCompletion to Airoboros style function trace (in prompt) + + ChatCompletion data (inside message['function_call']): + "function_call": { + "name": ... + "arguments": { + "arg1": val1, + ... + } + + Airoboros output: + { + "function": "send_message", + "params": { + "message": "Hello there! I am Sam, an AI developed by Liminal Corp. How can I assist you today?" + } + } + """ + airo_func_call = { + "function": function_call["name"], + "params": { + "inner_thoughts": inner_thoughts, + **json.loads(function_call["arguments"]), + }, + } + return json.dumps(airo_func_call, indent=2) + + # Add a sep for the conversation + if self.include_section_separators: + prompt += "\n### INPUT" + + # Last are the user/assistant messages + for message in messages[1:]: + assert message["role"] in ["user", "assistant", "function"], message + + if message["role"] == "user": + if self.simplify_json_content: + try: + content_json = json.loads(message["content"]) + content_simple = content_json["message"] + prompt += f"\nUSER: {content_simple}" + except: + prompt += f"\nUSER: {message['content']}" + elif message["role"] == "assistant": + prompt += f"\nASSISTANT:" + # need to add the function call if there was one + inner_thoughts = message["content"] + if message["function_call"]: + prompt += f"\n{create_function_call(message['function_call'], inner_thoughts=inner_thoughts)}" + elif message["role"] == "function": + # TODO find a good way to add this + # prompt += f"\nASSISTANT: (function return) {message['content']}" + prompt += f"\nFUNCTION RETURN: {message['content']}" + continue + else: + raise ValueError(message) + + # Add a sep for the response + if self.include_section_separators: + prompt += "\n### RESPONSE" + + if self.include_assistant_prefix: + prompt += f"\nASSISTANT:" + if self.include_opening_brance_in_prefix: + prompt += "\n{" + + return prompt + + def clean_function_args(self, function_name, function_args): + """Some basic MemGPT-specific cleaning of function args""" + cleaned_function_name = function_name + cleaned_function_args = function_args.copy() + + if function_name == "send_message": + # strip request_heartbeat + cleaned_function_args.pop("request_heartbeat", None) + + inner_thoughts = None + if "inner_thoughts" in function_args: + inner_thoughts = cleaned_function_args.pop("inner_thoughts") + + # TODO more cleaning to fix errors LLM makes + return inner_thoughts, cleaned_function_name, cleaned_function_args + + def output_to_chat_completion_response(self, raw_llm_output): + """Turn raw LLM output into a ChatCompletion style response with: + "message" = { + "role": "assistant", + "content": ..., + "function_call": { + "name": ... + "arguments": { + "arg1": val1, + ... + } + } + } + """ + if self.include_opening_brance_in_prefix and raw_llm_output[0] != "{": + raw_llm_output = "{" + raw_llm_output + + try: + function_json_output = json.loads(raw_llm_output) + except Exception as e: + raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}") + function_name = function_json_output["function"] + function_parameters = function_json_output["params"] + + if self.clean_func_args: + ( + inner_thoughts, + function_name, + function_parameters, + ) = self.clean_function_args(function_name, function_parameters) + + message = { + "role": "assistant", + "content": inner_thoughts, + "function_call": { + "name": function_name, + "arguments": json.dumps(function_parameters), + }, + } + return message diff --git a/memgpt/local_llm/webui/settings.py b/memgpt/local_llm/webui/settings.py index 2e9ecbce..64335199 100644 --- a/memgpt/local_llm/webui/settings.py +++ b/memgpt/local_llm/webui/settings.py @@ -2,6 +2,7 @@ SIMPLE = { "stopping_strings": [ "\nUSER:", "\nASSISTANT:", + "\nFUNCTION RETURN:", # '\n' + # '', # '<|', diff --git a/memgpt/main.py b/memgpt/main.py index 2ef440eb..8db0d9d2 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -394,6 +394,8 @@ async def main( ).ask_async() clear_line() + user_input = user_input.rstrip() + if user_input.startswith("!"): print(f"Commands for CLI begin with '/' not '!'") continue