cli improvements using questionary

This commit is contained in:
Vivian Fang
2023-10-22 22:29:15 -07:00
parent 388a1262b1
commit 0705d2464d
4 changed files with 401 additions and 56 deletions

3
.gitignore vendored
View File

@@ -79,3 +79,6 @@ dmypy.json
# Pyre type checker
.pyre/
# MemGPT config files
configs/

280
config.py Normal file
View File

@@ -0,0 +1,280 @@
import asyncio
import glob
import json
import os
import textwrap
import interface
import questionary
from colorama import Fore, Style, init
from rich.console import Console
console = Console()
from typing import List, Type
import memgpt.utils as utils
from memgpt.personas.personas import get_persona_text
from memgpt.humans.humans import get_human_text
model_choices = [
questionary.Choice("gpt-4"),
questionary.Choice(
"gpt-3.5-turbo (experimental! function-calling performance is not quite at the level of gpt-4 yet)",
value="gpt-3.5-turbo",
),
]
class Config:
personas_dir = os.path.join("memgpt", "personas", "examples")
humans_dir = os.path.join("memgpt", "humans", "examples")
configs_dir = "configs"
def __init__(self):
self.load_type = None
self.archival_storage_files = None
self.compute_embeddings = False
self.agent_save_file = None
self.persistence_manager_save_file = None
@classmethod
async def legacy_flags_init(
cls: Type["config"],
model: str,
memgpt_persona: str,
human_persona: str,
load_type: str = None,
archival_storage_files: str = None,
archival_storage_index: str = None,
compute_embeddings: bool = False,
):
self = cls()
self.model = model
self.memgpt_persona = memgpt_persona
self.human_persona = human_persona
self.load_type = load_type
self.archival_storage_files = archival_storage_files
self.archival_storage_index = archival_storage_index
self.compute_embeddings = compute_embeddings
recompute_embeddings = self.compute_embeddings
if self.archival_storage_index:
recompute_embeddings = questionary.confirm(
f"Would you like to recompute embeddings? Do this if your files have changed.\nFiles:{self.archival_storage_files}",
default=False,
)
await self.configure_archival_storage(recompute_embeddings)
return self
@classmethod
async def config_init(cls: Type["Config"], config_file: str = None):
self = cls()
self.config_file = config_file
if self.config_file is None:
cfg = Config.get_most_recent_config()
use_cfg = False
if cfg:
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}"
)
use_cfg = await questionary.confirm(
f"Use most recent config file '{cfg}'?"
).ask_async()
if use_cfg:
self.config_file = cfg
if self.config_file:
self.load_config(self.config_file)
recompute_embeddings = False
if self.compute_embeddings:
if self.archival_storage_index:
recompute_embeddings = await questionary.confirm(
f"Would you like to recompute embeddings? Do this if your files have changed.\n Files: {self.archival_storage_files}",
default=False,
).ask_async()
else:
recompute_embeddings = True
if self.load_type:
await self.configure_archival_storage(recompute_embeddings)
self.write_config()
return self
# print("No settings file found, configuring MemGPT...")
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}"
)
self.model = await questionary.select(
"Which model would you like to use?",
model_choices,
default=model_choices[0],
).ask_async()
self.memgpt_persona = await questionary.select(
"Which persona would you like MemGPT to use?",
Config.get_memgpt_personas(),
).ask_async()
print(self.memgpt_persona)
self.human_persona = await questionary.select(
"Which persona would you like to use?",
Config.get_user_personas(),
).ask_async()
self.archival_storage_index = None
self.preload_archival = await questionary.confirm(
"Would you like to preload anything into MemGPT's archival memory?"
).ask_async()
if self.preload_archival:
self.load_type = await questionary.select(
"What would you like to load?",
choices=[
questionary.Choice("A folder or file", value="folder"),
questionary.Choice("A SQL database", value="sql"),
questionary.Choice("A glob pattern", value="glob"),
],
).ask_async()
if self.load_type == "folder" or self.load_type == "sql":
archival_storage_path = await questionary.path(
"Please enter the folder or file (tab for autocomplete):"
).ask_async()
if os.path.isdir(archival_storage_path):
self.archival_storage_files = os.path.join(
archival_storage_path, "*"
)
else:
self.archival_storage_files = archival_storage_path
else:
self.archival_storage_files = await questionary.path(
"Please enter the glob pattern (tab for autocomplete):"
).ask_async()
self.compute_embeddings = await questionary.confirm(
"Would you like to compute embeddings over these files to enable embeddings search?"
).ask_async()
await self.configure_archival_storage(self.compute_embeddings)
self.write_config()
return self
async def configure_archival_storage(self, recompute_embeddings):
if recompute_embeddings:
self.archival_storage_index = (
await utils.prepare_archival_index_from_files_compute_embeddings(
self.archival_storage_files
)
)
if self.compute_embeddings and self.archival_storage_index:
self.index, self.archival_database = utils.prepare_archival_index(
self.archival_storage_index
)
else:
self.archival_database = utils.prepare_archival_index_from_files(
self.archival_storage_files
)
def to_dict(self):
return {
"model": self.model,
"memgpt_persona": self.memgpt_persona,
"human_persona": self.human_persona,
"preload_archival": self.preload_archival,
"archival_storage_files": self.archival_storage_files,
"archival_storage_index": self.archival_storage_index,
"compute_embeddings": self.compute_embeddings,
"load_type": self.load_type,
"agent_save_file": self.agent_save_file,
"persistence_manager_save_file": self.persistence_manager_save_file,
}
def load_config(self, config_file):
with open(config_file, "rt") as f:
cfg = json.load(f)
self.model = cfg["model"]
self.memgpt_persona = cfg["memgpt_persona"]
self.human_persona = cfg["human_persona"]
self.preload_archival = cfg["preload_archival"]
self.archival_storage_files = cfg["archival_storage_files"]
self.archival_storage_index = cfg["archival_storage_index"]
self.compute_embeddings = cfg["compute_embeddings"]
self.load_type = cfg["load_type"]
self.agent_save_file = cfg["agent_save_file"]
self.persistence_manager_save_file = cfg["persistence_manager_save_file"]
def write_config(self, configs_dir=None):
if configs_dir is None:
configs_dir = Config.configs_dir
os.makedirs(configs_dir, exist_ok=True)
if self.config_file is None:
filename = os.path.join(
configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_")
)
self.config_file = f"{filename}.json"
with open(self.config_file, "wt") as f:
json.dump(self.to_dict(), f, indent=4)
print(
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}"
)
@staticmethod
def get_memgpt_personas(dir_path=None):
if dir_path is None:
dir_path = Config.personas_dir
all_personas = Config.get_personas(dir_path)
return Config.get_persona_choices([p for p in all_personas], get_persona_text)
@staticmethod
def get_user_personas(dir_path=None):
if dir_path is None:
dir_path = Config.humans_dir
all_personas = Config.get_personas(dir_path)
return Config.get_persona_choices([p for p in all_personas], get_human_text)
@staticmethod
def get_personas(dir_path) -> List[str]:
files = sorted(glob.glob(os.path.join(dir_path, "*.txt")))
stems = []
for f in files:
filename = os.path.basename(f)
stem, _ = os.path.splitext(filename)
stems.append(stem)
return stems
@staticmethod
def get_persona_choices(personas, text_getter):
return [
questionary.Choice(
title=[
("class:question", f"{p}"),
("class:text", f"\n{indent(text_getter(p))}"),
],
value=p,
)
for p in personas
]
@staticmethod
def get_most_recent_config(configs_dir=None):
if configs_dir is None:
configs_dir = Config.configs_dir
files = [
os.path.join(configs_dir, f)
for f in os.listdir(configs_dir)
if os.path.isfile(os.path.join(configs_dir, f))
]
# Return the file with the most recent modification time
if len(files) == 0:
return None
return max(files, key=os.path.getmtime)
def indent(text, num_lines=5):
lines = textwrap.fill(text, width=100).split("\n")
if len(lines) > num_lines:
lines = lines[: num_lines - 1] + ["... (truncated)", lines[-1]]
return " " + "\n ".join(lines)
config = Config()

