From a081b0015d049b23d7506a72eb13d8a7caec2fa4 Mon Sep 17 00:00:00 2001 From: cpacker Date: Sat, 14 Oct 2023 17:58:35 -0700 Subject: [PATCH] patch save/load, add model flag --- main.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/main.py b/main.py index f54513a9..636c6bf6 100644 --- a/main.py +++ b/main.py @@ -21,6 +21,7 @@ from memgpt.persistence_manager import InMemoryStateManager as persistence_manag FLAGS = flags.FLAGS flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona") flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Specify human") +flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model") flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence") flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output") @@ -42,7 +43,7 @@ async def main(): logging.getLogger().setLevel(logging.DEBUG) print("Running... [exit by typing 'exit']") - memgpt_agent = presets.use_preset(presets.DEFAULT, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager()) + memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager()) print_messages = interface.print_messages await print_messages(memgpt_agent.messages) @@ -82,9 +83,14 @@ async def main(): elif user_input.lower() == "/savechat": filename = utils.get_local_time().replace(' ', '_').replace(':', '_') filename = f"{filename}.pkl" - with open(os.path.join('saved_chats', filename), 'wb') as f: - pickle.dump(memgpt_agent.messages, f) - print(f"Saved messages to: {filename}") + try: + if not os.path.exists("saved_chats"): + os.makedirs("saved_chats") + with open(os.path.join('saved_chats', filename), 'wb') as f: + pickle.dump(memgpt_agent.messages, f) + print(f"Saved messages to: {filename}") + except Exception as e: + print(f"Saving chat to {filename} failed with: {e}") continue elif user_input.lower() == "/save": @@ -92,10 +98,12 @@ async def main(): filename = f"{filename}.json" filename = os.path.join('saved_state', filename) try: + if not os.path.exists("saved_state"): + os.makedirs("saved_state") memgpt_agent.save_to_json_file(filename) print(f"Saved checkpoint to: {filename}") except Exception as e: - print(f"Saving to {filename} failed with: {e}") + print(f"Saving state to {filename} failed with: {e}") continue elif user_input.lower() == "/load" or user_input.lower().startswith("/load "): @@ -107,6 +115,8 @@ async def main(): print(f"Loaded checkpoint {filename}") except Exception as e: print(f"Loading {filename} failed with: {e}") + else: + print(f"/load error: no checkpoint specified") continue elif user_input.lower() == "/dump":