From 0705d2464d5800a2351902d1529a151589b865cd Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Sun, 22 Oct 2023 22:29:15 -0700 Subject: [PATCH] cli improvements using questionary --- .gitignore | 3 + config.py | 280 +++++++++++++++++++++++++++++++++++++++++++++++ main.py | 173 +++++++++++++++++++---------- requirements.txt | 1 + 4 files changed, 401 insertions(+), 56 deletions(-) create mode 100644 config.py diff --git a/.gitignore b/.gitignore index 3ba9ba08..709e8f50 100644 --- a/.gitignore +++ b/.gitignore @@ -79,3 +79,6 @@ dmypy.json # Pyre type checker .pyre/ + +# MemGPT config files +configs/ diff --git a/config.py b/config.py new file mode 100644 index 00000000..2c0d0be6 --- /dev/null +++ b/config.py @@ -0,0 +1,280 @@ +import asyncio +import glob +import json +import os +import textwrap + +import interface + +import questionary + +from colorama import Fore, Style, init +from rich.console import Console + +console = Console() + +from typing import List, Type + +import memgpt.utils as utils + +from memgpt.personas.personas import get_persona_text +from memgpt.humans.humans import get_human_text + +model_choices = [ + questionary.Choice("gpt-4"), + questionary.Choice( + "gpt-3.5-turbo (experimental! function-calling performance is not quite at the level of gpt-4 yet)", + value="gpt-3.5-turbo", + ), +] + + +class Config: + personas_dir = os.path.join("memgpt", "personas", "examples") + humans_dir = os.path.join("memgpt", "humans", "examples") + configs_dir = "configs" + + def __init__(self): + self.load_type = None + self.archival_storage_files = None + self.compute_embeddings = False + self.agent_save_file = None + self.persistence_manager_save_file = None + + @classmethod + async def legacy_flags_init( + cls: Type["config"], + model: str, + memgpt_persona: str, + human_persona: str, + load_type: str = None, + archival_storage_files: str = None, + archival_storage_index: str = None, + compute_embeddings: bool = False, + ): + self = cls() + self.model = model + self.memgpt_persona = memgpt_persona + self.human_persona = human_persona + self.load_type = load_type + self.archival_storage_files = archival_storage_files + self.archival_storage_index = archival_storage_index + self.compute_embeddings = compute_embeddings + recompute_embeddings = self.compute_embeddings + if self.archival_storage_index: + recompute_embeddings = questionary.confirm( + f"Would you like to recompute embeddings? Do this if your files have changed.\nFiles:{self.archival_storage_files}", + default=False, + ) + await self.configure_archival_storage(recompute_embeddings) + return self + + @classmethod + async def config_init(cls: Type["Config"], config_file: str = None): + self = cls() + self.config_file = config_file + if self.config_file is None: + cfg = Config.get_most_recent_config() + use_cfg = False + if cfg: + print( + f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}" + ) + use_cfg = await questionary.confirm( + f"Use most recent config file '{cfg}'?" + ).ask_async() + if use_cfg: + self.config_file = cfg + + if self.config_file: + self.load_config(self.config_file) + recompute_embeddings = False + if self.compute_embeddings: + if self.archival_storage_index: + recompute_embeddings = await questionary.confirm( + f"Would you like to recompute embeddings? Do this if your files have changed.\n Files: {self.archival_storage_files}", + default=False, + ).ask_async() + else: + recompute_embeddings = True + if self.load_type: + await self.configure_archival_storage(recompute_embeddings) + self.write_config() + return self + + # print("No settings file found, configuring MemGPT...") + print( + f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}" + ) + + self.model = await questionary.select( + "Which model would you like to use?", + model_choices, + default=model_choices[0], + ).ask_async() + + self.memgpt_persona = await questionary.select( + "Which persona would you like MemGPT to use?", + Config.get_memgpt_personas(), + ).ask_async() + print(self.memgpt_persona) + + self.human_persona = await questionary.select( + "Which persona would you like to use?", + Config.get_user_personas(), + ).ask_async() + + self.archival_storage_index = None + self.preload_archival = await questionary.confirm( + "Would you like to preload anything into MemGPT's archival memory?" + ).ask_async() + if self.preload_archival: + self.load_type = await questionary.select( + "What would you like to load?", + choices=[ + questionary.Choice("A folder or file", value="folder"), + questionary.Choice("A SQL database", value="sql"), + questionary.Choice("A glob pattern", value="glob"), + ], + ).ask_async() + if self.load_type == "folder" or self.load_type == "sql": + archival_storage_path = await questionary.path( + "Please enter the folder or file (tab for autocomplete):" + ).ask_async() + if os.path.isdir(archival_storage_path): + self.archival_storage_files = os.path.join( + archival_storage_path, "*" + ) + else: + self.archival_storage_files = archival_storage_path + else: + self.archival_storage_files = await questionary.path( + "Please enter the glob pattern (tab for autocomplete):" + ).ask_async() + self.compute_embeddings = await questionary.confirm( + "Would you like to compute embeddings over these files to enable embeddings search?" + ).ask_async() + await self.configure_archival_storage(self.compute_embeddings) + + self.write_config() + return self + + async def configure_archival_storage(self, recompute_embeddings): + if recompute_embeddings: + self.archival_storage_index = ( + await utils.prepare_archival_index_from_files_compute_embeddings( + self.archival_storage_files + ) + ) + if self.compute_embeddings and self.archival_storage_index: + self.index, self.archival_database = utils.prepare_archival_index( + self.archival_storage_index + ) + else: + self.archival_database = utils.prepare_archival_index_from_files( + self.archival_storage_files + ) + + def to_dict(self): + return { + "model": self.model, + "memgpt_persona": self.memgpt_persona, + "human_persona": self.human_persona, + "preload_archival": self.preload_archival, + "archival_storage_files": self.archival_storage_files, + "archival_storage_index": self.archival_storage_index, + "compute_embeddings": self.compute_embeddings, + "load_type": self.load_type, + "agent_save_file": self.agent_save_file, + "persistence_manager_save_file": self.persistence_manager_save_file, + } + + def load_config(self, config_file): + with open(config_file, "rt") as f: + cfg = json.load(f) + self.model = cfg["model"] + self.memgpt_persona = cfg["memgpt_persona"] + self.human_persona = cfg["human_persona"] + self.preload_archival = cfg["preload_archival"] + self.archival_storage_files = cfg["archival_storage_files"] + self.archival_storage_index = cfg["archival_storage_index"] + self.compute_embeddings = cfg["compute_embeddings"] + self.load_type = cfg["load_type"] + self.agent_save_file = cfg["agent_save_file"] + self.persistence_manager_save_file = cfg["persistence_manager_save_file"] + + def write_config(self, configs_dir=None): + if configs_dir is None: + configs_dir = Config.configs_dir + os.makedirs(configs_dir, exist_ok=True) + if self.config_file is None: + filename = os.path.join( + configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_") + ) + self.config_file = f"{filename}.json" + with open(self.config_file, "wt") as f: + json.dump(self.to_dict(), f, indent=4) + print( + f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}" + ) + + @staticmethod + def get_memgpt_personas(dir_path=None): + if dir_path is None: + dir_path = Config.personas_dir + all_personas = Config.get_personas(dir_path) + return Config.get_persona_choices([p for p in all_personas], get_persona_text) + + @staticmethod + def get_user_personas(dir_path=None): + if dir_path is None: + dir_path = Config.humans_dir + all_personas = Config.get_personas(dir_path) + return Config.get_persona_choices([p for p in all_personas], get_human_text) + + @staticmethod + def get_personas(dir_path) -> List[str]: + files = sorted(glob.glob(os.path.join(dir_path, "*.txt"))) + stems = [] + for f in files: + filename = os.path.basename(f) + stem, _ = os.path.splitext(filename) + stems.append(stem) + return stems + + @staticmethod + def get_persona_choices(personas, text_getter): + return [ + questionary.Choice( + title=[ + ("class:question", f"{p}"), + ("class:text", f"\n{indent(text_getter(p))}"), + ], + value=p, + ) + for p in personas + ] + + @staticmethod + def get_most_recent_config(configs_dir=None): + if configs_dir is None: + configs_dir = Config.configs_dir + files = [ + os.path.join(configs_dir, f) + for f in os.listdir(configs_dir) + if os.path.isfile(os.path.join(configs_dir, f)) + ] + # Return the file with the most recent modification time + if len(files) == 0: + return None + return max(files, key=os.path.getmtime) + + +def indent(text, num_lines=5): + lines = textwrap.fill(text, width=100).split("\n") + if len(lines) > num_lines: + lines = lines[: num_lines - 1] + ["... (truncated)", lines[-1]] + return " " + "\n ".join(lines) + + +config = Config() diff --git a/main.py b/main.py index cb8c5673..ade7aa9b 100644 --- a/main.py +++ b/main.py @@ -5,7 +5,8 @@ import glob import os import sys import pickle -import readline + +import questionary from rich.console import Console @@ -25,6 +26,8 @@ from memgpt.persistence_manager import ( InMemoryStateManagerWithFaiss, ) +from config import Config + FLAGS = flags.FLAGS flags.DEFINE_string("persona", default=None, required=False, help="Specify persona") flags.DEFINE_string("human", default=None, required=False, help="Specify human") @@ -87,7 +90,7 @@ def clear_line(): sys.stdout.flush() -def save(memgpt_agent): +def save(memgpt_agent, cfg): filename = utils.get_local_time().replace(" ", "_").replace(":", "_") filename = f"{filename}.json" filename = os.path.join("saved_state", filename) @@ -96,6 +99,7 @@ def save(memgpt_agent): os.makedirs("saved_state") memgpt_agent.save_to_json_file(filename) print(f"Saved checkpoint to: {filename}") + cfg.agent_save_file = filename except Exception as e: print(f"Saving state to {filename} failed with: {e}") @@ -104,8 +108,10 @@ def save(memgpt_agent): try: memgpt_agent.persistence_manager.save(filename) print(f"Saved persistence manager to: {filename}") + cfg.persistence_manager_save_file = filename except Exception as e: print(f"Saving persistence manager to {filename} failed with: {e}") + cfg.write_config() def load(memgpt_agent, filename): @@ -156,6 +162,79 @@ async def main(): logging.getLogger().setLevel(logging.CRITICAL) if FLAGS.debug: logging.getLogger().setLevel(logging.DEBUG) + + if any( + ( + FLAGS.persona, + FLAGS.human, + FLAGS.model != constants.DEFAULT_MEMGPT_MODEL, + FLAGS.archival_storage_faiss_path, + FLAGS.archival_storage_files, + FLAGS.archival_storage_files_compute_embeddings, + FLAGS.archival_storage_sqldb, + ) + ): + interface.important_message("⚙️ Using legacy command line arguments.") + model = FLAGS.model + if model is None: + model = constants.DEFAULT_MEMGPT_MODEL + memgpt_persona = FLAGS.persona + if memgpt_persona is None: + memgpt_persona = ( + personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT + ) + human_persona = FLAGS.human + if human_persona is None: + human_persona = humans.DEFAULT + + if FLAGS.archival_storage_files: + cfg = await Config.legacy_flags_init( + model, + memgpt_persona, + human_persona, + load_type="folder", + archival_storage_files=FLAGS.archival_storage_files, + compute_embeddings=False, + ) + elif FLAGS.archival_storage_faiss_path: + cfg = await Config.legacy_flags_init( + model, + memgpt_persona, + human_persona, + load_type="folder", + archival_storage_index=FLAGS.archival_storage_index, + compute_embeddings=False, + ) + elif FLAGS.archival_storage_files_compute_embeddings: + print(model) + print(memgpt_persona) + print(human_persona) + cfg = await Config.legacy_flags_init( + model, + memgpt_persona, + human_persona, + load_type="folder", + archival_storage_files=FLAGS.archival_storage_files_compute_embeddings, + compute_embeddings=True, + ) + elif FLAGS.archival_storage_sqldb: + cfg = await Config.legacy_flags_init( + model, + memgpt_persona, + human_persona, + load_type="sql", + archival_storage_files=FLAGS.archival_storage_sqldb, + compute_embeddings=False, + ) + else: + cfg = await Config.legacy_flags_init( + model, + memgpt_persona, + human_persona, + ) + else: + cfg = await Config.config_init() + print("Running... [exit by typing '/exit']") # Azure OpenAI support @@ -190,49 +269,35 @@ async def main(): ) return - if FLAGS.model != constants.DEFAULT_MEMGPT_MODEL: + if cfg.model != constants.DEFAULT_MEMGPT_MODEL: interface.important_message( - f"Warning - you are running MemGPT with {FLAGS.model}, which is not officially supported (yet). Expect bugs!" + f"Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!" ) - if FLAGS.archival_storage_faiss_path: - index, archival_database = utils.prepare_archival_index( - FLAGS.archival_storage_faiss_path + if cfg.archival_storage_index: + persistence_manager = InMemoryStateManagerWithFaiss( + cfg.index, cfg.archival_database ) - persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database) - elif FLAGS.archival_storage_files: - archival_database = utils.prepare_archival_index_from_files( - FLAGS.archival_storage_files - ) - print(f"Preloaded {len(archival_database)} chunks into archival memory.") + elif cfg.archival_storage_files: + print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.") persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory( - archival_database + cfg.archival_database ) - elif FLAGS.archival_storage_files_compute_embeddings: - faiss_save_dir = ( - await utils.prepare_archival_index_from_files_compute_embeddings( - FLAGS.archival_storage_files_compute_embeddings - ) - ) - interface.important_message( - f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed)." - ) - index, archival_database = utils.prepare_archival_index(faiss_save_dir) - persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database) else: persistence_manager = InMemoryStateManager() + if FLAGS.archival_storage_files_compute_embeddings: + interface.important_message( + f"(legacy) To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={cfg.archival_storage_index} (if your files haven't changed)." + ) + # Moved defaults out of FLAGS so that we can dynamically select the default persona based on model - chosen_human = FLAGS.human if FLAGS.human is not None else humans.DEFAULT - chosen_persona = ( - FLAGS.persona - if FLAGS.persona is not None - else (personas.GPT35_DEFAULT if "gpt-3.5" in FLAGS.model else personas.DEFAULT) - ) + chosen_human = cfg.human_persona + chosen_persona = cfg.memgpt_persona memgpt_agent = presets.use_preset( presets.DEFAULT, - FLAGS.model, + cfg.model, personas.get_persona_text(chosen_persona), humans.get_human_text(chosen_human), interface, @@ -247,19 +312,26 @@ async def main(): user_message = None USER_GOES_FIRST = FLAGS.first - if FLAGS.archival_storage_sqldb: - if not os.path.exists(FLAGS.archival_storage_sqldb): - print(f"File {FLAGS.archival_storage_sqldb} does not exist") + if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner + if not os.path.exists(cfg.archival_storage_files): + print(f"File {cfg.archival_storage_files} does not exist") return # Ingest data from file into archival storage else: print(f"Database found! Loading database into archival memory") - data_list = utils.read_database_as_list(FLAGS.archival_storage_sqldb) + data_list = utils.read_database_as_list(cfg.archival_storage_files) user_message = f"Your archival memory has been loaded with a SQL database called {data_list[0]}, which contains schema {data_list[1]}. Remember to refer to this first while answering any user questions!" for row in data_list: await memgpt_agent.persistence_manager.archival_memory.insert(row) print(f"Database loaded into archival memory.") + if cfg.agent_save_file: + load_save_file = await questionary.confirm( + f"Load in saved agent '{cfg.agent_save_file}'?" + ).ask_async() + if load_save_file: + load(memgpt_agent, cfg.agent_save_file) + # auto-exit for if "GITHUB_ACTIONS" in os.environ: return @@ -274,7 +346,12 @@ async def main(): while True: if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST): # Ask for user input - user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ") + # user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ") + user_input = await questionary.text( + "Enter your message:", + multiline=True, + qmark=">", + ).ask_async() clear_line() if user_input.startswith("!"): @@ -289,25 +366,9 @@ async def main(): # Handle CLI commands # Commands to not get passed as input to MemGPT if user_input.startswith("/"): - if user_input == "//": - print("Entering multiline mode, type // when done") - user_input_list = [] - while True: - user_input = console.input("[bold cyan]>[/bold cyan] ") - clear_line() - if user_input == "//": - break - else: - user_input_list.append(user_input) - - # pass multiline inputs to MemGPT - user_message = system.package_user_message( - "\n".join(user_input_list) - ) - - elif user_input.lower() == "/exit": + if user_input.lower() == "/exit": # autosave - save(memgpt_agent=memgpt_agent) + save(memgpt_agent=memgpt_agent, cfg=cfg) break elif user_input.lower() == "/savechat": @@ -326,7 +387,7 @@ async def main(): continue elif user_input.lower() == "/save": - save(memgpt_agent=memgpt_agent) + save(memgpt_agent=memgpt_agent, cfg=cfg) continue elif user_input.lower() == "/load" or user_input.lower().startswith( diff --git a/requirements.txt b/requirements.txt index 484bc1e2..cc1c5688 100644 --- a/requirements.txt +++ b/requirements.txt @@ -9,6 +9,7 @@ pybars3 pymupdf python-dotenv pytz +questionary rich tiktoken timezonefinder