173
main.py
View File

@@ -5,7 +5,8 @@ import glob
import os
import sys
import pickle
import readline
import questionary
from rich.console import Console
@@ -25,6 +26,8 @@ from memgpt.persistence_manager import (
InMemoryStateManagerWithFaiss,
)
from config import Config
FLAGS = flags.FLAGS
flags.DEFINE_string("persona", default=None, required=False, help="Specify persona")
flags.DEFINE_string("human", default=None, required=False, help="Specify human")
@@ -87,7 +90,7 @@ def clear_line():
sys.stdout.flush()
def save(memgpt_agent):
def save(memgpt_agent, cfg):
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
filename = f"{filename}.json"
filename = os.path.join("saved_state", filename)
@@ -96,6 +99,7 @@ def save(memgpt_agent):
os.makedirs("saved_state")
memgpt_agent.save_to_json_file(filename)
print(f"Saved checkpoint to: {filename}")
cfg.agent_save_file = filename
except Exception as e:
print(f"Saving state to {filename} failed with: {e}")
@@ -104,8 +108,10 @@ def save(memgpt_agent):
try:
memgpt_agent.persistence_manager.save(filename)
print(f"Saved persistence manager to: {filename}")
cfg.persistence_manager_save_file = filename
except Exception as e:
print(f"Saving persistence manager to {filename} failed with: {e}")
cfg.write_config()
def load(memgpt_agent, filename):
@@ -156,6 +162,79 @@ async def main():
logging.getLogger().setLevel(logging.CRITICAL)
if FLAGS.debug:
logging.getLogger().setLevel(logging.DEBUG)
if any(
(
FLAGS.persona,
FLAGS.human,
FLAGS.model != constants.DEFAULT_MEMGPT_MODEL,
FLAGS.archival_storage_faiss_path,
FLAGS.archival_storage_files,
FLAGS.archival_storage_files_compute_embeddings,
FLAGS.archival_storage_sqldb,
)
):
interface.important_message("⚙️ Using legacy command line arguments.")
model = FLAGS.model
if model is None:
model = constants.DEFAULT_MEMGPT_MODEL
memgpt_persona = FLAGS.persona
if memgpt_persona is None:
memgpt_persona = (
personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT
)
human_persona = FLAGS.human
if human_persona is None:
human_persona = humans.DEFAULT
if FLAGS.archival_storage_files:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="folder",
archival_storage_files=FLAGS.archival_storage_files,
compute_embeddings=False,
)
elif FLAGS.archival_storage_faiss_path:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="folder",
archival_storage_index=FLAGS.archival_storage_index,
compute_embeddings=False,
)
elif FLAGS.archival_storage_files_compute_embeddings:
print(model)
print(memgpt_persona)
print(human_persona)
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="folder",
archival_storage_files=FLAGS.archival_storage_files_compute_embeddings,
compute_embeddings=True,
)
elif FLAGS.archival_storage_sqldb:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
load_type="sql",
archival_storage_files=FLAGS.archival_storage_sqldb,
compute_embeddings=False,
)
else:
cfg = await Config.legacy_flags_init(
model,
memgpt_persona,
human_persona,
)
else:
cfg = await Config.config_init()
print("Running... [exit by typing '/exit']")
# Azure OpenAI support
@@ -190,49 +269,35 @@ async def main():
)
return
if FLAGS.model != constants.DEFAULT_MEMGPT_MODEL:
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
interface.important_message(
f"Warning - you are running MemGPT with {FLAGS.model}, which is not officially supported (yet). Expect bugs!"
f"Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
)
if FLAGS.archival_storage_faiss_path:
index, archival_database = utils.prepare_archival_index(
FLAGS.archival_storage_faiss_path
if cfg.archival_storage_index:
persistence_manager = InMemoryStateManagerWithFaiss(
cfg.index, cfg.archival_database
)
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
elif FLAGS.archival_storage_files:
archival_database = utils.prepare_archival_index_from_files(
FLAGS.archival_storage_files
)
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
elif cfg.archival_storage_files:
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(
archival_database
cfg.archival_database
)
elif FLAGS.archival_storage_files_compute_embeddings:
faiss_save_dir = (
await utils.prepare_archival_index_from_files_compute_embeddings(
FLAGS.archival_storage_files_compute_embeddings
)
)
interface.important_message(
f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed)."
)
index, archival_database = utils.prepare_archival_index(faiss_save_dir)
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
else:
persistence_manager = InMemoryStateManager()
if FLAGS.archival_storage_files_compute_embeddings:
interface.important_message(
f"(legacy) To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={cfg.archival_storage_index} (if your files haven't changed)."
)
# Moved defaults out of FLAGS so that we can dynamically select the default persona based on model
chosen_human = FLAGS.human if FLAGS.human is not None else humans.DEFAULT
chosen_persona = (
FLAGS.persona
if FLAGS.persona is not None
else (personas.GPT35_DEFAULT if "gpt-3.5" in FLAGS.model else personas.DEFAULT)
)
chosen_human = cfg.human_persona
chosen_persona = cfg.memgpt_persona
memgpt_agent = presets.use_preset(
presets.DEFAULT,
FLAGS.model,
cfg.model,
personas.get_persona_text(chosen_persona),
humans.get_human_text(chosen_human),
interface,
@@ -247,19 +312,26 @@ async def main():
user_message = None
USER_GOES_FIRST = FLAGS.first
if FLAGS.archival_storage_sqldb:
if not os.path.exists(FLAGS.archival_storage_sqldb):
print(f"File {FLAGS.archival_storage_sqldb} does not exist")
if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner
if not os.path.exists(cfg.archival_storage_files):
print(f"File {cfg.archival_storage_files} does not exist")
return
# Ingest data from file into archival storage
else:
print(f"Database found! Loading database into archival memory")
data_list = utils.read_database_as_list(FLAGS.archival_storage_sqldb)
data_list = utils.read_database_as_list(cfg.archival_storage_files)
user_message = f"Your archival memory has been loaded with a SQL database called {data_list[0]}, which contains schema {data_list[1]}. Remember to refer to this first while answering any user questions!"
for row in data_list:
await memgpt_agent.persistence_manager.archival_memory.insert(row)
print(f"Database loaded into archival memory.")
if cfg.agent_save_file:
load_save_file = await questionary.confirm(
f"Load in saved agent '{cfg.agent_save_file}'?"
).ask_async()
if load_save_file:
load(memgpt_agent, cfg.agent_save_file)
# auto-exit for
if "GITHUB_ACTIONS" in os.environ:
return
@@ -274,7 +346,12 @@ async def main():
while True:
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
# Ask for user input
user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ")
# user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ")
user_input = await questionary.text(
"Enter your message:",
multiline=True,
qmark=">",
).ask_async()
clear_line()
if user_input.startswith("!"):
@@ -289,25 +366,9 @@ async def main():
# Handle CLI commands
# Commands to not get passed as input to MemGPT
if user_input.startswith("/"):
if user_input == "//":
print("Entering multiline mode, type // when done")
user_input_list = []
while True:
user_input = console.input("[bold cyan]>[/bold cyan] ")
clear_line()
if user_input == "//":
break
else:
user_input_list.append(user_input)
# pass multiline inputs to MemGPT
user_message = system.package_user_message(
"\n".join(user_input_list)
)
elif user_input.lower() == "/exit":
if user_input.lower() == "/exit":
# autosave
save(memgpt_agent=memgpt_agent)
save(memgpt_agent=memgpt_agent, cfg=cfg)
break
elif user_input.lower() == "/savechat":
@@ -326,7 +387,7 @@ async def main():
continue
elif user_input.lower() == "/save":
save(memgpt_agent=memgpt_agent)
save(memgpt_agent=memgpt_agent, cfg=cfg)
continue
elif user_input.lower() == "/load" or user_input.lower().startswith(

View File

@@ -9,6 +9,7 @@ pybars3
pymupdf
python-dotenv
pytz
questionary
rich
tiktoken
timezonefinder