basic proof of concept tested on airoboros 70b 2.1

This commit is contained in:
Charles Packer
2023-10-22 22:52:24 -07:00
parent 5935f14d7f
commit 8484f0557d
8 changed files with 322 additions and 4 deletions

View File

@@ -0,0 +1,3 @@
## TODO
Instructions on how to add additional support for other function calling LLMs + other LLM backends

View File

View File

@@ -0,0 +1,88 @@
"""MemGPT sends a ChatCompletion request
Under the hood, we use the functions argument to turn
"""
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
import os
import json
import requests
from .webui_settings import DETERMINISTIC, SIMPLE
from .llm_chat_completion_wrappers import airoboros
HOST = os.getenv('OPENAI_API_BASE')
HOST_TYPE = os.getenv('BACKEND_TYPE') # default None == ChatCompletion
class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object"""
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self[key] = value
async def get_chat_completion(
model, # no model, since the model is fixed to whatever you set in your own backend
messages,
functions,
function_call="auto",
):
if function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto only)")
if True or model == 'airoboros_v2.1':
llm_wrapper = airoboros.Airoboros21Wrapper()
# First step: turn the message sequence into a prompt that the model expects
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
# print(prompt)
if HOST_TYPE != 'webui':
raise ValueError(HOST_TYPE)
request = SIMPLE
request['prompt'] = prompt
try:
URI = f'{HOST}/v1/generate'
response = requests.post(URI, json=request)
if response.status_code == 200:
# result = response.json()['results'][0]['history']
result = response.json()
# print(f"raw API response: {result}")
result = result['results'][0]['text']
print(f"json API response.text: {result}")
else:
raise Exception(f"API call got non-200 response code")
# cleaned_result, chatcompletion_result = parse_st_json_output(result)
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
print(json.dumps(chat_completion_result, indent=2))
# print(cleaned_result)
# unpack with response.choices[0].message.content
response = DotDict({
'model': None,
'choices': [DotDict({
'message': DotDict(chat_completion_result),
'finish_reason': 'stop', # TODO vary based on webui response
})],
'usage': DotDict({
# TODO fix
'prompt_tokens': 0,
'completion_tokens': 0,
'total_tokens': 0,
})
})
return response
except Exception as e:
# TODO
raise e

View File

@@ -0,0 +1,146 @@
import json
from .wrapper_base import LLMChatCompletionWrapper
class Airoboros21Wrapper(LLMChatCompletionWrapper):
"""Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1
"""
def __init__(self, simplify_json_content=True, include_assistant_prefix=True, clean_function_args=True):
self.simplify_json_content = simplify_json_content
self.include_assistant_prefix = include_assistant_prefix
self.clean_func_args = clean_function_args
def chat_completion_to_prompt(self, messages, functions):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
A chat.
USER: {prompt}
ASSISTANT:
Functions support: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling
As an AI assistant, please select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format.
Input: I want to know how many times 'Python' is mentioned in my text file.
Available functions:
file_analytics:
description: This tool performs various operations on a text file.
params:
action: The operation we want to perform on the data, such as "count_occurrences", "find_line", etc.
filters:
keyword: The word or phrase we want to search for.
OpenAI functions schema style:
{
"name": "send_message",
"description": "Sends a message to the human user",
"parameters": {
"type": "object",
"properties": {
# https://json-schema.org/understanding-json-schema/reference/array.html
"message": {
"type": "string",
"description": "Message contents. All unicode (including emojis) are supported.",
},
},
"required": ["message"],
}
},
"""
prompt = ""
# System insturctions go first
assert messages[0]['role'] == 'system'
prompt += messages[0]['content']
# Next is the functions preamble
def create_function_description(schema):
# airorobos style
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
func_str += f"\n params:"
for param_k, param_v in schema['parameters']['properties'].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
# TODO we're ignoring schema['parameters']['required']
return func_str
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
# Last are the user/assistant messages
for message in messages[1:]:
assert message['role'] in ['user', 'assistant', 'function'], message
if message['role'] == 'user':
if self.simplify_json_content:
try:
content_json = json.loads(message['content'])
content_simple = content_json['message']
prompt += f"\nUSER: {content_simple}"
except:
prompt += f"\nUSER: {message['content']}"
elif message['role'] == 'assistant':
prompt += f"\nASSISTANT: {message['content']}"
elif message['role'] == 'function':
# TODO
continue
# prompt += f"\nASSISTANT: (function return) {message['content']}"
else:
raise ValueError(message)
if self.include_assistant_prefix:
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nASSISTANT:"
return prompt
def clean_function_args(self, function_name, function_args):
"""Some basic MemGPT-specific cleaning of function args"""
cleaned_function_name = function_name
cleaned_function_args = function_args.copy()
if function_name == 'send_message':
# strip request_heartbeat
cleaned_function_args.pop('request_heartbeat', None)
# TODO more cleaning to fix errors LLM makes
return cleaned_function_name, cleaned_function_args
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn raw LLM output into a ChatCompletion style response with:
"message" = {
"role": "assistant",
"content": ...,
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
}
}
"""
function_json_output = json.loads(raw_llm_output)
function_name = function_json_output['function']
function_parameters = function_json_output['params']
if self.clean_func_args:
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
message = {
'role': 'assistant',
'content': None,
'function_call': {
'name': function_name,
'arguments': json.dumps(function_parameters),
}
}
return message

View File

@@ -0,0 +1,14 @@
from abc import ABC, abstractmethod
class LLMChatCompletionWrapper(ABC):
@abstractmethod
def chat_completion_to_prompt(self, messages, functions):
"""Go from ChatCompletion to a single prompt string"""
pass
@abstractmethod
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn the LLM output string into a ChatCompletion response"""
pass

