patch save/load, add model flag

This commit is contained in:
cpacker
2023-10-14 17:58:35 -07:00
parent 3557bed760
commit a081b0015d

20
main.py
View File

@@ -21,6 +21,7 @@ from memgpt.persistence_manager import InMemoryStateManager as persistence_manag
FLAGS = flags.FLAGS
flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona")
flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Specify human")
flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model")
flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence")
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
@@ -42,7 +43,7 @@ async def main():
logging.getLogger().setLevel(logging.DEBUG)
print("Running... [exit by typing 'exit']")
memgpt_agent = presets.use_preset(presets.DEFAULT, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager())
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager())
print_messages = interface.print_messages
await print_messages(memgpt_agent.messages)
@@ -82,9 +83,14 @@ async def main():
elif user_input.lower() == "/savechat":
filename = utils.get_local_time().replace(' ', '_').replace(':', '_')
filename = f"{filename}.pkl"
with open(os.path.join('saved_chats', filename), 'wb') as f:
pickle.dump(memgpt_agent.messages, f)
print(f"Saved messages to: {filename}")
try:
if not os.path.exists("saved_chats"):
os.makedirs("saved_chats")
with open(os.path.join('saved_chats', filename), 'wb') as f:
pickle.dump(memgpt_agent.messages, f)
print(f"Saved messages to: {filename}")
except Exception as e:
print(f"Saving chat to {filename} failed with: {e}")
continue
elif user_input.lower() == "/save":
@@ -92,10 +98,12 @@ async def main():
filename = f"{filename}.json"
filename = os.path.join('saved_state', filename)
try:
if not os.path.exists("saved_state"):
os.makedirs("saved_state")
memgpt_agent.save_to_json_file(filename)
print(f"Saved checkpoint to: {filename}")
except Exception as e:
print(f"Saving to {filename} failed with: {e}")
print(f"Saving state to {filename} failed with: {e}")
continue
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
@@ -107,6 +115,8 @@ async def main():
print(f"Loaded checkpoint {filename}")
except Exception as e:
print(f"Loading {filename} failed with: {e}")
else:
print(f"/load error: no checkpoint specified")
continue
elif user_input.lower() == "/dump":