diff --git a/memgpt/local_llm/README.md b/memgpt/local_llm/README.md new file mode 100644 index 00000000..d81a58e7 --- /dev/null +++ b/memgpt/local_llm/README.md @@ -0,0 +1,3 @@ +## TODO + +Instructions on how to add additional support for other function calling LLMs + other LLM backends \ No newline at end of file diff --git a/memgpt/local_llm/__init__.py b/memgpt/local_llm/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/memgpt/local_llm/chat_completion_proxy.py b/memgpt/local_llm/chat_completion_proxy.py new file mode 100644 index 00000000..39f69109 --- /dev/null +++ b/memgpt/local_llm/chat_completion_proxy.py @@ -0,0 +1,88 @@ +"""MemGPT sends a ChatCompletion request + +Under the hood, we use the functions argument to turn +""" + + +"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend""" + +import os +import json +import requests + +from .webui_settings import DETERMINISTIC, SIMPLE +from .llm_chat_completion_wrappers import airoboros + +HOST = os.getenv('OPENAI_API_BASE') +HOST_TYPE = os.getenv('BACKEND_TYPE') # default None == ChatCompletion + + +class DotDict(dict): + """Allow dot access on properties similar to OpenAI response object""" + + def __getattr__(self, attr): + return self.get(attr) + + def __setattr__(self, key, value): + self[key] = value + + +async def get_chat_completion( + model, # no model, since the model is fixed to whatever you set in your own backend + messages, + functions, + function_call="auto", + ): + if function_call != "auto": + raise ValueError(f"function_call == {function_call} not supported (auto only)") + + if True or model == 'airoboros_v2.1': + llm_wrapper = airoboros.Airoboros21Wrapper() + + # First step: turn the message sequence into a prompt that the model expects + prompt = llm_wrapper.chat_completion_to_prompt(messages, functions) + # print(prompt) + + if HOST_TYPE != 'webui': + raise ValueError(HOST_TYPE) + + request = SIMPLE + request['prompt'] = prompt + + try: + + URI = f'{HOST}/v1/generate' + response = requests.post(URI, json=request) + if response.status_code == 200: + # result = response.json()['results'][0]['history'] + result = response.json() + # print(f"raw API response: {result}") + result = result['results'][0]['text'] + print(f"json API response.text: {result}") + else: + raise Exception(f"API call got non-200 response code") + + # cleaned_result, chatcompletion_result = parse_st_json_output(result) + chat_completion_result = llm_wrapper.output_to_chat_completion_response(result) + print(json.dumps(chat_completion_result, indent=2)) + # print(cleaned_result) + + # unpack with response.choices[0].message.content + response = DotDict({ + 'model': None, + 'choices': [DotDict({ + 'message': DotDict(chat_completion_result), + 'finish_reason': 'stop', # TODO vary based on webui response + })], + 'usage': DotDict({ + # TODO fix + 'prompt_tokens': 0, + 'completion_tokens': 0, + 'total_tokens': 0, + }) + }) + return response + + except Exception as e: + # TODO + raise e diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/__init__.py b/memgpt/local_llm/llm_chat_completion_wrappers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py new file mode 100644 index 00000000..303e2d37 --- /dev/null +++ b/memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py @@ -0,0 +1,146 @@ +import json + +from .wrapper_base import LLMChatCompletionWrapper + + +class Airoboros21Wrapper(LLMChatCompletionWrapper): + """Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1 + """ + + def __init__(self, simplify_json_content=True, include_assistant_prefix=True, clean_function_args=True): + self.simplify_json_content = simplify_json_content + self.include_assistant_prefix = include_assistant_prefix + self.clean_func_args = clean_function_args + + def chat_completion_to_prompt(self, messages, functions): + """Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format + + A chat. + USER: {prompt} + ASSISTANT: + + Functions support: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling + + As an AI assistant, please select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format. + + Input: I want to know how many times 'Python' is mentioned in my text file. + + Available functions: + file_analytics: + description: This tool performs various operations on a text file. + params: + action: The operation we want to perform on the data, such as "count_occurrences", "find_line", etc. + filters: + keyword: The word or phrase we want to search for. + + OpenAI functions schema style: + + { + "name": "send_message", + "description": "Sends a message to the human user", + "parameters": { + "type": "object", + "properties": { + # https://json-schema.org/understanding-json-schema/reference/array.html + "message": { + "type": "string", + "description": "Message contents. All unicode (including emojis) are supported.", + }, + }, + "required": ["message"], + } + }, + """ + prompt = "" + + # System insturctions go first + assert messages[0]['role'] == 'system' + prompt += messages[0]['content'] + + # Next is the functions preamble + def create_function_description(schema): + # airorobos style + func_str = "" + func_str += f"{schema['name']}:" + func_str += f"\n description: {schema['description']}" + func_str += f"\n params:" + for param_k, param_v in schema['parameters']['properties'].items(): + # TODO we're ignoring type + func_str += f"\n {param_k}: {param_v['description']}" + # TODO we're ignoring schema['parameters']['required'] + return func_str + + prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." + prompt += f"\nAvailable functions:" + for function_dict in functions: + prompt += f"\n{create_function_description(function_dict)}" + + # Last are the user/assistant messages + for message in messages[1:]: + assert message['role'] in ['user', 'assistant', 'function'], message + + if message['role'] == 'user': + if self.simplify_json_content: + try: + content_json = json.loads(message['content']) + content_simple = content_json['message'] + prompt += f"\nUSER: {content_simple}" + except: + prompt += f"\nUSER: {message['content']}" + elif message['role'] == 'assistant': + prompt += f"\nASSISTANT: {message['content']}" + elif message['role'] == 'function': + # TODO + continue + # prompt += f"\nASSISTANT: (function return) {message['content']}" + else: + raise ValueError(message) + + if self.include_assistant_prefix: + # prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format." + prompt += f"\nASSISTANT:" + + return prompt + + def clean_function_args(self, function_name, function_args): + """Some basic MemGPT-specific cleaning of function args""" + cleaned_function_name = function_name + cleaned_function_args = function_args.copy() + + if function_name == 'send_message': + # strip request_heartbeat + cleaned_function_args.pop('request_heartbeat', None) + + # TODO more cleaning to fix errors LLM makes + return cleaned_function_name, cleaned_function_args + + def output_to_chat_completion_response(self, raw_llm_output): + """Turn raw LLM output into a ChatCompletion style response with: + "message" = { + "role": "assistant", + "content": ..., + "function_call": { + "name": ... + "arguments": { + "arg1": val1, + ... + } + } + } + """ + function_json_output = json.loads(raw_llm_output) + function_name = function_json_output['function'] + function_parameters = function_json_output['params'] + + if self.clean_func_args: + function_name, function_parameters = self.clean_function_args(function_name, function_parameters) + + message = { + 'role': 'assistant', + 'content': None, + 'function_call': { + 'name': function_name, + 'arguments': json.dumps(function_parameters), + } + } + return message diff --git a/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py b/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py new file mode 100644 index 00000000..d2e7584e --- /dev/null +++ b/memgpt/local_llm/llm_chat_completion_wrappers/wrapper_base.py @@ -0,0 +1,14 @@ +from abc import ABC, abstractmethod + + +class LLMChatCompletionWrapper(ABC): + + @abstractmethod + def chat_completion_to_prompt(self, messages, functions): + """Go from ChatCompletion to a single prompt string""" + pass + + @abstractmethod + def output_to_chat_completion_response(self, raw_llm_output): + """Turn the LLM output string into a ChatCompletion response""" + pass diff --git a/memgpt/local_llm/webui_settings.py b/memgpt/local_llm/webui_settings.py new file mode 100644 index 00000000..dc578084 --- /dev/null +++ b/memgpt/local_llm/webui_settings.py @@ -0,0 +1,54 @@ +DETERMINISTIC = { + 'max_new_tokens': 250, + 'do_sample': False, + 'temperature': 0, + 'top_p': 0, + 'typical_p': 1, + 'repetition_penalty': 1.18, + 'repetition_penalty_range': 0, + 'encoder_repetition_penalty': 1, + 'top_k': 1, + 'min_length': 0, + 'no_repeat_ngram_size': 0, + 'num_beams': 1, + 'penalty_alpha': 0, + 'length_penalty': 1, + 'early_stopping': False, + 'guidance_scale': 1, + 'negative_prompt': '', + 'seed': -1, + 'add_bos_token': True, + 'stopping_strings': [ + '\nUSER:', + '\nASSISTANT:', + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], + 'truncation_length': 4096, + 'ban_eos_token': False, + 'skip_special_tokens': True, + 'top_a': 0, + 'tfs': 1, + 'epsilon_cutoff': 0, + 'eta_cutoff': 0, + 'mirostat_mode': 2, + 'mirostat_tau': 4, + 'mirostat_eta': 0.1, + 'use_mancer': False + } + +SIMPLE = { + 'stopping_strings': [ + '\nUSER:', + '\nASSISTANT:', + # '\n' + + # '', + # '<|', + # '\n#', + # '\n\n\n', + ], + 'truncation_length': 4096, +} \ No newline at end of file diff --git a/memgpt/openai_tools.py b/memgpt/openai_tools.py index 98444878..7729ae15 100644 --- a/memgpt/openai_tools.py +++ b/memgpt/openai_tools.py @@ -3,7 +3,13 @@ import random import os import time +from .local_llm.chat_completion_proxy import get_chat_completion +HOST = os.getenv('OPENAI_API_BASE') +HOST_TYPE = os.getenv('BACKEND_TYPE') # default None == ChatCompletion + import openai +if HOST is not None: + openai.api_base = HOST def retry_with_exponential_backoff( @@ -102,10 +108,17 @@ def aretry_with_exponential_backoff( @aretry_with_exponential_backoff async def acompletions_with_backoff(**kwargs): - azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT') - if azure_openai_deployment is not None: - kwargs['deployment_id'] = azure_openai_deployment - return await openai.ChatCompletion.acreate(**kwargs) + + # Local model + if HOST_TYPE is not None: + return await get_chat_completion(**kwargs) + + # OpenAI / Azure model + else: + azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT') + if azure_openai_deployment is not None: + kwargs['deployment_id'] = azure_openai_deployment + return await openai.ChatCompletion.acreate(**kwargs) @aretry_with_exponential_backoff