View File

@@ -0,0 +1,54 @@
DETERMINISTIC = {
'max_new_tokens': 250,
'do_sample': False,
'temperature': 0,
'top_p': 0,
'typical_p': 1,
'repetition_penalty': 1.18,
'repetition_penalty_range': 0,
'encoder_repetition_penalty': 1,
'top_k': 1,
'min_length': 0,
'no_repeat_ngram_size': 0,
'num_beams': 1,
'penalty_alpha': 0,
'length_penalty': 1,
'early_stopping': False,
'guidance_scale': 1,
'negative_prompt': '',
'seed': -1,
'add_bos_token': True,
'stopping_strings': [
'\nUSER:',
'\nASSISTANT:',
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
'truncation_length': 4096,
'ban_eos_token': False,
'skip_special_tokens': True,
'top_a': 0,
'tfs': 1,
'epsilon_cutoff': 0,
'eta_cutoff': 0,
'mirostat_mode': 2,
'mirostat_tau': 4,
'mirostat_eta': 0.1,
'use_mancer': False
}
SIMPLE = {
'stopping_strings': [
'\nUSER:',
'\nASSISTANT:',
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
'truncation_length': 4096,
}

View File

@@ -3,7 +3,13 @@ import random
import os
import time
from .local_llm.chat_completion_proxy import get_chat_completion
HOST = os.getenv('OPENAI_API_BASE')
HOST_TYPE = os.getenv('BACKEND_TYPE') # default None == ChatCompletion
import openai
if HOST is not None:
openai.api_base = HOST
def retry_with_exponential_backoff(
@@ -102,10 +108,17 @@ def aretry_with_exponential_backoff(
@aretry_with_exponential_backoff
async def acompletions_with_backoff(**kwargs):
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
if azure_openai_deployment is not None:
kwargs['deployment_id'] = azure_openai_deployment
return await openai.ChatCompletion.acreate(**kwargs)
# Local model
if HOST_TYPE is not None:
return await get_chat_completion(**kwargs)
# OpenAI / Azure model
else:
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
if azure_openai_deployment is not None:
kwargs['deployment_id'] = azure_openai_deployment
return await openai.ChatCompletion.acreate(**kwargs)
@aretry_with_exponential_backoff