refactored + updated the airo wrapper a bit
This commit is contained in:
@@ -1,30 +1,16 @@
|
||||
"""MemGPT sends a ChatCompletion request
|
||||
|
||||
Under the hood, we use the functions argument to turn
|
||||
"""
|
||||
|
||||
|
||||
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
import json
|
||||
|
||||
from .webui_settings import DETERMINISTIC, SIMPLE
|
||||
from .webui.api import get_webui_completion
|
||||
from .llm_chat_completion_wrappers import airoboros
|
||||
from .utils import DotDict
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
"""Allow dot access on properties similar to OpenAI response object"""
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return self.get(attr)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
DEBUG = True
|
||||
|
||||
|
||||
async def get_chat_completion(
|
||||
@@ -36,60 +22,52 @@ async def get_chat_completion(
|
||||
if function_call != "auto":
|
||||
raise ValueError(f"function_call == {function_call} not supported (auto only)")
|
||||
|
||||
if True or model == "airoboros_v2.1":
|
||||
if model == "airoboros_v2.1":
|
||||
llm_wrapper = airoboros.Airoboros21Wrapper()
|
||||
else:
|
||||
# Warn the user that we're using the fallback
|
||||
print(
|
||||
f"Warning: could not find an LLM wrapper for {model}, using the airoboros wrapper"
|
||||
)
|
||||
llm_wrapper = airoboros.Airoboros21Wrapper()
|
||||
|
||||
# First step: turn the message sequence into a prompt that the model expects
|
||||
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
|
||||
# print(prompt)
|
||||
|
||||
if HOST_TYPE != "webui":
|
||||
raise ValueError(HOST_TYPE)
|
||||
|
||||
request = SIMPLE
|
||||
request["prompt"] = prompt
|
||||
if DEBUG:
|
||||
print(prompt)
|
||||
|
||||
try:
|
||||
URI = f"{HOST}/v1/generate"
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
# result = response.json()['results'][0]['history']
|
||||
result = response.json()
|
||||
# print(f"raw API response: {result}")
|
||||
result = result["results"][0]["text"]
|
||||
print(f"json API response.text: {result}")
|
||||
if HOST_TYPE == "webui":
|
||||
result = get_webui_completion(prompt)
|
||||
else:
|
||||
raise Exception(f"API call got non-200 response code")
|
||||
raise ValueError(HOST_TYPE)
|
||||
except requests.exceptions.ConnectionError as e:
|
||||
raise ValueError(f"Was unable to connect to host {HOST}")
|
||||
|
||||
# cleaned_result, chatcompletion_result = parse_st_json_output(result)
|
||||
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
|
||||
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
|
||||
if DEBUG:
|
||||
print(json.dumps(chat_completion_result, indent=2))
|
||||
# print(cleaned_result)
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
response = DotDict(
|
||||
{
|
||||
"model": None,
|
||||
"choices": [
|
||||
DotDict(
|
||||
{
|
||||
"message": DotDict(chat_completion_result),
|
||||
"finish_reason": "stop", # TODO vary based on webui response
|
||||
}
|
||||
)
|
||||
],
|
||||
"usage": DotDict(
|
||||
# unpack with response.choices[0].message.content
|
||||
response = DotDict(
|
||||
{
|
||||
"model": None,
|
||||
"choices": [
|
||||
DotDict(
|
||||
{
|
||||
# TODO fix
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"message": DotDict(chat_completion_result),
|
||||
"finish_reason": "stop", # TODO vary based on backend response
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# TODO
|
||||
raise e
|
||||
)
|
||||
],
|
||||
"usage": DotDict(
|
||||
{
|
||||
# TODO fix, actually use real info
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
return response
|
||||
|
||||
@@ -12,12 +12,16 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
def __init__(
|
||||
self,
|
||||
simplify_json_content=True,
|
||||
include_assistant_prefix=True,
|
||||
clean_function_args=True,
|
||||
include_assistant_prefix=True,
|
||||
include_opening_brace_in_prefix=True,
|
||||
include_section_separators=True,
|
||||
):
|
||||
self.simplify_json_content = simplify_json_content
|
||||
self.include_assistant_prefix = include_assistant_prefix
|
||||
self.clean_func_args = clean_function_args
|
||||
self.include_assistant_prefix = include_assistant_prefix
|
||||
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
|
||||
self.include_section_separators = include_section_separators
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
|
||||
@@ -77,11 +81,41 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
# TODO we're ignoring schema['parameters']['required']
|
||||
return func_str
|
||||
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
|
||||
prompt += f"\nAvailable functions:"
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
|
||||
def create_function_call(function_call):
|
||||
"""Go from ChatCompletion to Airoboros style function trace (in prompt)
|
||||
|
||||
ChatCompletion data (inside message['function_call']):
|
||||
"function_call": {
|
||||
"name": ...
|
||||
"arguments": {
|
||||
"arg1": val1,
|
||||
...
|
||||
}
|
||||
|
||||
Airoboros output:
|
||||
{
|
||||
"function": "send_message",
|
||||
"params": {
|
||||
"message": "Hello there! I am Sam, an AI developed by Liminal Corp. How can I assist you today?"
|
||||
}
|
||||
}
|
||||
"""
|
||||
airo_func_call = {
|
||||
"function": function_call["name"],
|
||||
"params": json.loads(function_call["arguments"]),
|
||||
}
|
||||
return json.dumps(airo_func_call, indent=2)
|
||||
|
||||
# Add a sep for the conversation
|
||||
if self.include_section_separators:
|
||||
prompt += "\n### INPUT"
|
||||
|
||||
# Last are the user/assistant messages
|
||||
for message in messages[1:]:
|
||||
assert message["role"] in ["user", "assistant", "function"], message
|
||||
@@ -96,16 +130,25 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
prompt += f"\nUSER: {message['content']}"
|
||||
elif message["role"] == "assistant":
|
||||
prompt += f"\nASSISTANT: {message['content']}"
|
||||
# need to add the function call if there was one
|
||||
if message["function_call"]:
|
||||
prompt += f"\n{create_function_call(message['function_call'])}"
|
||||
elif message["role"] == "function":
|
||||
# TODO
|
||||
continue
|
||||
# TODO find a good way to add this
|
||||
# prompt += f"\nASSISTANT: (function return) {message['content']}"
|
||||
prompt += f"\nFUNCTION RETURN: {message['content']}"
|
||||
continue
|
||||
else:
|
||||
raise ValueError(message)
|
||||
|
||||
# Add a sep for the response
|
||||
if self.include_section_separators:
|
||||
prompt += "\n### RESPONSE"
|
||||
|
||||
if self.include_assistant_prefix:
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nASSISTANT:"
|
||||
if self.include_opening_brance_in_prefix:
|
||||
prompt += "\n{"
|
||||
|
||||
return prompt
|
||||
|
||||
@@ -135,7 +178,13 @@ class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
}
|
||||
}
|
||||
"""
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
if self.include_opening_brance_in_prefix and raw_llm_output[0] != "{":
|
||||
raw_llm_output = "{" + raw_llm_output
|
||||
|
||||
try:
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
except Exception as e:
|
||||
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
|
||||
function_name = function_json_output["function"]
|
||||
function_parameters = function_json_output["params"]
|
||||
|
||||
|
||||
8
memgpt/local_llm/utils.py
Normal file
8
memgpt/local_llm/utils.py
Normal file
@@ -0,0 +1,8 @@
|
||||
class DotDict(dict):
|
||||
"""Allow dot access on properties similar to OpenAI response object"""
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return self.get(attr)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
33
memgpt/local_llm/webui/api.py
Normal file
33
memgpt/local_llm/webui/api.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import requests
|
||||
|
||||
from .settings import SIMPLE
|
||||
|
||||
HOST = os.getenv("OPENAI_API_BASE")
|
||||
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
|
||||
WEBUI_API_SUFFIX = "/v1/generate"
|
||||
DEBUG = True
|
||||
|
||||
|
||||
def get_webui_completion(prompt, settings=SIMPLE):
|
||||
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
|
||||
|
||||
# Settings for the generation, includes the prompt + stop tokens, max length, etc
|
||||
request = settings
|
||||
request["prompt"] = prompt
|
||||
|
||||
try:
|
||||
URI = f"{HOST}{WEBUI_API_SUFFIX}"
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
result = result["results"][0]["text"]
|
||||
if DEBUG:
|
||||
print(f"json API response.text: {result}")
|
||||
else:
|
||||
raise Exception(f"API call got non-200 response code")
|
||||
except:
|
||||
# TODO handle gracefully
|
||||
raise
|
||||
|
||||
return result
|
||||
12
memgpt/local_llm/webui/settings.py
Normal file
12
memgpt/local_llm/webui/settings.py
Normal file
@@ -0,0 +1,12 @@
|
||||
SIMPLE = {
|
||||
"stopping_strings": [
|
||||
"\nUSER:",
|
||||
"\nASSISTANT:",
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
"truncation_length": 4096, # assuming llama2 models
|
||||
}
|
||||
@@ -1,54 +0,0 @@
|
||||
DETERMINISTIC = {
|
||||
"max_new_tokens": 250,
|
||||
"do_sample": False,
|
||||
"temperature": 0,
|
||||
"top_p": 0,
|
||||
"typical_p": 1,
|
||||
"repetition_penalty": 1.18,
|
||||
"repetition_penalty_range": 0,
|
||||
"encoder_repetition_penalty": 1,
|
||||
"top_k": 1,
|
||||
"min_length": 0,
|
||||
"no_repeat_ngram_size": 0,
|
||||
"num_beams": 1,
|
||||
"penalty_alpha": 0,
|
||||
"length_penalty": 1,
|
||||
"early_stopping": False,
|
||||
"guidance_scale": 1,
|
||||
"negative_prompt": "",
|
||||
"seed": -1,
|
||||
"add_bos_token": True,
|
||||
"stopping_strings": [
|
||||
"\nUSER:",
|
||||
"\nASSISTANT:",
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
"truncation_length": 4096,
|
||||
"ban_eos_token": False,
|
||||
"skip_special_tokens": True,
|
||||
"top_a": 0,
|
||||
"tfs": 1,
|
||||
"epsilon_cutoff": 0,
|
||||
"eta_cutoff": 0,
|
||||
"mirostat_mode": 2,
|
||||
"mirostat_tau": 4,
|
||||
"mirostat_eta": 0.1,
|
||||
"use_mancer": False,
|
||||
}
|
||||
|
||||
SIMPLE = {
|
||||
"stopping_strings": [
|
||||
"\nUSER:",
|
||||
"\nASSISTANT:",
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
"truncation_length": 4096,
|
||||
}
|
||||
Reference in New Issue
Block a user