refactored + updated the airo wrapper a bit

This commit is contained in:
Charles Packer
2023-10-23 00:41:10 -07:00
parent f4ae08f6f5
commit faaa9a04fa
6 changed files with 150 additions and 124 deletions

View File

@@ -1,30 +1,16 @@
"""MemGPT sends a ChatCompletion request
Under the hood, we use the functions argument to turn
"""
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
import os
import json
import requests
import json
from .webui_settings import DETERMINISTIC, SIMPLE
from .webui.api import get_webui_completion
from .llm_chat_completion_wrappers import airoboros
from .utils import DotDict
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object"""
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self[key] = value
DEBUG = True
async def get_chat_completion(
@@ -36,60 +22,52 @@ async def get_chat_completion(
if function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto only)")
if True or model == "airoboros_v2.1":
if model == "airoboros_v2.1":
llm_wrapper = airoboros.Airoboros21Wrapper()
else:
# Warn the user that we're using the fallback
print(
f"Warning: could not find an LLM wrapper for {model}, using the airoboros wrapper"
)
llm_wrapper = airoboros.Airoboros21Wrapper()
# First step: turn the message sequence into a prompt that the model expects
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
# print(prompt)
if HOST_TYPE != "webui":
raise ValueError(HOST_TYPE)
request = SIMPLE
request["prompt"] = prompt
if DEBUG:
print(prompt)
try:
URI = f"{HOST}/v1/generate"
response = requests.post(URI, json=request)
if response.status_code == 200:
# result = response.json()['results'][0]['history']
result = response.json()
# print(f"raw API response: {result}")
result = result["results"][0]["text"]
print(f"json API response.text: {result}")
if HOST_TYPE == "webui":
result = get_webui_completion(prompt)
else:
raise Exception(f"API call got non-200 response code")
raise ValueError(HOST_TYPE)
except requests.exceptions.ConnectionError as e:
raise ValueError(f"Was unable to connect to host {HOST}")
# cleaned_result, chatcompletion_result = parse_st_json_output(result)
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
if DEBUG:
print(json.dumps(chat_completion_result, indent=2))
# print(cleaned_result)
# unpack with response.choices[0].message.content
response = DotDict(
{
"model": None,
"choices": [
DotDict(
{
"message": DotDict(chat_completion_result),
"finish_reason": "stop", # TODO vary based on webui response
}
)
],
"usage": DotDict(
# unpack with response.choices[0].message.content
response = DotDict(
{
"model": None,
"choices": [
DotDict(
{
# TODO fix
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
"message": DotDict(chat_completion_result),
"finish_reason": "stop", # TODO vary based on backend response
}
),
}
)
return response
except Exception as e:
# TODO
raise e
)
],
"usage": DotDict(
{
# TODO fix, actually use real info
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
),
}
)
return response

View File

@@ -12,12 +12,16 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
def __init__(
self,
simplify_json_content=True,
include_assistant_prefix=True,
clean_function_args=True,
include_assistant_prefix=True,
include_opening_brace_in_prefix=True,
include_section_separators=True,
):
self.simplify_json_content = simplify_json_content
self.include_assistant_prefix = include_assistant_prefix
self.clean_func_args = clean_function_args
self.include_assistant_prefix = include_assistant_prefix
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
@@ -77,11 +81,41 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
# TODO we're ignoring schema['parameters']['required']
return func_str
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
def create_function_call(function_call):
"""Go from ChatCompletion to Airoboros style function trace (in prompt)
ChatCompletion data (inside message['function_call']):
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
Airoboros output:
{
"function": "send_message",
"params": {
"message": "Hello there! I am Sam, an AI developed by Liminal Corp. How can I assist you today?"
}
}
"""
airo_func_call = {
"function": function_call["name"],
"params": json.loads(function_call["arguments"]),
}
return json.dumps(airo_func_call, indent=2)
# Add a sep for the conversation
if self.include_section_separators:
prompt += "\n### INPUT"
# Last are the user/assistant messages
for message in messages[1:]:
assert message["role"] in ["user", "assistant", "function"], message
@@ -96,16 +130,25 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
prompt += f"\nUSER: {message['content']}"
elif message["role"] == "assistant":
prompt += f"\nASSISTANT: {message['content']}"
# need to add the function call if there was one
if message["function_call"]:
prompt += f"\n{create_function_call(message['function_call'])}"
elif message["role"] == "function":
# TODO
continue
# TODO find a good way to add this
# prompt += f"\nASSISTANT: (function return) {message['content']}"
prompt += f"\nFUNCTION RETURN: {message['content']}"
continue
else:
raise ValueError(message)
# Add a sep for the response
if self.include_section_separators:
prompt += "\n### RESPONSE"
if self.include_assistant_prefix:
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nASSISTANT:"
if self.include_opening_brance_in_prefix:
prompt += "\n{"
return prompt
@@ -135,7 +178,13 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
}
}
"""
function_json_output = json.loads(raw_llm_output)
if self.include_opening_brance_in_prefix and raw_llm_output[0] != "{":
raw_llm_output = "{" + raw_llm_output
try:
function_json_output = json.loads(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]

View File

@@ -0,0 +1,8 @@
class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object"""
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self[key] = value

View File

@@ -0,0 +1,33 @@
import os
import requests
from .settings import SIMPLE
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
WEBUI_API_SUFFIX = "/v1/generate"
DEBUG = True
def get_webui_completion(prompt, settings=SIMPLE):
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt
try:
URI = f"{HOST}{WEBUI_API_SUFFIX}"
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(f"API call got non-200 response code")
except:
# TODO handle gracefully
raise
return result

View File

@@ -0,0 +1,12 @@
SIMPLE = {
"stopping_strings": [
"\nUSER:",
"\nASSISTANT:",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"truncation_length": 4096, # assuming llama2 models
}

View File

@@ -1,54 +0,0 @@
DETERMINISTIC = {
"max_new_tokens": 250,
"do_sample": False,
"temperature": 0,
"top_p": 0,
"typical_p": 1,
"repetition_penalty": 1.18,
"repetition_penalty_range": 0,
"encoder_repetition_penalty": 1,
"top_k": 1,
"min_length": 0,
"no_repeat_ngram_size": 0,
"num_beams": 1,
"penalty_alpha": 0,
"length_penalty": 1,
"early_stopping": False,
"guidance_scale": 1,
"negative_prompt": "",
"seed": -1,
"add_bos_token": True,
"stopping_strings": [
"\nUSER:",
"\nASSISTANT:",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"truncation_length": 4096,
"ban_eos_token": False,
"skip_special_tokens": True,
"top_a": 0,
"tfs": 1,
"epsilon_cutoff": 0,
"eta_cutoff": 0,
"mirostat_mode": 2,
"mirostat_tau": 4,
"mirostat_eta": 0.1,
"use_mancer": False,
}
SIMPLE = {
"stopping_strings": [
"\nUSER:",
"\nASSISTANT:",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"truncation_length": 4096,
}