Merge branch 'cpacker:main' into main

This commit is contained in:
Ansh Babbar
2023-10-24 11:40:30 +05:30
committed by GitHub
18 changed files with 1170 additions and 171 deletions

View File

@@ -23,5 +23,6 @@ jobs:
python -m pip install --upgrade pip
pip install -r requirements.txt
- name: Run main.py
run: python main.py
- name: Run main.py with input
run: |
echo -e "\n\n\nn" | python main.py

3
.gitignore vendored
View File

@@ -79,3 +79,6 @@ dmypy.json
# Pyre type checker
.pyre/
# MemGPT config files
configs/

View File

@@ -5,7 +5,9 @@
<div align="center">
<strong>Try out our MemGPT chatbot on <a href="https://discord.gg/9GEQrxmVyE">Discord</a>!</strong>
<strong>⭐ NEW: You can now run MemGPT with <a href="https://github.com/cpacker/MemGPT/discussions/67">local LLMs</a> and <a href="https://github.com/cpacker/MemGPT/discussions/65">AutoGen</a>! ⭐ </strong>
[![Discord](https://img.shields.io/discord/1161736243340640419?label=Discord&logo=discord&logoColor=5865F2&style=flat-square&color=5865F2)](https://discord.gg/9GEQrxmVyE)
[![arXiv 2310.08560](https://img.shields.io/badge/arXiv-2310.08560-B31B1B?logo=arxiv&style=flat-square)](https://arxiv.org/abs/2310.08560)
@@ -75,13 +77,6 @@ Install dependencies:
pip install -r requirements.txt
```
Extra step for Windows:
```sh
# only needed on Windows
pip install pyreadline3
```
Add your OpenAI API key to your environment:
```sh
@@ -115,6 +110,12 @@ python main.py --use_azure_openai
To create a new starter user or starter persona (that MemGPT gets initialized with), create a new `.txt` file in [/memgpt/humans/examples](/memgpt/humans/examples) or [/memgpt/personas/examples](/memgpt/personas/examples), then use the `--persona` or `--human` flag when running `main.py`. For example:
```sh
# assuming you created a new file /memgpt/humans/examples/me.txt
python main.py
# Select me.txt during configuration process
```
-- OR --
```sh
# assuming you created a new file /memgpt/humans/examples/me.txt
python main.py --human me.txt
@@ -123,6 +124,11 @@ python main.py --human me.txt
### GPT-3.5 support
You can run MemGPT with GPT-3.5 as the LLM instead of GPT-4:
```sh
python main.py
# Select gpt-3.5 during configuration process
```
-- OR --
```sh
python main.py --model gpt-3.5-turbo
```
@@ -130,7 +136,19 @@ python main.py --model gpt-3.5-turbo
Please report any bugs you encounter regarding MemGPT running on GPT-3.5 to https://github.com/cpacker/MemGPT/issues/59.
### Local LLM support
You can run MemGPT with local LLMs too. See [instructions here](/memgpt/local_llm) and report any bugs/improvements here https://github.com/cpacker/MemGPT/discussions/67.
### `main.py` flags
```text
--first
allows you to send the first message in the chat (by default, MemGPT will send the first message)
--debug
enables debugging output
```
<details>
<summary>Configure via legacy flags</summary>
```text
--model
@@ -139,10 +157,6 @@ Please report any bugs you encounter regarding MemGPT running on GPT-3.5 to htt
load a specific persona file
--human
load a specific human file
--first
allows you to send the first message in the chat (by default, MemGPT will send the first message)
--debug
enables debugging output
--archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH>
load in document database (backed by FAISS index)
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB_PATTERN>"
@@ -152,6 +166,8 @@ Please report any bugs you encounter regarding MemGPT running on GPT-3.5 to htt
--archival_storage_sqldb=<SQLDB_PATH>
load in SQL database
```
</details>
### Interactive CLI commands
@@ -160,8 +176,6 @@ These are the commands for the CLI, **not the Discord bot**! The Discord bot has
While using MemGPT via the CLI (not Discord!) you can run various commands:
```text
//
enter multiline input mode (type // again when done)
/exit
exit the CLI
/save
@@ -293,6 +307,6 @@ Datasets used in our [paper](https://arxiv.org/abs/2310.08560) can be downloaded
- [x] Integration tests
- [x] Integrate with AutoGen ([discussion](https://github.com/cpacker/MemGPT/discussions/65))
- [x] Add official gpt-3.5-turbo support ([discussion](https://github.com/cpacker/MemGPT/discussions/66))
- [x] CLI UI improvements ([issue](https://github.com/cpacker/MemGPT/issues/11))
- [x] Add support for other LLM backends ([issue](https://github.com/cpacker/MemGPT/issues/18), [discussion](https://github.com/cpacker/MemGPT/discussions/67))
- [ ] Release MemGPT family of open models (eg finetuned Mistral) ([discussion](https://github.com/cpacker/MemGPT/discussions/67))
- [ ] CLI UI improvements ([issue](https://github.com/cpacker/MemGPT/issues/11))
- [ ] Add support for other LLM backends ([issue](https://github.com/cpacker/MemGPT/issues/18))

307
config.py Normal file
View File

@@ -0,0 +1,307 @@
import glob
import json
import os
import textwrap
import interface
import questionary
from colorama import Fore, Style
from typing import List, Type
import memgpt.utils as utils
from memgpt.personas.personas import get_persona_text
from memgpt.humans.humans import get_human_text
model_choices = [
questionary.Choice("gpt-4"),
questionary.Choice(
"gpt-3.5-turbo (experimental! function-calling performance is not quite at the level of gpt-4 yet)",
value="gpt-3.5-turbo",
),
]
class Config:
personas_dir = os.path.join("memgpt", "personas", "examples")
humans_dir = os.path.join("memgpt", "humans", "examples")
configs_dir = "configs"
def __init__(self):
self.load_type = None
self.archival_storage_files = None
self.compute_embeddings = False
self.agent_save_file = None
self.persistence_manager_save_file = None
self.host = os.getenv("OPENAI_API_BASE")
self.index = None
@classmethod
async def legacy_flags_init(
cls: Type["Config"],
model: str,
memgpt_persona: str,
human_persona: str,
load_type: str = None,
archival_storage_files: str = None,
archival_storage_index: str = None,
compute_embeddings: bool = False,
):
self = cls()
self.model = model
self.memgpt_persona = memgpt_persona
self.human_persona = human_persona
self.load_type = load_type
self.archival_storage_files = archival_storage_files
self.archival_storage_index = archival_storage_index
self.compute_embeddings = compute_embeddings
recompute_embeddings = self.compute_embeddings
if self.archival_storage_index:
recompute_embeddings = False # TODO Legacy support -- can't recompute embeddings on a path that's not specified.
if self.archival_storage_files:
await self.configure_archival_storage(recompute_embeddings)
return self
@classmethod
async def config_init(cls: Type["Config"], config_file: str = None):
self = cls()
self.config_file = config_file
if self.config_file is None:
cfg = Config.get_most_recent_config()
use_cfg = False
if cfg:
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}"
)
use_cfg = await questionary.confirm(
f"Use most recent config file '{cfg}'?"
).ask_async()
if use_cfg:
self.config_file = cfg
if self.config_file:
self.load_config(self.config_file)
recompute_embeddings = False
if self.compute_embeddings:
if self.archival_storage_index:
recompute_embeddings = await questionary.confirm(
f"Would you like to recompute embeddings? Do this if your files have changed.\n Files: {self.archival_storage_files}",
default=False,
).ask_async()
else:
recompute_embeddings = True
if self.load_type:
await self.configure_archival_storage(recompute_embeddings)
self.write_config()
return self
# print("No settings file found, configuring MemGPT...")
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}"
)
self.model = await questionary.select(
"Which model would you like to use?",
model_choices,
default=model_choices[0],
).ask_async()
self.memgpt_persona = await questionary.select(
"Which persona would you like MemGPT to use?",
Config.get_memgpt_personas(),
).ask_async()
print(self.memgpt_persona)
self.human_persona = await questionary.select(
"Which user would you like to use?",
Config.get_user_personas(),
).ask_async()
self.archival_storage_index = None
self.preload_archival = await questionary.confirm(
"Would you like to preload anything into MemGPT's archival memory?"
).ask_async()
if self.preload_archival:
self.load_type = await questionary.select(
"What would you like to load?",
choices=[
questionary.Choice("A folder or file", value="folder"),
questionary.Choice("A SQL database", value="sql"),
questionary.Choice("A glob pattern", value="glob"),
],
).ask_async()
if self.load_type == "folder" or self.load_type == "sql":
archival_storage_path = await questionary.path(
"Please enter the folder or file (tab for autocomplete):"
).ask_async()
if os.path.isdir(archival_storage_path):
self.archival_storage_files = os.path.join(
archival_storage_path, "*"
)
else:
self.archival_storage_files = archival_storage_path
else:
self.archival_storage_files = await questionary.path(
"Please enter the glob pattern (tab for autocomplete):"
).ask_async()
self.compute_embeddings = await questionary.confirm(
"Would you like to compute embeddings over these files to enable embeddings search?"
).ask_async()
await self.configure_archival_storage(self.compute_embeddings)
self.write_config()
return self
async def configure_archival_storage(self, recompute_embeddings):
if recompute_embeddings:
if self.host:
interface.warning_message(
"⛔️ Embeddings on a non-OpenAI endpoint are not yet supported, falling back to substring matching search."
)
else:
self.archival_storage_index = (
await utils.prepare_archival_index_from_files_compute_embeddings(
self.archival_storage_files
)
)
if self.compute_embeddings and self.archival_storage_index:
self.index, self.archival_database = utils.prepare_archival_index(
self.archival_storage_index
)
else:
self.archival_database = utils.prepare_archival_index_from_files(
self.archival_storage_files
)
def to_dict(self):
return {
"model": self.model,
"memgpt_persona": self.memgpt_persona,
"human_persona": self.human_persona,
"preload_archival": self.preload_archival,
"archival_storage_files": self.archival_storage_files,
"archival_storage_index": self.archival_storage_index,
"compute_embeddings": self.compute_embeddings,
"load_type": self.load_type,
"agent_save_file": self.agent_save_file,
"persistence_manager_save_file": self.persistence_manager_save_file,
"host": self.host,
}
def load_config(self, config_file):
with open(config_file, "rt") as f:
cfg = json.load(f)
self.model = cfg["model"]
self.memgpt_persona = cfg["memgpt_persona"]
self.human_persona = cfg["human_persona"]
self.preload_archival = cfg["preload_archival"]
self.archival_storage_files = cfg["archival_storage_files"]
self.archival_storage_index = cfg["archival_storage_index"]
self.compute_embeddings = cfg["compute_embeddings"]
self.load_type = cfg["load_type"]
self.agent_save_file = cfg["agent_save_file"]
self.persistence_manager_save_file = cfg["persistence_manager_save_file"]
self.host = cfg["host"]
def write_config(self, configs_dir=None):
if configs_dir is None:
configs_dir = Config.configs_dir
os.makedirs(configs_dir, exist_ok=True)
if self.config_file is None:
filename = os.path.join(
configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_")
)
self.config_file = f"{filename}.json"
with open(self.config_file, "wt") as f:
json.dump(self.to_dict(), f, indent=4)
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}"
)
@staticmethod
def get_memgpt_personas(dir_path=None):
if dir_path is None:
dir_path = Config.personas_dir
all_personas = Config.get_personas(dir_path)
default_personas = [
"sam",
"sam_pov",
"memgpt_starter",
"memgpt_doc",
"sam_simple_pov_gpt35",
]
custom_personas = list(set(all_personas) - set(default_personas))
return Config.get_persona_choices(
[p for p in custom_personas + default_personas], get_persona_text
) + [
questionary.Separator(),
questionary.Choice(
f"📝 You can create your own personas by adding .txt files to {dir_path}.",
disabled=True,
),
]
@staticmethod
def get_user_personas(dir_path=None):
if dir_path is None:
dir_path = Config.humans_dir
all_personas = Config.get_personas(dir_path)
default_personas = ["basic", "cs_phd"]
custom_personas = list(set(all_personas) - set(default_personas))
return Config.get_persona_choices(
[p for p in custom_personas + default_personas], get_human_text
) + [
questionary.Separator(),
questionary.Choice(
f"📝 You can create your own human profiles by adding .txt files to {dir_path}.",
disabled=True,
),
]
@staticmethod
def get_personas(dir_path) -> List[str]:
files = sorted(glob.glob(os.path.join(dir_path, "*.txt")))
stems = []
for f in files:
filename = os.path.basename(f)
stem, _ = os.path.splitext(filename)
stems.append(stem)
return stems
@staticmethod
def get_persona_choices(personas, text_getter):
return [
questionary.Choice(
title=[
("class:question", f"{p}"),
("class:text", f"\n{indent(text_getter(p))}"),
],
value=p,
)
for p in personas
]
@staticmethod
def get_most_recent_config(configs_dir=None):
if configs_dir is None:
configs_dir = Config.configs_dir
os.makedirs(configs_dir, exist_ok=True)
files = [
os.path.join(configs_dir, f)
for f in os.listdir(configs_dir)
if os.path.isfile(os.path.join(configs_dir, f))
]
# Return the file with the most recent modification time
if len(files) == 0:
return None
return max(files, key=os.path.getmtime)
def indent(text, num_lines=5):
lines = textwrap.fill(text, width=100).split("\n")
if len(lines) > num_lines:
lines = lines[: num_lines - 1] + ["... (truncated)", lines[-1]]
return " " + "\n ".join(lines)

View File

@@ -10,131 +10,172 @@ init(autoreset=True)
# DEBUG = True # puts full message outputs in the terminal
DEBUG = False # only dumps important messages in the terminal
def important_message(msg):
print(f'{Fore.MAGENTA}{Style.BRIGHT}{msg}{Style.RESET_ALL}')
print(f"{Fore.MAGENTA}{Style.BRIGHT}{msg}{Style.RESET_ALL}")
def warning_message(msg):
print(f"{Fore.RED}{Style.BRIGHT}{msg}{Style.RESET_ALL}")
async def internal_monologue(msg):
# ANSI escape code for italic is '\x1B[3m'
print(f'\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}')
print(f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}")
async def assistant_message(msg):
print(f'{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}')
print(f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}")
async def memory_message(msg):
print(f'{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}')
print(
f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{msg}{Style.RESET_ALL}"
)
async def system_message(msg):
printd(f'{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}')
printd(
f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}"
)
async def user_message(msg, raw=False):
if isinstance(msg, str):
if raw:
printd(f'{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}')
printd(f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}")
return
else:
try:
msg_json = json.loads(msg)
except:
printd(f"Warning: failed to parse user message into json")
printd(f'{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}')
printd(
f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg}{Style.RESET_ALL}"
)
return
if msg_json['type'] == 'user_message':
msg_json.pop('type')
printd(f'{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}')
elif msg_json['type'] == 'heartbeat':
if msg_json["type"] == "user_message":
msg_json.pop("type")
printd(f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
elif msg_json["type"] == "heartbeat":
if DEBUG:
msg_json.pop('type')
printd(f'{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}')
elif msg_json['type'] == 'system_message':
msg_json.pop('type')
printd(f'{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}')
msg_json.pop("type")
printd(
f"{Fore.GREEN}{Style.BRIGHT}💓 {Fore.GREEN}{msg_json}{Style.RESET_ALL}"
)
elif msg_json["type"] == "system_message":
msg_json.pop("type")
printd(f"{Fore.GREEN}{Style.BRIGHT}🖥️ {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
else:
printd(f'{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}')
printd(f"{Fore.GREEN}{Style.BRIGHT}🧑 {Fore.GREEN}{msg_json}{Style.RESET_ALL}")
async def function_message(msg):
if isinstance(msg, dict):
printd(f'{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}')
printd(f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}")
return
if msg.startswith('Success: '):
printd(f'{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}')
elif msg.startswith('Error: '):
printd(f'{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}')
elif msg.startswith('Running '):
if msg.startswith("Success: "):
printd(
f"{Fore.RED}{Style.BRIGHT}⚡🟢 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
elif msg.startswith("Error: "):
printd(
f"{Fore.RED}{Style.BRIGHT}⚡🔴 [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
elif msg.startswith("Running "):
if DEBUG:
printd(f'{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}')
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
else:
if 'memory' in msg:
match = re.search(r'Running (\w+)\((.*)\)', msg)
if "memory" in msg:
match = re.search(r"Running (\w+)\((.*)\)", msg)
if match:
function_name = match.group(1)
function_args = match.group(2)
print(f'{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:')
print(
f"{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:"
)
try:
msg_dict = eval(function_args)
if function_name == 'archival_memory_search':
print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}')
if function_name == "archival_memory_search":
print(
f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}'
)
else:
print(f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}{msg_dict["new_content"]}')
print(
f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}{msg_dict["new_content"]}'
)
except Exception as e:
printd(e)
printd(msg_dict)
pass
else:
printd(f"Warning: did not recognize function message")
printd(f'{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}')
elif 'send_message' in msg:
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
elif "send_message" in msg:
# ignore in debug mode
pass
else:
printd(f'{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}')
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
else:
try:
msg_dict = json.loads(msg)
if "status" in msg_dict and msg_dict["status"] == "OK":
printd(f'{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}')
printd(
f"{Fore.GREEN}{Style.BRIGHT}⚡ [function] {Fore.GREEN}{msg}{Style.RESET_ALL}"
)
except Exception:
printd(f"Warning: did not recognize function message {type(msg)} {msg}")
printd(f'{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}')
printd(
f"{Fore.RED}{Style.BRIGHT}⚡ [function] {Fore.RED}{msg}{Style.RESET_ALL}"
)
async def print_messages(message_sequence):
for msg in message_sequence:
role = msg['role']
content = msg['content']
role = msg["role"]
content = msg["content"]
if role == 'system':
if role == "system":
await system_message(content)
elif role == 'assistant':
elif role == "assistant":
# Differentiate between internal monologue, function calls, and messages
if msg.get('function_call'):
if msg.get("function_call"):
if content is not None:
await internal_monologue(content)
await function_message(msg['function_call'])
await function_message(msg["function_call"])
# assistant_message(content)
else:
await internal_monologue(content)
elif role == 'user':
elif role == "user":
await user_message(content)
elif role == 'function':
elif role == "function":
await function_message(content)
else:
print(f'Unknown role: {content}')
print(f"Unknown role: {content}")
async def print_messages_simple(message_sequence):
for msg in message_sequence:
role = msg['role']
content = msg['content']
role = msg["role"]
content = msg["content"]
if role == 'system':
if role == "system":
await system_message(content)
elif role == 'assistant':
elif role == "assistant":
await assistant_message(content)
elif role == 'user':
elif role == "user":
await user_message(content, raw=True)
else:
print(f'Unknown role: {content}')
print(f"Unknown role: {content}")
async def print_messages_raw(message_sequence):
for msg in message_sequence:

352
main.py
View File

@@ -5,9 +5,11 @@ import glob
import os
import sys
import pickle
import readline
import questionary
from rich.console import Console
console = Console()
import interface # for printing to terminal
@@ -18,56 +20,104 @@ import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithPreloadedArchivalMemory, InMemoryStateManagerWithFaiss
from memgpt.persistence_manager import (
InMemoryStateManager,
InMemoryStateManagerWithPreloadedArchivalMemory,
InMemoryStateManagerWithFaiss,
)
from config import Config
FLAGS = flags.FLAGS
flags.DEFINE_string("persona", default=None, required=False, help="Specify persona")
flags.DEFINE_string("human", default=None, required=False, help="Specify human")
flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model")
flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence")
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
flags.DEFINE_boolean("no_verify", default=False, required=False, help="Bypass message verification")
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)")
flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)")
flags.DEFINE_string("archival_storage_files_compute_embeddings", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern), and compute embeddings over them")
flags.DEFINE_string("archival_storage_sqldb", default="", required=False, help="Specify SQL database to pre-load into archival memory")
flags.DEFINE_string(
"model",
default=constants.DEFAULT_MEMGPT_MODEL,
required=False,
help="Specify the LLM model",
)
flags.DEFINE_boolean(
"first",
default=False,
required=False,
help="Use -first to send the first message in the sequence",
)
flags.DEFINE_boolean(
"debug", default=False, required=False, help="Use -debug to enable debugging output"
)
flags.DEFINE_boolean(
"no_verify", default=False, required=False, help="Bypass message verification"
)
flags.DEFINE_string(
"archival_storage_faiss_path",
default="",
required=False,
help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)",
)
flags.DEFINE_string(
"archival_storage_files",
default="",
required=False,
help="Specify files to pre-load into archival memory (glob pattern)",
)
flags.DEFINE_string(
"archival_storage_files_compute_embeddings",
default="",
required=False,
help="Specify files to pre-load into archival memory (glob pattern), and compute embeddings over them",
)
flags.DEFINE_string(
"archival_storage_sqldb",
default="",
required=False,
help="Specify SQL database to pre-load into archival memory",
)
# Support for Azure OpenAI (see: https://github.com/openai/openai-python#microsoft-azure-endpoints)
flags.DEFINE_boolean("use_azure_openai", default=False, required=False, help="Use Azure OpenAI (requires additional environment variables)")
flags.DEFINE_boolean(
"use_azure_openai",
default=False,
required=False,
help="Use Azure OpenAI (requires additional environment variables)",
)
def clear_line():
if os.name == 'nt': # for windows
if os.name == "nt": # for windows
console.print("\033[A\033[K", end="")
else: # for linux
sys.stdout.write("\033[2K\033[G")
sys.stdout.flush()
def save(memgpt_agent):
filename = utils.get_local_time().replace(' ', '_').replace(':', '_')
def save(memgpt_agent, cfg):
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
filename = f"{filename}.json"
filename = os.path.join('saved_state', filename)
filename = os.path.join("saved_state", filename)
try:
if not os.path.exists("saved_state"):
os.makedirs("saved_state")
memgpt_agent.save_to_json_file(filename)
print(f"Saved checkpoint to: {filename}")
cfg.agent_save_file = filename
except Exception as e:
print(f"Saving state to {filename} failed with: {e}")
# save the persistence manager too
filename = filename.replace('.json', '.persistence.pickle')
filename = filename.replace(".json", ".persistence.pickle")
try:
memgpt_agent.persistence_manager.save(filename)
print(f"Saved persistence manager to: {filename}")
cfg.persistence_manager_save_file = filename
except Exception as e:
print(f"Saving persistence manager to {filename} failed with: {e}")
cfg.write_config()
def load(memgpt_agent, filename):
if filename is not None:
if filename[-5:] != '.json':
filename += '.json'
if filename[-5:] != ".json":
filename += ".json"
try:
memgpt_agent.load_from_json_file_inplace(filename)
print(f"Loaded checkpoint {filename}")
@@ -75,8 +125,12 @@ def load(memgpt_agent, filename):
print(f"Loading {filename} failed with: {e}")
else:
# Load the latest file
print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead")
json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory.
print(
f"/load warning: no checkpoint specified, loading most recent checkpoint instead"
)
json_files = glob.glob(
"saved_state/*.json"
) # This will list all .json files in the current directory.
# Check if there are any json files.
if not json_files:
@@ -91,12 +145,16 @@ def load(memgpt_agent, filename):
print(f"Loading {filename} failed with: {e}")
# need to load persistence manager too
filename = filename.replace('.json', '.persistence.pickle')
filename = filename.replace(".json", ".persistence.pickle")
try:
memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods
memgpt_agent.persistence_manager = InMemoryStateManager.load(
filename
) # TODO(fixme):for different types of persistence managers that require different load/save methods
print(f"Loaded persistence manager from {filename}")
except Exception as e:
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
print(
f"/load warning: loading persistence manager from {filename} failed with: {e}"
)
async def main():
@@ -104,93 +162,199 @@ async def main():
logging.getLogger().setLevel(logging.CRITICAL)
if FLAGS.debug:
logging.getLogger().setLevel(logging.DEBUG)
print("Running... [exit by typing '/exit']")
if any(
(
FLAGS.persona,
FLAGS.human,
FLAGS.model != constants.DEFAULT_MEMGPT_MODEL,
FLAGS.archival_storage_faiss_path,
FLAGS.archival_storage_files,
FLAGS.archival_storage_files_compute_embeddings,
FLAGS.archival_storage_sqldb,
)
):
interface.important_message("⚙️ Using legacy command line arguments.")
model = FLAGS.model
if model is None:
model = constants.DEFAULT_MEMGPT_MODEL
memgpt_persona = FLAGS.persona
if memgpt_persona is None:
memgpt_persona = (
personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT
)
human_persona = FLAGS.human
if human_persona is None:
human_persona = humans.DEFAULT
if FLAGS.archival_storage_files:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="folder",
archival_storage_files=FLAGS.archival_storage_files,
compute_embeddings=False,
)
elif FLAGS.archival_storage_faiss_path:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="folder",
archival_storage_files=FLAGS.archival_storage_faiss_path,
archival_storage_index=FLAGS.archival_storage_faiss_path,
compute_embeddings=True,
)
elif FLAGS.archival_storage_files_compute_embeddings:
print(model)
print(memgpt_persona)
print(human_persona)
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="folder",
archival_storage_files=FLAGS.archival_storage_files_compute_embeddings,
compute_embeddings=True,
)
elif FLAGS.archival_storage_sqldb:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="sql",
archival_storage_files=FLAGS.archival_storage_sqldb,
compute_embeddings=False,
)
else:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
)
else:
cfg = await Config.config_init()
interface.important_message("Running... [exit by typing '/exit']")
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
interface.warning_message(
f"⛔️ Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
)
# Azure OpenAI support
if FLAGS.use_azure_openai:
azure_openai_key = os.getenv('AZURE_OPENAI_KEY')
azure_openai_endpoint = os.getenv('AZURE_OPENAI_ENDPOINT')
azure_openai_version = os.getenv('AZURE_OPENAI_VERSION')
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
if None in [azure_openai_key, azure_openai_endpoint, azure_openai_version, azure_openai_deployment]:
print(f"Error: missing Azure OpenAI environment variables. Please see README section on Azure.")
azure_openai_key = os.getenv("AZURE_OPENAI_KEY")
azure_openai_endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
azure_openai_version = os.getenv("AZURE_OPENAI_VERSION")
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if None in [
azure_openai_key,
azure_openai_endpoint,
azure_openai_version,
azure_openai_deployment,
]:
print(
f"Error: missing Azure OpenAI environment variables. Please see README section on Azure."
)
return
import openai
openai.api_type = "azure"
openai.api_key = azure_openai_key
openai.api_base = azure_openai_endpoint
openai.api_version = azure_openai_version
# deployment gets passed into chatcompletion
else:
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
print(f"Error: AZURE_OPENAI_DEPLOYMENT should not be set if --use_azure_openai is False")
print(
f"Error: AZURE_OPENAI_DEPLOYMENT should not be set if --use_azure_openai is False"
)
return
if FLAGS.model != constants.DEFAULT_MEMGPT_MODEL:
interface.important_message(f"Warning - you are running MemGPT with {FLAGS.model}, which is not officially supported (yet). Expect bugs!")
if FLAGS.archival_storage_faiss_path:
index, archival_database = utils.prepare_archival_index(FLAGS.archival_storage_faiss_path)
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
elif FLAGS.archival_storage_files:
archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files)
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database)
elif FLAGS.archival_storage_files_compute_embeddings:
faiss_save_dir = await utils.prepare_archival_index_from_files_compute_embeddings(FLAGS.archival_storage_files_compute_embeddings)
interface.important_message(f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed).")
index, archival_database = utils.prepare_archival_index(faiss_save_dir)
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
if cfg.index:
persistence_manager = InMemoryStateManagerWithFaiss(
cfg.index, cfg.archival_database
)
elif cfg.archival_storage_files:
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(
cfg.archival_database
)
else:
persistence_manager = InMemoryStateManager()
# Moved defaults out of FLAGS so that we can dynamically select the default persona based on model
chosen_human = FLAGS.human if FLAGS.human is not None else humans.DEFAULT
chosen_persona = FLAGS.persona if FLAGS.persona is not None else (personas.GPT35_DEFAULT if 'gpt-3.5' in FLAGS.model else personas.DEFAULT)
if FLAGS.archival_storage_files_compute_embeddings:
interface.important_message(
f"(legacy) To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={cfg.archival_storage_index} (if your files haven't changed)."
)
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(chosen_persona), humans.get_human_text(chosen_human), interface, persistence_manager)
# Moved defaults out of FLAGS so that we can dynamically select the default persona based on model
chosen_human = cfg.human_persona
chosen_persona = cfg.memgpt_persona
memgpt_agent = presets.use_preset(
presets.DEFAULT,
cfg.model,
personas.get_persona_text(chosen_persona),
humans.get_human_text(chosen_human),
interface,
persistence_manager,
)
print_messages = interface.print_messages
await print_messages(memgpt_agent.messages)
counter = 0
user_input = None
skip_next_user_input = False
user_message = None
USER_GOES_FIRST = FLAGS.first
if FLAGS.archival_storage_sqldb:
if not os.path.exists(FLAGS.archival_storage_sqldb):
print(f"File {FLAGS.archival_storage_sqldb} does not exist")
if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner
if not os.path.exists(cfg.archival_storage_files):
print(f"File {cfg.archival_storage_files} does not exist")
return
# Ingest data from file into archival storage
else:
print(f"Database found! Loading database into archival memory")
data_list = utils.read_database_as_list(FLAGS.archival_storage_sqldb)
data_list = utils.read_database_as_list(cfg.archival_storage_files)
user_message = f"Your archival memory has been loaded with a SQL database called {data_list[0]}, which contains schema {data_list[1]}. Remember to refer to this first while answering any user questions!"
for row in data_list:
await memgpt_agent.persistence_manager.archival_memory.insert(row)
print(f"Database loaded into archival memory.")
if cfg.agent_save_file:
load_save_file = await questionary.confirm(
f"Load in saved agent '{cfg.agent_save_file}'?"
).ask_async()
if load_save_file:
load(memgpt_agent, cfg.agent_save_file)
# auto-exit for
if "GITHUB_ACTIONS" in os.environ:
return
if not USER_GOES_FIRST:
console.input('[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]')
console.input(
"[bold cyan]Hit enter to begin (will request first MemGPT message)[/bold cyan]"
)
clear_line()
print()
while True:
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
# Ask for user input
user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ")
# user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ")
user_input = await questionary.text(
"Enter your message:",
multiline=True,
qmark=">",
).ask_async()
clear_line()
if user_input.startswith('!'):
if user_input.startswith("!"):
print(f"Commands for CLI begin with '/' not '!'")
continue
@@ -201,34 +365,21 @@ async def main():
# Handle CLI commands
# Commands to not get passed as input to MemGPT
if user_input.startswith('/'):
if user_input == "//":
print("Entering multiline mode, type // when done")
user_input_list = []
while True:
user_input = console.input("[bold cyan]>[/bold cyan] ")
clear_line()
if user_input == "//":
break
else:
user_input_list.append(user_input)
# pass multiline inputs to MemGPT
user_message = system.package_user_message("\n".join(user_input_list))
elif user_input.lower() == "/exit":
if user_input.startswith("/"):
if user_input.lower() == "/exit":
# autosave
save(memgpt_agent=memgpt_agent)
save(memgpt_agent=memgpt_agent, cfg=cfg)
break
elif user_input.lower() == "/savechat":
filename = utils.get_local_time().replace(' ', '_').replace(':', '_')
filename = (
utils.get_local_time().replace(" ", "_").replace(":", "_")
)
filename = f"{filename}.pkl"
try:
if not os.path.exists("saved_chats"):
os.makedirs("saved_chats")
with open(os.path.join('saved_chats', filename), 'wb') as f:
with open(os.path.join("saved_chats", filename), "wb") as f:
pickle.dump(memgpt_agent.messages, f)
print(f"Saved messages to: {filename}")
except Exception as e:
@@ -236,10 +387,12 @@ async def main():
continue
elif user_input.lower() == "/save":
save(memgpt_agent=memgpt_agent)
save(memgpt_agent=memgpt_agent, cfg=cfg)
continue
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
elif user_input.lower() == "/load" or user_input.lower().startswith(
"/load "
):
command = user_input.strip().split()
filename = command[1] if len(command) > 1 else None
load(memgpt_agent=memgpt_agent, filename=filename)
@@ -265,17 +418,23 @@ async def main():
continue
elif user_input.lower() == "/model":
if memgpt_agent.model == 'gpt-4':
memgpt_agent.model = 'gpt-3.5-turbo'
elif memgpt_agent.model == 'gpt-3.5-turbo':
memgpt_agent.model = 'gpt-4'
if memgpt_agent.model == "gpt-4":
memgpt_agent.model = "gpt-3.5-turbo"
elif memgpt_agent.model == "gpt-3.5-turbo":
memgpt_agent.model = "gpt-4"
print(f"Updated model to:\n{str(memgpt_agent.model)}")
continue
elif user_input.lower() == "/pop" or user_input.lower().startswith("/pop "):
elif user_input.lower() == "/pop" or user_input.lower().startswith(
"/pop "
):
# Check if there's an additional argument that's an integer
command = user_input.strip().split()
amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 2
amount = (
int(command[1])
if len(command) > 1 and command[1].isdigit()
else 2
)
print(f"Popping last {amount} messages from stack")
for _ in range(min(amount, len(memgpt_agent.messages))):
memgpt_agent.messages.pop()
@@ -304,14 +463,23 @@ async def main():
skip_next_user_input = False
with console.status("[bold cyan]Thinking...") as status:
new_messages, heartbeat_request, function_failed, token_warning = await memgpt_agent.step(user_message, first_message=False, skip_verify=FLAGS.no_verify)
(
new_messages,
heartbeat_request,
function_failed,
token_warning,
) = await memgpt_agent.step(
user_message, first_message=False, skip_verify=FLAGS.no_verify
)
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
if token_warning:
user_message = system.get_token_limit_warning()
skip_next_user_input = True
elif function_failed:
user_message = system.get_heartbeat(constants.FUNC_FAILED_HEARTBEAT_MESSAGE)
user_message = system.get_heartbeat(
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
)
skip_next_user_input = True
elif heartbeat_request:
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
@@ -322,10 +490,10 @@ async def main():
print("Finished.")
if __name__ == '__main__':
if __name__ == "__main__":
def run(argv):
loop = asyncio.get_event_loop()
loop.run_until_complete(main())
app.run(run)
app.run(run)

View File

@@ -80,7 +80,7 @@ class MemGPTAgent(ConversableAgent):
def pretty_concat(messages):
"""AutoGen expects a single response, but MemGPT may take many steps.
To accomadate AutoGen, concatenate all of MemGPT's steps into one and return as a single message.
To accommodate AutoGen, concatenate all of MemGPT's steps into one and return as a single message.
"""
ret = {
'role': 'assistant',

103
memgpt/local_llm/README.md Normal file
View File

@@ -0,0 +1,103 @@
⁉️ Need help configuring local LLMs with MemGPT? Ask for help on [our Discord](https://discord.gg/9GEQrxmVyE) or [post on the GitHub discussion](https://github.com/cpacker/MemGPT/discussions/67).
👀 If you have a hosted ChatCompletion-compatible endpoint that works with function calling, you can simply set `OPENAI_API_BASE` (`export OPENAI_API_BASE=...`) to the IP+port of your endpoint. **As of 10/22/2023, most ChatCompletion endpoints do *NOT* support function calls, so if you want to play with MemGPT and open models, you probably need to follow the instructions below.**
🙋 Our examples assume that you're using [oobabooga web UI](https://github.com/oobabooga/text-generation-webui#starting-the-web-ui) to put your LLMs behind a web server. If you need help setting this up, check the instructions [here](https://github.com/oobabooga/text-generation-webui#starting-the-web-ui). More LLM web server support to come soon (tell us what you use and we'll add it)!
---
# How to connect MemGPT to non-OpenAI LLMs
**If you have an LLM that is function-call finetuned**:
- Implement a wrapper class for that model
- The wrapper class needs to implement two functions:
- One to go from ChatCompletion messages/functions schema to a prompt string
- And one to go from raw LLM outputs to a ChatCompletion response
- Put that model behind a server (e.g. using WebUI) and set `OPENAI_API_BASE`
```python
class LLMChatCompletionWrapper(ABC):
@abstractmethod
def chat_completion_to_prompt(self, messages, functions):
"""Go from ChatCompletion to a single prompt string"""
pass
@abstractmethod
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn the LLM output string into a ChatCompletion response"""
pass
```
## Example with [Airoboros](https://huggingface.co/jondurbin/airoboros-l2-70b-2.1) (llama2 finetune)
To help you get started, we've implemented an example wrapper class for a popular llama2 model **finetuned on function calling** (Airoboros). We want MemGPT to run well on open models as much as you do, so we'll be actively updating this page with more examples. Additionally, we welcome contributions from the community! If you find an open LLM that works well with MemGPT, please open a PR with a model wrapper and we'll merge it ASAP.
```python
class Airoboros21Wrapper(LLMChatCompletionWrapper):
"""Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1"""
def chat_completion_to_prompt(self, messages, functions):
"""
Examples for how airoboros expects its prompt inputs: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
Examples for how airoboros expects to see function schemas: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling
"""
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn raw LLM output into a ChatCompletion style response with:
"message" = {
"role": "assistant",
"content": ...,
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
}
}
"""
```
See full file [here](llm_chat_completion_wrappers/airoboros.py). WebUI exposes a lot of parameters that can dramatically change LLM outputs, to change these you can modify the [WebUI settings file](/memgpt/local_llm/webui/settings.py).
### Running the example
```sh
# running airoboros behind a textgen webui server
export OPENAI_API_BASE = <pointing at webui server>
export BACKEND_TYPE = webui
# using --no_verify because this airoboros example does not output inner monologue, just functions
# airoboros is able to properly call `send_message`
$ python3 main.py --no_verify
Running... [exit by typing '/exit']
💭 Bootup sequence complete. Persona activated. Testing messaging functionality.
💭 None
🤖 Welcome! My name is Sam. How can I assist you today?
Enter your message: My name is Brad, not Chad...
💭 None
⚡🧠 [function] updating memory with core_memory_replace:
First name: Chad
→ First name: Brad
```
---
## Status of ChatCompletion w/ function calling and open LLMs
MemGPT uses function calling to do memory management. With [OpenAI's ChatCompletion API](https://platform.openai.com/docs/api-reference/chat/), you can pass in a function schema in the `functions` keyword arg, and the API response will include a `function_call` field that includes the function name and the function arguments (generated JSON). How this works under the hood is your `functions` keyword is combined with the `messages` and `system` to form one big string input to the transformer, and the output of the transformer is parsed to extract the JSON function call.
In the future, more open LLMs and LLM servers (that can host OpenAI-compatable ChatCompletion endpoints) may start including parsing code to do this automatically as standard practice. However, in the meantime, when you see a model that says it supports “function calling”, like Airoboros, it doesn't mean that you can just load Airoboros into a ChatCompletion-compatable endpoint like WebUI, and then use the same OpenAI API call and it'll just work.
1. When a model page says it supports function calling, they probably mean that the model was finetuned on some function call data (not that you can just use ChatCompletion with functions out-of-the-box). Remember, LLMs are just string-in-string-out, so there are many ways to format the function call data. E.g. Airoboros formats the function schema in YAML style (see https://huggingface.co/jondurbin/airoboros-l2-70b-3.1.2#agentfunction-calling) and the output is in JSON style. To get this to work behind a ChatCompletion API, you still have to do the parsing from `functions` keyword arg (containing the schema) to the model's expected schema style in the prompt (YAML for Airoboros), and you have to run some code to extract the function call (JSON for Airoboros) and package it cleanly as a `function_call` field in the response.
2. Partly because of how complex it is to support function calling, most (all?) of the community projects that do OpenAI ChatCompletion endpoints for arbitrary open LLMs do not support function calling, because if they did, they would need to write model-specific parsing code for each one.
## What is this all this extra code for?
Because of the poor state of function calling support in existing ChatCompletion API serving code, we instead provide a light wrapper on top of ChatCompletion that adds parsers to handle function calling support. These parsers need to be specific to the model you're using (or at least specific to the way it was trained on function calling). We hope that our example code will help the community add additional compatability of MemGPT with more function-calling LLMs - we will also add more model support as we test more models and find those that work well enough to run MemGPT's function set.
To run the example of MemGPT with Airoboros, you'll need to host the model behind some LLM web server (for example [webui](https://github.com/oobabooga/text-generation-webui#starting-the-web-ui)). Then, all you need to do is point MemGPT to this API endpoint by setting the environment variables `OPENAI_API_BASE` and `BACKEND_TYPE`. Now, instead of calling ChatCompletion on OpenAI's API, MemGPT will use it's own ChatCompletion wrapper that parses the system, messages, and function arguments into a format that Airoboros has been finetuned on, and once Airoboros generates a string output, MemGPT will parse the response to extract a potential function call (knowing what we know about Airoboros expected function call output).

View File

View File

@@ -0,0 +1,77 @@
"""Key idea: create drop-in replacement for agent's ChatCompletion call that runs on an OpenLLM backend"""
import os
import requests
import json
from .webui.api import get_webui_completion
from .llm_chat_completion_wrappers import airoboros
from .utils import DotDict
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
DEBUG = False
async def get_chat_completion(
model, # no model, since the model is fixed to whatever you set in your own backend
messages,
functions,
function_call="auto",
):
if function_call != "auto":
raise ValueError(f"function_call == {function_call} not supported (auto only)")
if model == "airoboros_v2.1":
llm_wrapper = airoboros.Airoboros21Wrapper()
else:
# Warn the user that we're using the fallback
print(
f"Warning: could not find an LLM wrapper for {model}, using the airoboros wrapper"
)
llm_wrapper = airoboros.Airoboros21Wrapper()
# First step: turn the message sequence into a prompt that the model expects
prompt = llm_wrapper.chat_completion_to_prompt(messages, functions)
if DEBUG:
print(prompt)
try:
if HOST_TYPE == "webui":
result = get_webui_completion(prompt)
else:
print(f"Warning: BACKEND_TYPE was not set, defaulting to webui")
result = get_webui_completion(prompt)
except requests.exceptions.ConnectionError as e:
raise ValueError(f"Was unable to connect to host {HOST}")
if result is None or result == "":
raise Exception(f"Got back an empty response string from {HOST}")
chat_completion_result = llm_wrapper.output_to_chat_completion_response(result)
if DEBUG:
print(json.dumps(chat_completion_result, indent=2))
# unpack with response.choices[0].message.content
response = DotDict(
{
"model": None,
"choices": [
DotDict(
{
"message": DotDict(chat_completion_result),
"finish_reason": "stop", # TODO vary based on backend response
}
)
],
"usage": DotDict(
{
# TODO fix, actually use real info
"prompt_tokens": 0,
"completion_tokens": 0,
"total_tokens": 0,
}
),
}
)
return response

View File

@@ -0,0 +1,204 @@
import json
from .wrapper_base import LLMChatCompletionWrapper
class Airoboros21Wrapper(LLMChatCompletionWrapper):
"""Wrapper for Airoboros 70b v2.1: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1
Note: this wrapper formats a prompt that only generates JSON, no inner thoughts
"""
def __init__(
self,
simplify_json_content=True,
clean_function_args=True,
include_assistant_prefix=True,
include_opening_brace_in_prefix=True,
include_section_separators=True,
):
self.simplify_json_content = simplify_json_content
self.clean_func_args = clean_function_args
self.include_assistant_prefix = include_assistant_prefix
self.include_opening_brance_in_prefix = include_opening_brace_in_prefix
self.include_section_separators = include_section_separators
def chat_completion_to_prompt(self, messages, functions):
"""Example for airoboros: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#prompt-format
A chat.
USER: {prompt}
ASSISTANT:
Functions support: https://huggingface.co/jondurbin/airoboros-l2-70b-2.1#agentfunction-calling
As an AI assistant, please select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format.
Input: I want to know how many times 'Python' is mentioned in my text file.
Available functions:
file_analytics:
description: This tool performs various operations on a text file.
params:
action: The operation we want to perform on the data, such as "count_occurrences", "find_line", etc.
filters:
keyword: The word or phrase we want to search for.
OpenAI functions schema style:
{
"name": "send_message",
"description": "Sends a message to the human user",
"parameters": {
"type": "object",
"properties": {
# https://json-schema.org/understanding-json-schema/reference/array.html
"message": {
"type": "string",
"description": "Message contents. All unicode (including emojis) are supported.",
},
},
"required": ["message"],
}
},
"""
prompt = ""
# System insturctions go first
assert messages[0]["role"] == "system"
prompt += messages[0]["content"]
# Next is the functions preamble
def create_function_description(schema):
# airorobos style
func_str = ""
func_str += f"{schema['name']}:"
func_str += f"\n description: {schema['description']}"
func_str += f"\n params:"
for param_k, param_v in schema["parameters"]["properties"].items():
# TODO we're ignoring type
func_str += f"\n {param_k}: {param_v['description']}"
# TODO we're ignoring schema['parameters']['required']
return func_str
# prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the user's input. Provide your response in JSON format."
prompt += f"\nPlease select the most suitable function and parameters from the list of available functions below, based on the ongoing conversation. Provide your response in JSON format."
prompt += f"\nAvailable functions:"
for function_dict in functions:
prompt += f"\n{create_function_description(function_dict)}"
def create_function_call(function_call):
"""Go from ChatCompletion to Airoboros style function trace (in prompt)
ChatCompletion data (inside message['function_call']):
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
Airoboros output:
{
"function": "send_message",
"params": {
"message": "Hello there! I am Sam, an AI developed by Liminal Corp. How can I assist you today?"
}
}
"""
airo_func_call = {
"function": function_call["name"],
"params": json.loads(function_call["arguments"]),
}
return json.dumps(airo_func_call, indent=2)
# Add a sep for the conversation
if self.include_section_separators:
prompt += "\n### INPUT"
# Last are the user/assistant messages
for message in messages[1:]:
assert message["role"] in ["user", "assistant", "function"], message
if message["role"] == "user":
if self.simplify_json_content:
try:
content_json = json.loads(message["content"])
content_simple = content_json["message"]
prompt += f"\nUSER: {content_simple}"
except:
prompt += f"\nUSER: {message['content']}"
elif message["role"] == "assistant":
prompt += f"\nASSISTANT: {message['content']}"
# need to add the function call if there was one
if message["function_call"]:
prompt += f"\n{create_function_call(message['function_call'])}"
elif message["role"] == "function":
# TODO find a good way to add this
# prompt += f"\nASSISTANT: (function return) {message['content']}"
prompt += f"\nFUNCTION RETURN: {message['content']}"
continue
else:
raise ValueError(message)
# Add a sep for the response
if self.include_section_separators:
prompt += "\n### RESPONSE"
if self.include_assistant_prefix:
prompt += f"\nASSISTANT:"
if self.include_opening_brance_in_prefix:
prompt += "\n{"
return prompt
def clean_function_args(self, function_name, function_args):
"""Some basic MemGPT-specific cleaning of function args"""
cleaned_function_name = function_name
cleaned_function_args = function_args.copy()
if function_name == "send_message":
# strip request_heartbeat
cleaned_function_args.pop("request_heartbeat", None)
# TODO more cleaning to fix errors LLM makes
return cleaned_function_name, cleaned_function_args
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn raw LLM output into a ChatCompletion style response with:
"message" = {
"role": "assistant",
"content": ...,
"function_call": {
"name": ...
"arguments": {
"arg1": val1,
...
}
}
}
"""
if self.include_opening_brance_in_prefix and raw_llm_output[0] != "{":
raw_llm_output = "{" + raw_llm_output
try:
function_json_output = json.loads(raw_llm_output)
except Exception as e:
raise Exception(f"Failed to decode JSON from LLM output:\n{raw_llm_output}")
function_name = function_json_output["function"]
function_parameters = function_json_output["params"]
if self.clean_func_args:
function_name, function_parameters = self.clean_function_args(
function_name, function_parameters
)
message = {
"role": "assistant",
"content": None,
"function_call": {
"name": function_name,
"arguments": json.dumps(function_parameters),
},
}
return message

View File

@@ -0,0 +1,13 @@
from abc import ABC, abstractmethod
class LLMChatCompletionWrapper(ABC):
@abstractmethod
def chat_completion_to_prompt(self, messages, functions):
"""Go from ChatCompletion to a single prompt string"""
pass
@abstractmethod
def output_to_chat_completion_response(self, raw_llm_output):
"""Turn the LLM output string into a ChatCompletion response"""
pass

View File

@@ -0,0 +1,8 @@
class DotDict(dict):
"""Allow dot access on properties similar to OpenAI response object"""
def __getattr__(self, attr):
return self.get(attr)
def __setattr__(self, key, value):
self[key] = value

View File

@@ -0,0 +1,33 @@
import os
import requests
from .settings import SIMPLE
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
WEBUI_API_SUFFIX = "/api/v1/generate"
DEBUG = False
def get_webui_completion(prompt, settings=SIMPLE):
"""See https://github.com/oobabooga/text-generation-webui for instructions on how to run the LLM web server"""
# Settings for the generation, includes the prompt + stop tokens, max length, etc
request = settings
request["prompt"] = prompt
try:
URI = f"{HOST.strip('/')}{WEBUI_API_SUFFIX}"
response = requests.post(URI, json=request)
if response.status_code == 200:
result = response.json()
result = result["results"][0]["text"]
if DEBUG:
print(f"json API response.text: {result}")
else:
raise Exception(f"API call got non-200 response code for address: {URI}")
except:
# TODO handle gracefully
raise
return result

View File

@@ -0,0 +1,12 @@
SIMPLE = {
"stopping_strings": [
"\nUSER:",
"\nASSISTANT:",
# '\n' +
# '</s>',
# '<|',
# '\n#',
# '\n\n\n',
],
"truncation_length": 4096, # assuming llama2 models
}

View File

@@ -3,8 +3,16 @@ import random
import os
import time
from .local_llm.chat_completion_proxy import get_chat_completion
HOST = os.getenv("OPENAI_API_BASE")
HOST_TYPE = os.getenv("BACKEND_TYPE") # default None == ChatCompletion
import openai
if HOST is not None:
openai.api_base = HOST
def retry_with_exponential_backoff(
func,
@@ -102,18 +110,24 @@ def aretry_with_exponential_backoff(
@aretry_with_exponential_backoff
async def acompletions_with_backoff(**kwargs):
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
if azure_openai_deployment is not None:
kwargs['deployment_id'] = azure_openai_deployment
return await openai.ChatCompletion.acreate(**kwargs)
# Local model
if HOST_TYPE is not None:
return await get_chat_completion(**kwargs)
# OpenAI / Azure model
else:
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs["deployment_id"] = azure_openai_deployment
return await openai.ChatCompletion.acreate(**kwargs)
@aretry_with_exponential_backoff
async def acreate_embedding_with_backoff(**kwargs):
"""Wrapper around Embedding.acreate w/ backoff"""
azure_openai_deployment = os.getenv('AZURE_OPENAI_DEPLOYMENT')
azure_openai_deployment = os.getenv("AZURE_OPENAI_DEPLOYMENT")
if azure_openai_deployment is not None:
kwargs['deployment_id'] = azure_openai_deployment
kwargs["deployment_id"] = azure_openai_deployment
return await openai.Embedding.acreate(**kwargs)
@@ -121,6 +135,6 @@ async def async_get_embedding_with_backoff(text, model="text-embedding-ada-002")
"""To get text embeddings, import/call this function
It specifies defaults + handles rate-limiting + is async"""
text = text.replace("\n", " ")
response = await acreate_embedding_with_backoff(input = [text], model=model)
embedding = response['data'][0]['embedding']
return embedding
response = await acreate_embedding_with_backoff(input=[text], model=model)
embedding = response["data"][0]["embedding"]
return embedding

View File

@@ -9,6 +9,7 @@ pybars3
pymupdf
python-dotenv
pytz
questionary
rich
tiktoken
timezonefinder