patch save/load, add model flag
This commit is contained in:
20
main.py
20
main.py
@@ -21,6 +21,7 @@ from memgpt.persistence_manager import InMemoryStateManager as persistence_manag
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona")
|
||||
flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Specify human")
|
||||
flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model")
|
||||
flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence")
|
||||
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
|
||||
|
||||
@@ -42,7 +43,7 @@ async def main():
|
||||
logging.getLogger().setLevel(logging.DEBUG)
|
||||
print("Running... [exit by typing 'exit']")
|
||||
|
||||
memgpt_agent = presets.use_preset(presets.DEFAULT, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager())
|
||||
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager())
|
||||
print_messages = interface.print_messages
|
||||
await print_messages(memgpt_agent.messages)
|
||||
|
||||
@@ -82,9 +83,14 @@ async def main():
|
||||
elif user_input.lower() == "/savechat":
|
||||
filename = utils.get_local_time().replace(' ', '_').replace(':', '_')
|
||||
filename = f"{filename}.pkl"
|
||||
with open(os.path.join('saved_chats', filename), 'wb') as f:
|
||||
pickle.dump(memgpt_agent.messages, f)
|
||||
print(f"Saved messages to: {filename}")
|
||||
try:
|
||||
if not os.path.exists("saved_chats"):
|
||||
os.makedirs("saved_chats")
|
||||
with open(os.path.join('saved_chats', filename), 'wb') as f:
|
||||
pickle.dump(memgpt_agent.messages, f)
|
||||
print(f"Saved messages to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving chat to {filename} failed with: {e}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/save":
|
||||
@@ -92,10 +98,12 @@ async def main():
|
||||
filename = f"{filename}.json"
|
||||
filename = os.path.join('saved_state', filename)
|
||||
try:
|
||||
if not os.path.exists("saved_state"):
|
||||
os.makedirs("saved_state")
|
||||
memgpt_agent.save_to_json_file(filename)
|
||||
print(f"Saved checkpoint to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving to {filename} failed with: {e}")
|
||||
print(f"Saving state to {filename} failed with: {e}")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
|
||||
@@ -107,6 +115,8 @@ async def main():
|
||||
print(f"Loaded checkpoint {filename}")
|
||||
except Exception as e:
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
else:
|
||||
print(f"/load error: no checkpoint specified")
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/dump":
|
||||
|
||||
Reference in New Issue
Block a user