cli improvements using questionary
This commit is contained in:
3
.gitignore
vendored
3
.gitignore
vendored
@@ -79,3 +79,6 @@ dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# MemGPT config files
|
||||
configs/
|
||||
|
||||
280
config.py
Normal file
280
config.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import asyncio
|
||||
import glob
|
||||
import json
|
||||
import os
|
||||
import textwrap
|
||||
|
||||
import interface
|
||||
|
||||
import questionary
|
||||
|
||||
from colorama import Fore, Style, init
|
||||
from rich.console import Console
|
||||
|
||||
console = Console()
|
||||
|
||||
from typing import List, Type
|
||||
|
||||
import memgpt.utils as utils
|
||||
|
||||
from memgpt.personas.personas import get_persona_text
|
||||
from memgpt.humans.humans import get_human_text
|
||||
|
||||
model_choices = [
|
||||
questionary.Choice("gpt-4"),
|
||||
questionary.Choice(
|
||||
"gpt-3.5-turbo (experimental! function-calling performance is not quite at the level of gpt-4 yet)",
|
||||
value="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class Config:
|
||||
personas_dir = os.path.join("memgpt", "personas", "examples")
|
||||
humans_dir = os.path.join("memgpt", "humans", "examples")
|
||||
configs_dir = "configs"
|
||||
|
||||
def __init__(self):
|
||||
self.load_type = None
|
||||
self.archival_storage_files = None
|
||||
self.compute_embeddings = False
|
||||
self.agent_save_file = None
|
||||
self.persistence_manager_save_file = None
|
||||
|
||||
@classmethod
|
||||
async def legacy_flags_init(
|
||||
cls: Type["config"],
|
||||
model: str,
|
||||
memgpt_persona: str,
|
||||
human_persona: str,
|
||||
load_type: str = None,
|
||||
archival_storage_files: str = None,
|
||||
archival_storage_index: str = None,
|
||||
compute_embeddings: bool = False,
|
||||
):
|
||||
self = cls()
|
||||
self.model = model
|
||||
self.memgpt_persona = memgpt_persona
|
||||
self.human_persona = human_persona
|
||||
self.load_type = load_type
|
||||
self.archival_storage_files = archival_storage_files
|
||||
self.archival_storage_index = archival_storage_index
|
||||
self.compute_embeddings = compute_embeddings
|
||||
recompute_embeddings = self.compute_embeddings
|
||||
if self.archival_storage_index:
|
||||
recompute_embeddings = questionary.confirm(
|
||||
f"Would you like to recompute embeddings? Do this if your files have changed.\nFiles:{self.archival_storage_files}",
|
||||
default=False,
|
||||
)
|
||||
await self.configure_archival_storage(recompute_embeddings)
|
||||
return self
|
||||
|
||||
@classmethod
|
||||
async def config_init(cls: Type["Config"], config_file: str = None):
|
||||
self = cls()
|
||||
self.config_file = config_file
|
||||
if self.config_file is None:
|
||||
cfg = Config.get_most_recent_config()
|
||||
use_cfg = False
|
||||
if cfg:
|
||||
print(
|
||||
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Found saved config file.{Style.RESET_ALL}"
|
||||
)
|
||||
use_cfg = await questionary.confirm(
|
||||
f"Use most recent config file '{cfg}'?"
|
||||
).ask_async()
|
||||
if use_cfg:
|
||||
self.config_file = cfg
|
||||
|
||||
if self.config_file:
|
||||
self.load_config(self.config_file)
|
||||
recompute_embeddings = False
|
||||
if self.compute_embeddings:
|
||||
if self.archival_storage_index:
|
||||
recompute_embeddings = await questionary.confirm(
|
||||
f"Would you like to recompute embeddings? Do this if your files have changed.\n Files: {self.archival_storage_files}",
|
||||
default=False,
|
||||
).ask_async()
|
||||
else:
|
||||
recompute_embeddings = True
|
||||
if self.load_type:
|
||||
await self.configure_archival_storage(recompute_embeddings)
|
||||
self.write_config()
|
||||
return self
|
||||
|
||||
# print("No settings file found, configuring MemGPT...")
|
||||
print(
|
||||
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ No settings file found, configuring MemGPT...{Style.RESET_ALL}"
|
||||
)
|
||||
|
||||
self.model = await questionary.select(
|
||||
"Which model would you like to use?",
|
||||
model_choices,
|
||||
default=model_choices[0],
|
||||
).ask_async()
|
||||
|
||||
self.memgpt_persona = await questionary.select(
|
||||
"Which persona would you like MemGPT to use?",
|
||||
Config.get_memgpt_personas(),
|
||||
).ask_async()
|
||||
print(self.memgpt_persona)
|
||||
|
||||
self.human_persona = await questionary.select(
|
||||
"Which persona would you like to use?",
|
||||
Config.get_user_personas(),
|
||||
).ask_async()
|
||||
|
||||
self.archival_storage_index = None
|
||||
self.preload_archival = await questionary.confirm(
|
||||
"Would you like to preload anything into MemGPT's archival memory?"
|
||||
).ask_async()
|
||||
if self.preload_archival:
|
||||
self.load_type = await questionary.select(
|
||||
"What would you like to load?",
|
||||
choices=[
|
||||
questionary.Choice("A folder or file", value="folder"),
|
||||
questionary.Choice("A SQL database", value="sql"),
|
||||
questionary.Choice("A glob pattern", value="glob"),
|
||||
],
|
||||
).ask_async()
|
||||
if self.load_type == "folder" or self.load_type == "sql":
|
||||
archival_storage_path = await questionary.path(
|
||||
"Please enter the folder or file (tab for autocomplete):"
|
||||
).ask_async()
|
||||
if os.path.isdir(archival_storage_path):
|
||||
self.archival_storage_files = os.path.join(
|
||||
archival_storage_path, "*"
|
||||
)
|
||||
else:
|
||||
self.archival_storage_files = archival_storage_path
|
||||
else:
|
||||
self.archival_storage_files = await questionary.path(
|
||||
"Please enter the glob pattern (tab for autocomplete):"
|
||||
).ask_async()
|
||||
self.compute_embeddings = await questionary.confirm(
|
||||
"Would you like to compute embeddings over these files to enable embeddings search?"
|
||||
).ask_async()
|
||||
await self.configure_archival_storage(self.compute_embeddings)
|
||||
|
||||
self.write_config()
|
||||
return self
|
||||
|
||||
async def configure_archival_storage(self, recompute_embeddings):
|
||||
if recompute_embeddings:
|
||||
self.archival_storage_index = (
|
||||
await utils.prepare_archival_index_from_files_compute_embeddings(
|
||||
self.archival_storage_files
|
||||
)
|
||||
)
|
||||
if self.compute_embeddings and self.archival_storage_index:
|
||||
self.index, self.archival_database = utils.prepare_archival_index(
|
||||
self.archival_storage_index
|
||||
)
|
||||
else:
|
||||
self.archival_database = utils.prepare_archival_index_from_files(
|
||||
self.archival_storage_files
|
||||
)
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"model": self.model,
|
||||
"memgpt_persona": self.memgpt_persona,
|
||||
"human_persona": self.human_persona,
|
||||
"preload_archival": self.preload_archival,
|
||||
"archival_storage_files": self.archival_storage_files,
|
||||
"archival_storage_index": self.archival_storage_index,
|
||||
"compute_embeddings": self.compute_embeddings,
|
||||
"load_type": self.load_type,
|
||||
"agent_save_file": self.agent_save_file,
|
||||
"persistence_manager_save_file": self.persistence_manager_save_file,
|
||||
}
|
||||
|
||||
def load_config(self, config_file):
|
||||
with open(config_file, "rt") as f:
|
||||
cfg = json.load(f)
|
||||
self.model = cfg["model"]
|
||||
self.memgpt_persona = cfg["memgpt_persona"]
|
||||
self.human_persona = cfg["human_persona"]
|
||||
self.preload_archival = cfg["preload_archival"]
|
||||
self.archival_storage_files = cfg["archival_storage_files"]
|
||||
self.archival_storage_index = cfg["archival_storage_index"]
|
||||
self.compute_embeddings = cfg["compute_embeddings"]
|
||||
self.load_type = cfg["load_type"]
|
||||
self.agent_save_file = cfg["agent_save_file"]
|
||||
self.persistence_manager_save_file = cfg["persistence_manager_save_file"]
|
||||
|
||||
def write_config(self, configs_dir=None):
|
||||
if configs_dir is None:
|
||||
configs_dir = Config.configs_dir
|
||||
os.makedirs(configs_dir, exist_ok=True)
|
||||
if self.config_file is None:
|
||||
filename = os.path.join(
|
||||
configs_dir, utils.get_local_time().replace(" ", "_").replace(":", "_")
|
||||
)
|
||||
self.config_file = f"{filename}.json"
|
||||
with open(self.config_file, "wt") as f:
|
||||
json.dump(self.to_dict(), f, indent=4)
|
||||
print(
|
||||
f"{Style.BRIGHT}{Fore.MAGENTA}⚙️ Saved config file to {self.config_file}.{Style.RESET_ALL}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def get_memgpt_personas(dir_path=None):
|
||||
if dir_path is None:
|
||||
dir_path = Config.personas_dir
|
||||
all_personas = Config.get_personas(dir_path)
|
||||
return Config.get_persona_choices([p for p in all_personas], get_persona_text)
|
||||
|
||||
@staticmethod
|
||||
def get_user_personas(dir_path=None):
|
||||
if dir_path is None:
|
||||
dir_path = Config.humans_dir
|
||||
all_personas = Config.get_personas(dir_path)
|
||||
return Config.get_persona_choices([p for p in all_personas], get_human_text)
|
||||
|
||||
@staticmethod
|
||||
def get_personas(dir_path) -> List[str]:
|
||||
files = sorted(glob.glob(os.path.join(dir_path, "*.txt")))
|
||||
stems = []
|
||||
for f in files:
|
||||
filename = os.path.basename(f)
|
||||
stem, _ = os.path.splitext(filename)
|
||||
stems.append(stem)
|
||||
return stems
|
||||
|
||||
@staticmethod
|
||||
def get_persona_choices(personas, text_getter):
|
||||
return [
|
||||
questionary.Choice(
|
||||
title=[
|
||||
("class:question", f"{p}"),
|
||||
("class:text", f"\n{indent(text_getter(p))}"),
|
||||
],
|
||||
value=p,
|
||||
)
|
||||
for p in personas
|
||||
]
|
||||
|
||||
@staticmethod
|
||||
def get_most_recent_config(configs_dir=None):
|
||||
if configs_dir is None:
|
||||
configs_dir = Config.configs_dir
|
||||
files = [
|
||||
os.path.join(configs_dir, f)
|
||||
for f in os.listdir(configs_dir)
|
||||
if os.path.isfile(os.path.join(configs_dir, f))
|
||||
]
|
||||
# Return the file with the most recent modification time
|
||||
if len(files) == 0:
|
||||
return None
|
||||
return max(files, key=os.path.getmtime)
|
||||
|
||||
|
||||
def indent(text, num_lines=5):
|
||||
lines = textwrap.fill(text, width=100).split("\n")
|
||||
if len(lines) > num_lines:
|
||||
lines = lines[: num_lines - 1] + ["... (truncated)", lines[-1]]
|
||||
return " " + "\n ".join(lines)
|
||||
|
||||
|
||||
config = Config()
|
||||
173
main.py
173
main.py
@@ -5,7 +5,8 @@ import glob
|
||||
import os
|
||||
import sys
|
||||
import pickle
|
||||
import readline
|
||||
|
||||
import questionary
|
||||
|
||||
from rich.console import Console
|
||||
|
||||
@@ -25,6 +26,8 @@ from memgpt.persistence_manager import (
|
||||
InMemoryStateManagerWithFaiss,
|
||||
)
|
||||
|
||||
from config import Config
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("persona", default=None, required=False, help="Specify persona")
|
||||
flags.DEFINE_string("human", default=None, required=False, help="Specify human")
|
||||
@@ -87,7 +90,7 @@ def clear_line():
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def save(memgpt_agent):
|
||||
def save(memgpt_agent, cfg):
|
||||
filename = utils.get_local_time().replace(" ", "_").replace(":", "_")
|
||||
filename = f"{filename}.json"
|
||||
filename = os.path.join("saved_state", filename)
|
||||
@@ -96,6 +99,7 @@ def save(memgpt_agent):
|
||||
os.makedirs("saved_state")
|
||||
memgpt_agent.save_to_json_file(filename)
|
||||
print(f"Saved checkpoint to: {filename}")
|
||||
cfg.agent_save_file = filename
|
||||
except Exception as e:
|
||||
print(f"Saving state to {filename} failed with: {e}")
|
||||
|
||||
@@ -104,8 +108,10 @@ def save(memgpt_agent):
|
||||
try:
|
||||
memgpt_agent.persistence_manager.save(filename)
|
||||
print(f"Saved persistence manager to: {filename}")
|
||||
cfg.persistence_manager_save_file = filename
|
||||
except Exception as e:
|
||||
print(f"Saving persistence manager to {filename} failed with: {e}")
|
||||
cfg.write_config()
|
||||
|
||||
|
||||
def load(memgpt_agent, filename):
|
||||
@@ -156,6 +162,79 @@ async def main():
|
||||
logging.getLogger().setLevel(logging.CRITICAL)
|
||||
if FLAGS.debug:
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
|
||||
if any(
|
||||
(
|
||||
FLAGS.persona,
|
||||
FLAGS.human,
|
||||
FLAGS.model != constants.DEFAULT_MEMGPT_MODEL,
|
||||
FLAGS.archival_storage_faiss_path,
|
||||
FLAGS.archival_storage_files,
|
||||
FLAGS.archival_storage_files_compute_embeddings,
|
||||
FLAGS.archival_storage_sqldb,
|
||||
)
|
||||
):
|
||||
interface.important_message("⚙️ Using legacy command line arguments.")
|
||||
model = FLAGS.model
|
||||
if model is None:
|
||||
model = constants.DEFAULT_MEMGPT_MODEL
|
||||
memgpt_persona = FLAGS.persona
|
||||
if memgpt_persona is None:
|
||||
memgpt_persona = (
|
||||
personas.GPT35_DEFAULT if "gpt-3.5" in model else personas.DEFAULT
|
||||
)
|
||||
human_persona = FLAGS.human
|
||||
if human_persona is None:
|
||||
human_persona = humans.DEFAULT
|
||||
|
||||
if FLAGS.archival_storage_files:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="folder",
|
||||
archival_storage_files=FLAGS.archival_storage_files,
|
||||
compute_embeddings=False,
|
||||
)
|
||||
elif FLAGS.archival_storage_faiss_path:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="folder",
|
||||
archival_storage_index=FLAGS.archival_storage_index,
|
||||
compute_embeddings=False,
|
||||
)
|
||||
elif FLAGS.archival_storage_files_compute_embeddings:
|
||||
print(model)
|
||||
print(memgpt_persona)
|
||||
print(human_persona)
|
||||
cfg = await Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="folder",
|
||||
archival_storage_files=FLAGS.archival_storage_files_compute_embeddings,
|
||||
compute_embeddings=True,
|
||||
)
|
||||
elif FLAGS.archival_storage_sqldb:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
load_type="sql",
|
||||
archival_storage_files=FLAGS.archival_storage_sqldb,
|
||||
compute_embeddings=False,
|
||||
)
|
||||
else:
|
||||
cfg = await Config.legacy_flags_init(
|
||||
model,
|
||||
memgpt_persona,
|
||||
human_persona,
|
||||
)
|
||||
else:
|
||||
cfg = await Config.config_init()
|
||||
|
||||
print("Running... [exit by typing '/exit']")
|
||||
|
||||
# Azure OpenAI support
|
||||
@@ -190,49 +269,35 @@ async def main():
|
||||
)
|
||||
return
|
||||
|
||||
if FLAGS.model != constants.DEFAULT_MEMGPT_MODEL:
|
||||
if cfg.model != constants.DEFAULT_MEMGPT_MODEL:
|
||||
interface.important_message(
|
||||
f"Warning - you are running MemGPT with {FLAGS.model}, which is not officially supported (yet). Expect bugs!"
|
||||
f"Warning - you are running MemGPT with {cfg.model}, which is not officially supported (yet). Expect bugs!"
|
||||
)
|
||||
|
||||
if FLAGS.archival_storage_faiss_path:
|
||||
index, archival_database = utils.prepare_archival_index(
|
||||
FLAGS.archival_storage_faiss_path
|
||||
if cfg.archival_storage_index:
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(
|
||||
cfg.index, cfg.archival_database
|
||||
)
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
|
||||
elif FLAGS.archival_storage_files:
|
||||
archival_database = utils.prepare_archival_index_from_files(
|
||||
FLAGS.archival_storage_files
|
||||
)
|
||||
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
|
||||
elif cfg.archival_storage_files:
|
||||
print(f"Preloaded {len(cfg.archival_database)} chunks into archival memory.")
|
||||
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(
|
||||
archival_database
|
||||
cfg.archival_database
|
||||
)
|
||||
elif FLAGS.archival_storage_files_compute_embeddings:
|
||||
faiss_save_dir = (
|
||||
await utils.prepare_archival_index_from_files_compute_embeddings(
|
||||
FLAGS.archival_storage_files_compute_embeddings
|
||||
)
|
||||
)
|
||||
interface.important_message(
|
||||
f"To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={faiss_save_dir} (if your files haven't changed)."
|
||||
)
|
||||
index, archival_database = utils.prepare_archival_index(faiss_save_dir)
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
|
||||
else:
|
||||
persistence_manager = InMemoryStateManager()
|
||||
|
||||
if FLAGS.archival_storage_files_compute_embeddings:
|
||||
interface.important_message(
|
||||
f"(legacy) To avoid computing embeddings next time, replace --archival_storage_files_compute_embeddings={FLAGS.archival_storage_files_compute_embeddings} with\n\t --archival_storage_faiss_path={cfg.archival_storage_index} (if your files haven't changed)."
|
||||
)
|
||||
|
||||
# Moved defaults out of FLAGS so that we can dynamically select the default persona based on model
|
||||
chosen_human = FLAGS.human if FLAGS.human is not None else humans.DEFAULT
|
||||
chosen_persona = (
|
||||
FLAGS.persona
|
||||
if FLAGS.persona is not None
|
||||
else (personas.GPT35_DEFAULT if "gpt-3.5" in FLAGS.model else personas.DEFAULT)
|
||||
)
|
||||
chosen_human = cfg.human_persona
|
||||
chosen_persona = cfg.memgpt_persona
|
||||
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT,
|
||||
FLAGS.model,
|
||||
cfg.model,
|
||||
personas.get_persona_text(chosen_persona),
|
||||
humans.get_human_text(chosen_human),
|
||||
interface,
|
||||
@@ -247,19 +312,26 @@ async def main():
|
||||
user_message = None
|
||||
USER_GOES_FIRST = FLAGS.first
|
||||
|
||||
if FLAGS.archival_storage_sqldb:
|
||||
if not os.path.exists(FLAGS.archival_storage_sqldb):
|
||||
print(f"File {FLAGS.archival_storage_sqldb} does not exist")
|
||||
if cfg.load_type == "sql": # TODO: move this into config.py in a clean manner
|
||||
if not os.path.exists(cfg.archival_storage_files):
|
||||
print(f"File {cfg.archival_storage_files} does not exist")
|
||||
return
|
||||
# Ingest data from file into archival storage
|
||||
else:
|
||||
print(f"Database found! Loading database into archival memory")
|
||||
data_list = utils.read_database_as_list(FLAGS.archival_storage_sqldb)
|
||||
data_list = utils.read_database_as_list(cfg.archival_storage_files)
|
||||
user_message = f"Your archival memory has been loaded with a SQL database called {data_list[0]}, which contains schema {data_list[1]}. Remember to refer to this first while answering any user questions!"
|
||||
for row in data_list:
|
||||
await memgpt_agent.persistence_manager.archival_memory.insert(row)
|
||||
print(f"Database loaded into archival memory.")
|
||||
|
||||
if cfg.agent_save_file:
|
||||
load_save_file = await questionary.confirm(
|
||||
f"Load in saved agent '{cfg.agent_save_file}'?"
|
||||
).ask_async()
|
||||
if load_save_file:
|
||||
load(memgpt_agent, cfg.agent_save_file)
|
||||
|
||||
# auto-exit for
|
||||
if "GITHUB_ACTIONS" in os.environ:
|
||||
return
|
||||
@@ -274,7 +346,12 @@ async def main():
|
||||
while True:
|
||||
if not skip_next_user_input and (counter > 0 or USER_GOES_FIRST):
|
||||
# Ask for user input
|
||||
user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ")
|
||||
# user_input = console.input("[bold cyan]Enter your message:[/bold cyan] ")
|
||||
user_input = await questionary.text(
|
||||
"Enter your message:",
|
||||
multiline=True,
|
||||
qmark=">",
|
||||
).ask_async()
|
||||
clear_line()
|
||||
|
||||
if user_input.startswith("!"):
|
||||
@@ -289,25 +366,9 @@ async def main():
|
||||
# Handle CLI commands
|
||||
# Commands to not get passed as input to MemGPT
|
||||
if user_input.startswith("/"):
|
||||
if user_input == "//":
|
||||
print("Entering multiline mode, type // when done")
|
||||
user_input_list = []
|
||||
while True:
|
||||
user_input = console.input("[bold cyan]>[/bold cyan] ")
|
||||
clear_line()
|
||||
if user_input == "//":
|
||||
break
|
||||
else:
|
||||
user_input_list.append(user_input)
|
||||
|
||||
# pass multiline inputs to MemGPT
|
||||
user_message = system.package_user_message(
|
||||
"\n".join(user_input_list)
|
||||
)
|
||||
|
||||
elif user_input.lower() == "/exit":
|
||||
if user_input.lower() == "/exit":
|
||||
# autosave
|
||||
save(memgpt_agent=memgpt_agent)
|
||||
save(memgpt_agent=memgpt_agent, cfg=cfg)
|
||||
break
|
||||
|
||||
elif user_input.lower() == "/savechat":
|
||||
@@ -326,7 +387,7 @@ async def main():
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/save":
|
||||
save(memgpt_agent=memgpt_agent)
|
||||
save(memgpt_agent=memgpt_agent, cfg=cfg)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/load" or user_input.lower().startswith(
|
||||
|
||||
@@ -9,6 +9,7 @@ pybars3
|
||||
pymupdf
|
||||
python-dotenv
|
||||
pytz
|
||||
questionary
|
||||
rich
|
||||
tiktoken
|
||||
timezonefinder
|
||||
|
||||
Reference in New Issue
Block a user