autosave on /exit
This commit is contained in:
112
main.py
112
main.py
@@ -43,6 +43,62 @@ def clear_line():
|
||||
sys.stdout.flush()
|
||||
|
||||
|
||||
def save(memgpt_agent):
|
||||
filename = utils.get_local_time().replace(' ', '_').replace(':', '_')
|
||||
filename = f"{filename}.json"
|
||||
filename = os.path.join('saved_state', filename)
|
||||
try:
|
||||
if not os.path.exists("saved_state"):
|
||||
os.makedirs("saved_state")
|
||||
memgpt_agent.save_to_json_file(filename)
|
||||
print(f"Saved checkpoint to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving state to {filename} failed with: {e}")
|
||||
|
||||
# save the persistence manager too
|
||||
filename = filename.replace('.json', '.persistence.pickle')
|
||||
try:
|
||||
memgpt_agent.persistence_manager.save(filename)
|
||||
print(f"Saved persistence manager to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving persistence manager to {filename} failed with: {e}")
|
||||
|
||||
|
||||
def load(memgpt_agent, filename):
|
||||
if filename is not None:
|
||||
if filename[-5:] != '.json':
|
||||
filename += '.json'
|
||||
try:
|
||||
memgpt_agent.load_from_json_file_inplace(filename)
|
||||
print(f"Loaded checkpoint {filename}")
|
||||
except Exception as e:
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
else:
|
||||
# Load the latest file
|
||||
print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead")
|
||||
json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory.
|
||||
|
||||
# Check if there are any json files.
|
||||
if not json_files:
|
||||
print(f"/load error: no .json checkpoint files found")
|
||||
else:
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
filename = max(json_files, key=os.path.getmtime)
|
||||
try:
|
||||
memgpt_agent.load_from_json_file_inplace(filename)
|
||||
print(f"Loaded checkpoint {filename}")
|
||||
except Exception as e:
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
|
||||
# need to load persistence manager too
|
||||
filename = filename.replace('.json', '.persistence.pickle')
|
||||
try:
|
||||
memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods
|
||||
print(f"Loaded persistence manager from {filename}")
|
||||
except Exception as e:
|
||||
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
|
||||
|
||||
|
||||
async def main():
|
||||
utils.DEBUG = FLAGS.debug
|
||||
logging.getLogger().setLevel(logging.CRITICAL)
|
||||
@@ -162,6 +218,8 @@ async def main():
|
||||
user_message = system.package_user_message("\n".join(user_input_list))
|
||||
|
||||
elif user_input.lower() == "/exit":
|
||||
# autosave
|
||||
save(memgpt_agent=memgpt_agent)
|
||||
break
|
||||
|
||||
elif user_input.lower() == "/savechat":
|
||||
@@ -178,63 +236,13 @@ async def main():
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/save":
|
||||
filename = utils.get_local_time().replace(' ', '_').replace(':', '_')
|
||||
filename = f"{filename}.json"
|
||||
filename = os.path.join('saved_state', filename)
|
||||
try:
|
||||
if not os.path.exists("saved_state"):
|
||||
os.makedirs("saved_state")
|
||||
memgpt_agent.save_to_json_file(filename)
|
||||
print(f"Saved checkpoint to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving state to {filename} failed with: {e}")
|
||||
|
||||
# save the persistence manager too
|
||||
filename = filename.replace('.json', '.persistence.pickle')
|
||||
try:
|
||||
memgpt_agent.persistence_manager.save(filename)
|
||||
print(f"Saved persistence manager to: {filename}")
|
||||
except Exception as e:
|
||||
print(f"Saving persistence manager to {filename} failed with: {e}")
|
||||
|
||||
save(memgpt_agent=memgpt_agent)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
|
||||
command = user_input.strip().split()
|
||||
filename = command[1] if len(command) > 1 else None
|
||||
if filename is not None:
|
||||
if filename[-5:] != '.json':
|
||||
filename += '.json'
|
||||
try:
|
||||
memgpt_agent.load_from_json_file_inplace(filename)
|
||||
print(f"Loaded checkpoint {filename}")
|
||||
except Exception as e:
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
else:
|
||||
# Load the latest file
|
||||
print(f"/load warning: no checkpoint specified, loading most recent checkpoint instead")
|
||||
json_files = glob.glob("saved_state/*.json") # This will list all .json files in the current directory.
|
||||
|
||||
# Check if there are any json files.
|
||||
if not json_files:
|
||||
print(f"/load error: no .json checkpoint files found")
|
||||
else:
|
||||
# Sort files based on modified timestamp, with the latest file being the first.
|
||||
filename = max(json_files, key=os.path.getmtime)
|
||||
try:
|
||||
memgpt_agent.load_from_json_file_inplace(filename)
|
||||
print(f"Loaded checkpoint {filename}")
|
||||
except Exception as e:
|
||||
print(f"Loading {filename} failed with: {e}")
|
||||
|
||||
# need to load persistence manager too
|
||||
filename = filename.replace('.json', '.persistence.pickle')
|
||||
try:
|
||||
memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods
|
||||
print(f"Loaded persistence manager from {filename}")
|
||||
except Exception as e:
|
||||
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
|
||||
|
||||
load(memgpt_agent=memgpt_agent, filename=filename)
|
||||
continue
|
||||
|
||||
elif user_input.lower() == "/dump":
|
||||
|
||||
Reference in New Issue
Block a user