diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py index ea5b904f..ae983339 100644 --- a/memgpt/local_llm/chat_completion_proxy.py +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -1,30 +1,16 @@ -"""MemGPT sends a ChatCompletion request - -Under the hood, we use the functions argument to turn -""" - - """Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend""" import os -import json import requests +import json -from .webui_settings import DETERMINISTIC, SIMPLE +from .webui.api import get_webui_completion from .llm_chat_completion_wrappers import airoboros +from .utils import DotDict HOST = os.getenv("OPENAI_API_BASE") HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion - - -class DotDict(dict): - """Allow dot access on properties similar to OpenAI response object""" - - def __getattr__(self, attr): - return self.get(attr) - - def __setattr__(self, key, value): - self[key] = value +DEBUG = True async def get_chat_completion( @@ -36,60 +22,52 @@ async def get_chat_completion( if function_call != "auto": raise ValueError(f"function_call == {function_call} not supported (auto only)") - if True or model == "airoboros_v2.1": + if model == "airoboros_v2.1": + llm_wrapper = airoboros.Airoboros21Wrapper() + else: + # Warn the user that we're using the fallback + print( + f"Warning: could not find an LLM wrapper for {model}, using the airoboros wrapper" + ) llm_wrapper = airoboros.Airoboros21Wrapper() # First step: turn the message sequence into a prompt that the model expects prompt = llm_wrapper.chat_completion_to_prompt(messages, functions) - # print(prompt) - - if HOST_TYPE != "webui": - raise ValueError(HOST_TYPE) - - request = SIMPLE - request["prompt"] = prompt + if DEBUG: + print(prompt) try: - URI = f"{HOST}/v1/generate" - response = requests.post(URI, json=request) - if response.status_code == 200: - # result = response.json()['results'][0]['history'] - result = response.json() - # print(f"raw API response: {result}") - result = result["results"][0]["text"] - print(f"json API response.text: {result}") + if HOST_TYPE == "webui": + result = get_webui_completion(prompt) else: - raise Exception(f"API call got non-200 response code") + raise ValueError(HOST_TYPE) + except requests.exceptions.ConnectionError as e: + raise ValueError(f"Was unable to connect to host {HOST}") - # cleaned_result, chatcompletion_result = parse_st_json_output(result) - chat_completion_result = llm_wrapper.output_to_chat_completion_response(result) + chat_completion_result = llm_wrapper.output_to_chat_completion_response(result) + if DEBUG: print(json.dumps(chat_completion_result, indent=2)) - # print(cleaned_result) - # unpack with response.choices[0].message.content - response = DotDict( - { - "model": None, - "choices": [ - DotDict( - { - "message": DotDict(chat_completion_result), - "finish_reason": "stop", # TODO vary based on webui response - } - ) - ], - "usage": DotDict( + # unpack with response.choices[0].message.content + response = DotDict( + { + "model": None, + "choices": [ + DotDict( { - # TODO fix - "prompt_tokens": 0, - "completion_tokens": 0, - "total_tokens": 0, + "message": DotDict(chat_completion_result), + "finish_reason": "stop", # TODO vary based on backend response } - ), - } - ) - return response - - except Exception as e: - # TODO - raise e + ) + ], + "usage": DotDict( + { + # TODO fix, actually use real info + "prompt_tokens": 0, + "completion_tokens": 0, + "total_tokens": 0, + } + ), + } + ) + return response diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py index 6b3a117f..98d3625e 100644 --- a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -12,12 +12,16 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): def __init__( self, simplify_json_content=True, - include_assistant_prefix=True, clean_function_args=True, + include_assistant_prefix=True, + include_opening_brace_in_prefix=True, + include_section_separators=True, ): self.simplify_json_content = simplify_json_content - self.include_assistant_prefix = include_assistant_prefix self.clean_func_args = clean_function_args + self.include_assistant_prefix = include_assistant_prefix + self.include_opening_brance_in_prefix = include_opening_brace_in_prefix + self.include_section_separators = include_section_separators def chat_completion_to_prompt(self, messages, functions): """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format @@ -77,11 +81,41 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): # TODO we're ignoring schema['parameters']['required'] return func_str - prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." + # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." + prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format." prompt += f"\nAvailable functions:" for function_dict in functions: prompt += f"\n{create_function_description(function_dict)}" + def create_function_call(function_call): + """Go from ChatCompletion to Airoboros style function trace (in prompt) + + ChatCompletion data (inside message['function_call']): + "function_call": { + "name": ... + "arguments": { + "arg1": val1, + ... + } + + Airoboros output: + { + "function": "send_message", + "params": { + "message": "Hello there! I am Sam, an AI developed by Liminal Corp. How can I assist you today?" + } + } + """ + airo_func_call = { + "function": function_call["name"], + "params": json.loads(function_call["arguments"]), + } + return json.dumps(airo_func_call, indent=2) + + # Add a sep for the conversation + if self.include_section_separators: + prompt += "\n### INPUT" + # Last are the user/assistant messages for message in messages[1:]: assert message["role"] in ["user", "assistant", "function"], message @@ -96,16 +130,25 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): prompt += f"\nUSER: {message['content']}" elif message["role"] == "assistant": prompt += f"\nASSISTANT: {message['content']}" + # need to add the function call if there was one + if message["function_call"]: + prompt += f"\n{create_function_call(message['function_call'])}" elif message["role"] == "function": - # TODO - continue + # TODO find a good way to add this # prompt += f"\nASSISTANT: (function return) {message['content']}" + prompt += f"\nFUNCTION RETURN: {message['content']}" + continue else: raise ValueError(message) + # Add a sep for the response + if self.include_section_separators: + prompt += "\n### RESPONSE" + if self.include_assistant_prefix: - # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." prompt += f"\nASSISTANT:" + if self.include_opening_brance_in_prefix: + prompt += "\n{" return prompt @@ -135,7 +178,13 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper): } } """ - function_json_output = json.loads(raw_llm_output) + if self.include_opening_brance_in_prefix and raw_llm_output[0] != "{": + raw_llm_output = "{" + raw_llm_output + + try: + function_json_output = json.loads(raw_llm_output) + except Exception as e: + raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}") function_name = function_json_output["function"] function_parameters = function_json_output["params"] diff --git a/memgpt/local_llm/utils.py b/memgpt/local_llm/utils.py new file mode 100644 index 00000000..42a0ce27 --- /dev/null +++ b/memgpt/local_llm/utils.py @@ -0,0 +1,8 @@ +class DotDict(dict): + """Allow dot access on properties similar to OpenAI response object""" + + def __getattr__(self, attr): + return self.get(attr) + + def __setattr__(self, key, value): + self[key] = value diff --git a/memgpt/local_llm/webui/api.py b/memgpt/local_llm/webui/api.py new file mode 100644 index 00000000..3cff08e0 --- /dev/null +++ b/memgpt/local_llm/webui/api.py @@ -0,0 +1,33 @@ +import os +import requests + +from .settings import SIMPLE + +HOST = os.getenv("OPENAI_API_BASE") +HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion +WEBUI_API_SUFFIX = "/v1/generate" +DEBUG = True + + +def get_webui_completion(prompt, settings=SIMPLE): + """See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server""" + + # Settings for the generation, includes the prompt + stop tokens, max length, etc + request = settings + request["prompt"] = prompt + + try: + URI = f"{HOST}{WEBUI_API_SUFFIX}" + response = requests.post(URI, json=request) + if response.status_code == 200: + result = response.json() + result = result["results"][0]["text"] + if DEBUG: + print(f"json API response.text: {result}") + else: + raise Exception(f"API call got non-200 response code") + except: + # TODO handle gracefully + raise + + return result diff --git a/memgpt/local_llm/webui/settings.py b/memgpt/local_llm/webui/settings.py new file mode 100644 index 00000000..2e9ecbce --- /dev/null +++ b/memgpt/local_llm/webui/settings.py @@ -0,0 +1,12 @@ +SIMPLE = { + "stopping_strings": [ + "\nUSER:", + "\nASSISTANT:", + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], + "truncation_length": 4096, # assuming llama2 models +} diff --git a/memgpt/local_llm/webui_settings.py b/memgpt/local_llm/webui_settings.py deleted file mode 100644 index 2601f642..00000000 --- a/memgpt/local_llm/webui_settings.py +++ /dev/null @@ -1,54 +0,0 @@ -DETERMINISTIC = { - "max_new_tokens": 250, - "do_sample": False, - "temperature": 0, - "top_p": 0, - "typical_p": 1, - "repetition_penalty": 1.18, - "repetition_penalty_range": 0, - "encoder_repetition_penalty": 1, - "top_k": 1, - "min_length": 0, - "no_repeat_ngram_size": 0, - "num_beams": 1, - "penalty_alpha": 0, - "length_penalty": 1, - "early_stopping": False, - "guidance_scale": 1, - "negative_prompt": "", - "seed": -1, - "add_bos_token": True, - "stopping_strings": [ - "\nUSER:", - "\nASSISTANT:", - # '\n' + - # '', - # '<|', - # '\n#', - # '\n\n\n', - ], - "truncation_length": 4096, - "ban_eos_token": False, - "skip_special_tokens": True, - "top_a": 0, - "tfs": 1, - "epsilon_cutoff": 0, - "eta_cutoff": 0, - "mirostat_mode": 2, - "mirostat_tau": 4, - "mirostat_eta": 0.1, - "use_mancer": False, -} - -SIMPLE = { - "stopping_strings": [ - "\nUSER:", - "\nASSISTANT:", - # '\n' + - # '', - # '<|', - # '\n#', - # '\n\n\n', - ], - "truncation_length": 4096, -}