Merge pull request #27 from cpacker/persistence_save_fix

fixed bug where persistence manager was not saving in demo CLI
This commit is contained in:
Charles Packer
2023-10-17 23:42:17 -07:00
committed by GitHub
3 changed files with 35 additions and 4 deletions

18
main.py
View File

@@ -132,6 +132,15 @@ async def main():
print(f"Saved checkpoint to: {filename}")
except Exception as e:
print(f"Saving state to {filename} failed with: {e}")
# save the persistence manager too
filename = filename.replace('.json', '.persistence.pickle')
try:
memgpt_agent.persistence_manager.save(filename)
print(f"Saved persistence manager to: {filename}")
except Exception as e:
print(f"Saving persistence manager to {filename} failed with: {e}")
continue
elif user_input.lower() == "/load" or user_input.lower().startswith("/load "):
@@ -145,6 +154,15 @@ async def main():
print(f"Loading {filename} failed with: {e}")
else:
print(f"/load error: no checkpoint specified")
# need to load persistence manager too
filename = filename.replace('.json', '.persistence.pickle')
try:
memgpt_agent.persistence_manager = InMemoryStateManager.load(filename) # TODO(fixme):for different types of persistence managers that require different load/save methods
print(f"Loaded persistence manager from: {filename}")
except Exception as e:
print(f"/load error: loading persistence manager from {filename} failed with: {e}")
continue
elif user_input.lower() == "/dump":

View File

@@ -250,7 +250,7 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
"""Dummy in-memory version of an archival memory database, using a FAISS
index for fast nearest-neighbors embedding search.
Archival memory is effectively "infinite" overflow for core memory,
and is read-only via string queries.
@@ -291,7 +291,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
"""Simple embedding-based search (inefficient, no caching)"""
# see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb
# query_embedding = get_embedding(query_string, model=self.embedding_model)
# query_embedding = get_embedding(query_string, model=self.embedding_model)
# our wrapped version supports backoff/rate-limits
if query_string in self.embeddings_dict:
query_embedding = self.embeddings_dict[query_string]

View File

@@ -1,4 +1,5 @@
from abc import ABC, abstractmethod
import pickle
from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss
from .utils import get_local_time, printd
@@ -39,6 +40,15 @@ class InMemoryStateManager(PersistenceManager):
self.messages = []
self.all_messages = []
@staticmethod
def load(filename):
with open(filename, 'rb') as f:
return pickle.load(f)
def save(self, filename):
with open(filename, 'wb') as fh:
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
def init(self, agent):
printd(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
@@ -91,7 +101,7 @@ class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
def __init__(self, archival_memory_db):
self.archival_memory_db = archival_memory_db
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
@@ -117,7 +127,10 @@ class InMemoryStateManagerWithFaiss(InMemoryStateManager):
self.archival_index = archival_index
self.archival_memory_db = archival_memory_db
self.a_k = a_k
def save(self, _filename):
raise NotImplementedError
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]