basic proof of concept tested on airoboros 70b 2.1
This commit is contained in:
3
memgpt/local_llm/README.md
Normal file
3
memgpt/local_llm/README.md
Normal file
@@ -0,0 +1,3 @@
|
||||
## TODO
|
||||
|
||||
Instructions on how to add additional support for other function calling LLMs + other LLM backends
|
||||
0
memgpt/local_llm/__init__.py
Normal file
0
memgpt/local_llm/__init__.py
Normal file
88
memgpt/local_llm/chat_completion_proxy.py
Normal file
88
memgpt/local_llm/chat_completion_proxy.py
Normal file
@@ -0,0 +1,88 @@
|
||||
"""MemGPT sends a ChatCompletion request
|
||||
|
||||
Under the hood, we use the functions argument to turn
|
||||
"""
|
||||
|
||||
|
||||
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
|
||||
|
||||
import os
|
||||
import json
|
||||
import requests
|
||||
|
||||
from .webui_settings import DETERMINISTIC, SIMPLE
|
||||
from .llm_chat_completion_wrappers import airoboros
|
||||
|
||||
HOST = os.getenv('OPENAI_API_BASE')
|
||||
HOST_TYPE = os.getenv('BACKEND_TYPE') # default None == ChatCompletion
|
||||
|
||||
|
||||
class DotDict(dict):
|
||||
"""Allow dot access on properties similar to OpenAI response object"""
|
||||
|
||||
def __getattr__(self, attr):
|
||||
return self.get(attr)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
self[key] = value
|
||||
|
||||
|
||||
async def get_chat_completion(
|
||||
model, # no model, since the model is fixed to whatever you set in your own backend
|
||||
messages,
|
||||
functions,
|
||||
function_call="auto",
|
||||
):
|
||||
if function_call != "auto":
|
||||
raise ValueError(f"function_call == {function_call} not supported (auto only)")
|
||||
|
||||
if True or model == 'airoboros_v2.1':
|
||||
llm_wrapper = airoboros.Airoboros21Wrapper()
|
||||
|
||||
# First step: turn the message sequence into a prompt that the model expects
|
||||
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
|
||||
# print(prompt)
|
||||
|
||||
if HOST_TYPE != 'webui':
|
||||
raise ValueError(HOST_TYPE)
|
||||
|
||||
request = SIMPLE
|
||||
request['prompt'] = prompt
|
||||
|
||||
try:
|
||||
|
||||
URI = f'{HOST}/v1/generate'
|
||||
response = requests.post(URI, json=request)
|
||||
if response.status_code == 200:
|
||||
# result = response.json()['results'][0]['history']
|
||||
result = response.json()
|
||||
# print(f"raw API response: {result}")
|
||||
result = result['results'][0]['text']
|
||||
print(f"json API response.text: {result}")
|
||||
else:
|
||||
raise Exception(f"API call got non-200 response code")
|
||||
|
||||
# cleaned_result, chatcompletion_result = parse_st_json_output(result)
|
||||
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
|
||||
print(json.dumps(chat_completion_result, indent=2))
|
||||
# print(cleaned_result)
|
||||
|
||||
# unpack with response.choices[0].message.content
|
||||
response = DotDict({
|
||||
'model': None,
|
||||
'choices': [DotDict({
|
||||
'message': DotDict(chat_completion_result),
|
||||
'finish_reason': 'stop', # TODO vary based on webui response
|
||||
})],
|
||||
'usage': DotDict({
|
||||
# TODO fix
|
||||
'prompt_tokens': 0,
|
||||
'completion_tokens': 0,
|
||||
'total_tokens': 0,
|
||||
})
|
||||
})
|
||||
return response
|
||||
|
||||
except Exception as e:
|
||||
# TODO
|
||||
raise e
|
||||
146
memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
Normal file
146
memgpt/local_llm/llm_chat_completion_wrappers/airoboros.py
Normal file
@@ -0,0 +1,146 @@
|
||||
import json
|
||||
|
||||
from .wrapper_base import LLMChatCompletionWrapper
|
||||
|
||||
|
||||
class Airoboros21Wrapper(LLMChatCompletionWrapper):
|
||||
"""Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1
|
||||
"""
|
||||
|
||||
def __init__(self, simplify_json_content=True, include_assistant_prefix=True, clean_function_args=True):
|
||||
self.simplify_json_content = simplify_json_content
|
||||
self.include_assistant_prefix = include_assistant_prefix
|
||||
self.clean_func_args = clean_function_args
|
||||
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
|
||||
|
||||
A chat.
|
||||
USER: {prompt}
|
||||
ASSISTANT:
|
||||
|
||||
Functions support: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling
|
||||
|
||||
As an AI assistant, please select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format.
|
||||
|
||||
Input: I want to know how many times 'Python' is mentioned in my text file.
|
||||
|
||||
Available functions:
|
||||
file_analytics:
|
||||
description: This tool performs various operations on a text file.
|
||||
params:
|
||||
action: The operation we want to perform on the data, such as "count_occurrences", "find_line", etc.
|
||||
filters:
|
||||
keyword: The word or phrase we want to search for.
|
||||
|
||||
OpenAI functions schema style:
|
||||
|
||||
{
|
||||
"name": "send_message",
|
||||
"description": "Sends a message to the human user",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
# https://json-schema.org/understanding-json-schema/reference/array.html
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "Message contents. All unicode (including emojis) are supported.",
|
||||
},
|
||||
},
|
||||
"required": ["message"],
|
||||
}
|
||||
},
|
||||
"""
|
||||
prompt = ""
|
||||
|
||||
# System insturctions go first
|
||||
assert messages[0]['role'] == 'system'
|
||||
prompt += messages[0]['content']
|
||||
|
||||
# Next is the functions preamble
|
||||
def create_function_description(schema):
|
||||
# airorobos style
|
||||
func_str = ""
|
||||
func_str += f"{schema['name']}:"
|
||||
func_str += f"\n description: {schema['description']}"
|
||||
func_str += f"\n params:"
|
||||
for param_k, param_v in schema['parameters']['properties'].items():
|
||||
# TODO we're ignoring type
|
||||
func_str += f"\n {param_k}: {param_v['description']}"
|
||||
# TODO we're ignoring schema['parameters']['required']
|
||||
return func_str
|
||||
|
||||
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nAvailable functions:"
|
||||
for function_dict in functions:
|
||||
prompt += f"\n{create_function_description(function_dict)}"
|
||||
|
||||
# Last are the user/assistant messages
|
||||
for message in messages[1:]:
|
||||
assert message['role'] in ['user', 'assistant', 'function'], message
|
||||
|
||||
if message['role'] == 'user':
|
||||
if self.simplify_json_content:
|
||||
try:
|
||||
content_json = json.loads(message['content'])
|
||||
content_simple = content_json['message']
|
||||
prompt += f"\nUSER: {content_simple}"
|
||||
except:
|
||||
prompt += f"\nUSER: {message['content']}"
|
||||
elif message['role'] == 'assistant':
|
||||
prompt += f"\nASSISTANT: {message['content']}"
|
||||
elif message['role'] == 'function':
|
||||
# TODO
|
||||
continue
|
||||
# prompt += f"\nASSISTANT: (function return) {message['content']}"
|
||||
else:
|
||||
raise ValueError(message)
|
||||
|
||||
if self.include_assistant_prefix:
|
||||
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
|
||||
prompt += f"\nASSISTANT:"
|
||||
|
||||
return prompt
|
||||
|
||||
def clean_function_args(self, function_name, function_args):
|
||||
"""Some basic MemGPT-specific cleaning of function args"""
|
||||
cleaned_function_name = function_name
|
||||
cleaned_function_args = function_args.copy()
|
||||
|
||||
if function_name == 'send_message':
|
||||
# strip request_heartbeat
|
||||
cleaned_function_args.pop('request_heartbeat', None)
|
||||
|
||||
# TODO more cleaning to fix errors LLM makes
|
||||
return cleaned_function_name, cleaned_function_args
|
||||
|
||||
def output_to_chat_completion_response(self, raw_llm_output):
|
||||
"""Turn raw LLM output into a ChatCompletion style response with:
|
||||
"message" = {
|
||||
"role": "assistant",
|
||||
"content": ...,
|
||||
"function_call": {
|
||||
"name": ...
|
||||
"arguments": {
|
||||
"arg1": val1,
|
||||
...
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
function_json_output = json.loads(raw_llm_output)
|
||||
function_name = function_json_output['function']
|
||||
function_parameters = function_json_output['params']
|
||||
|
||||
if self.clean_func_args:
|
||||
function_name, function_parameters = self.clean_function_args(function_name, function_parameters)
|
||||
|
||||
message = {
|
||||
'role': 'assistant',
|
||||
'content': None,
|
||||
'function_call': {
|
||||
'name': function_name,
|
||||
'arguments': json.dumps(function_parameters),
|
||||
}
|
||||
}
|
||||
return message
|
||||
@@ -0,0 +1,14 @@
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class LLMChatCompletionWrapper(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def chat_completion_to_prompt(self, messages, functions):
|
||||
"""Go from ChatCompletion to a single prompt string"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def output_to_chat_completion_response(self, raw_llm_output):
|
||||
"""Turn the LLM output string into a ChatCompletion response"""
|
||||
pass
|
||||
54
memgpt/local_llm/webui_settings.py
Normal file
54
memgpt/local_llm/webui_settings.py
Normal file
@@ -0,0 +1,54 @@
|
||||
DETERMINISTIC = {
|
||||
'max_new_tokens': 250,
|
||||
'do_sample': False,
|
||||
'temperature': 0,
|
||||
'top_p': 0,
|
||||
'typical_p': 1,
|
||||
'repetition_penalty': 1.18,
|
||||
'repetition_penalty_range': 0,
|
||||
'encoder_repetition_penalty': 1,
|
||||
'top_k': 1,
|
||||
'min_length': 0,
|
||||
'no_repeat_ngram_size': 0,
|
||||
'num_beams': 1,
|
||||
'penalty_alpha': 0,
|
||||
'length_penalty': 1,
|
||||
'early_stopping': False,
|
||||
'guidance_scale': 1,
|
||||
'negative_prompt': '',
|
||||
'seed': -1,
|
||||
'add_bos_token': True,
|
||||
'stopping_strings': [
|
||||
'\nUSER:',
|
||||
'\nASSISTANT:',
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
'truncation_length': 4096,
|
||||
'ban_eos_token': False,
|
||||
'skip_special_tokens': True,
|
||||
'top_a': 0,
|
||||
'tfs': 1,
|
||||
'epsilon_cutoff': 0,
|
||||
'eta_cutoff': 0,
|
||||
'mirostat_mode': 2,
|
||||
'mirostat_tau': 4,
|
||||
'mirostat_eta': 0.1,
|
||||
'use_mancer': False
|
||||
}
|
||||
|
||||
SIMPLE = {
|
||||
'stopping_strings': [
|
||||
'\nUSER:',
|
||||
'\nASSISTANT:',
|
||||
# '\n' +
|
||||
# '</s>',
|
||||
# '<|',
|
||||
# '\n#',
|
||||
# '\n\n\n',
|
||||
],
|
||||
'truncation_length': 4096,
|
||||
}
|
||||
@@ -3,7 +3,13 @@ import random
|
||||
import os
|
||||
import time
|
||||
|
||||
from .local_llm.chat_completion_proxy import get_chat_completion
|
||||
HOST = os.getenv('OPENAI_API_BASE')
|
||||
HOST_TYPE = os.getenv('BACKEND_TYPE') # default None == ChatCompletion
|
||||
|
||||
import openai
|
||||
if HOST is not None:
|
||||
openai.api_base = HOST
|
||||
|
||||
|
||||
def retry_with_exponential_backoff(
|
||||
@@ -102,10 +108,17 @@ def aretry_with_exponential_backoff(
|
||||
|
||||
@aretry_with_exponential_backoff
|
||||
async def acompletions_with_backoff(**kwargs):
|
||||
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs['deployment_id'] = azure_openai_deployment
|
||||
return await openai.ChatCompletion.acreate(**kwargs)
|
||||
|
||||
# Local model
|
||||
if HOST_TYPE is not None:
|
||||
return await get_chat_completion(**kwargs)
|
||||
|
||||
# OpenAI / Azure model
|
||||
else:
|
||||
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
|
||||
if azure_openai_deployment is not None:
|
||||
kwargs['deployment_id'] = azure_openai_deployment
|
||||
return await openai.ChatCompletion.acreate(**kwargs)
|
||||
|
||||
|
||||
@aretry_with_exponential_backoff
|
||||
|
||||
Reference in New Issue
Block a user