243 lines
9.7 KiB
Python
243 lines
9.7 KiB
Python
from abc import ABC, abstractmethod
|
|
import os
|
|
import pickle
|
|
from memgpt.config import AgentConfig
|
|
from .memory import (
|
|
DummyRecallMemory,
|
|
DummyRecallMemoryWithEmbeddings,
|
|
DummyArchivalMemory,
|
|
DummyArchivalMemoryWithEmbeddings,
|
|
DummyArchivalMemoryWithFaiss,
|
|
EmbeddingArchivalMemory,
|
|
)
|
|
from .utils import get_local_time, printd
|
|
|
|
|
|
class PersistenceManager(ABC):
|
|
@abstractmethod
|
|
def trim_messages(self, num):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def prepend_to_messages(self, added_messages):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def append_to_messages(self, added_messages):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def swap_system_message(self, new_system_message):
|
|
pass
|
|
|
|
@abstractmethod
|
|
def update_memory(self, new_memory):
|
|
pass
|
|
|
|
|
|
class InMemoryStateManager(PersistenceManager):
|
|
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
|
|
|
recall_memory_cls = DummyRecallMemory
|
|
archival_memory_cls = DummyArchivalMemory
|
|
|
|
def __init__(self):
|
|
# Memory held in-state useful for debugging stateful versions
|
|
self.memory = None
|
|
self.messages = []
|
|
self.all_messages = []
|
|
|
|
@staticmethod
|
|
def load(filename):
|
|
with open(filename, "rb") as f:
|
|
return pickle.load(f)
|
|
|
|
def save(self, filename):
|
|
with open(filename, "wb") as fh:
|
|
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
|
|
|
|
def init(self, agent):
|
|
printd(f"Initializing InMemoryStateManager with agent object")
|
|
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.memory = agent.memory
|
|
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
|
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
|
|
|
# Persistence manager also handles DB-related state
|
|
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
|
self.archival_memory_db = []
|
|
self.archival_memory = self.archival_memory_cls(archival_memory_database=self.archival_memory_db)
|
|
|
|
def trim_messages(self, num):
|
|
# printd(f"InMemoryStateManager.trim_messages")
|
|
self.messages = [self.messages[0]] + self.messages[num:]
|
|
|
|
def prepend_to_messages(self, added_messages):
|
|
# first tag with timestamps
|
|
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
|
|
printd(f"InMemoryStateManager.prepend_to_message")
|
|
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
|
self.all_messages.extend(added_messages)
|
|
|
|
def append_to_messages(self, added_messages):
|
|
# first tag with timestamps
|
|
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
|
|
printd(f"InMemoryStateManager.append_to_messages")
|
|
self.messages = self.messages + added_messages
|
|
self.all_messages.extend(added_messages)
|
|
|
|
def swap_system_message(self, new_system_message):
|
|
# first tag with timestamps
|
|
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
|
|
|
printd(f"InMemoryStateManager.swap_system_message")
|
|
self.messages[0] = new_system_message
|
|
self.all_messages.append(new_system_message)
|
|
|
|
def update_memory(self, new_memory):
|
|
printd(f"InMemoryStateManager.update_memory")
|
|
self.memory = new_memory
|
|
|
|
|
|
class LocalStateManager(PersistenceManager):
|
|
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
|
|
|
recall_memory_cls = DummyRecallMemory
|
|
archival_memory_cls = EmbeddingArchivalMemory
|
|
|
|
def __init__(self, agent_config: AgentConfig):
|
|
# Memory held in-state useful for debugging stateful versions
|
|
self.memory = None
|
|
self.messages = []
|
|
self.all_messages = []
|
|
self.archival_memory = EmbeddingArchivalMemory(agent_config)
|
|
self.recall_memory = None
|
|
self.agent_config = agent_config
|
|
|
|
@classmethod
|
|
def load(cls, filename, agent_config: AgentConfig):
|
|
""" Load a LocalStateManager from a file. """ ""
|
|
with open(filename, "rb") as f:
|
|
data = pickle.load(f)
|
|
|
|
manager = cls(agent_config)
|
|
manager.all_messages = data["all_messages"]
|
|
manager.messages = data["messages"]
|
|
manager.recall_memory = data["recall_memory"]
|
|
manager.archival_memory = EmbeddingArchivalMemory(agent_config)
|
|
return manager
|
|
|
|
def save(self, filename):
|
|
with open(filename, "wb") as fh:
|
|
## TODO: fix this hacky solution to pickle the retriever
|
|
self.archival_memory.save()
|
|
pickle.dump(
|
|
{
|
|
"recall_memory": self.recall_memory,
|
|
"messages": self.messages,
|
|
"all_messages": self.all_messages,
|
|
},
|
|
fh,
|
|
protocol=pickle.HIGHEST_PROTOCOL,
|
|
)
|
|
printd(f"Saved state to {fh}")
|
|
|
|
def init(self, agent):
|
|
printd(f"Initializing InMemoryStateManager with agent object")
|
|
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.memory = agent.memory
|
|
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
|
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
|
|
|
# Persistence manager also handles DB-related state
|
|
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
|
|
|
# TODO: init archival memory here?
|
|
|
|
def trim_messages(self, num):
|
|
# printd(f"InMemoryStateManager.trim_messages")
|
|
self.messages = [self.messages[0]] + self.messages[num:]
|
|
|
|
def prepend_to_messages(self, added_messages):
|
|
# first tag with timestamps
|
|
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
|
|
printd(f"InMemoryStateManager.prepend_to_message")
|
|
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
|
self.all_messages.extend(added_messages)
|
|
|
|
def append_to_messages(self, added_messages):
|
|
# first tag with timestamps
|
|
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
|
|
|
printd(f"InMemoryStateManager.append_to_messages")
|
|
self.messages = self.messages + added_messages
|
|
self.all_messages.extend(added_messages)
|
|
|
|
def swap_system_message(self, new_system_message):
|
|
# first tag with timestamps
|
|
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
|
|
|
printd(f"InMemoryStateManager.swap_system_message")
|
|
self.messages[0] = new_system_message
|
|
self.all_messages.append(new_system_message)
|
|
|
|
def update_memory(self, new_memory):
|
|
printd(f"InMemoryStateManager.update_memory")
|
|
self.memory = new_memory
|
|
|
|
|
|
class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
|
|
archival_memory_cls = DummyArchivalMemory
|
|
recall_memory_cls = DummyRecallMemory
|
|
|
|
def __init__(self, archival_memory_db):
|
|
self.archival_memory_db = archival_memory_db
|
|
|
|
def init(self, agent):
|
|
print(f"Initializing InMemoryStateManager with agent object")
|
|
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.memory = agent.memory
|
|
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
|
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
|
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
|
self.archival_memory = self.archival_memory_cls(archival_memory_database=self.archival_memory_db)
|
|
|
|
|
|
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
|
|
archival_memory_cls = DummyArchivalMemoryWithEmbeddings
|
|
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
|
|
|
|
|
class InMemoryStateManagerWithFaiss(InMemoryStateManager):
|
|
archival_memory_cls = DummyArchivalMemoryWithFaiss
|
|
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
|
|
|
def __init__(self, archival_index, archival_memory_db, a_k=100):
|
|
super().__init__()
|
|
self.archival_index = archival_index
|
|
self.archival_memory_db = archival_memory_db
|
|
self.a_k = a_k
|
|
|
|
def save(self, _filename):
|
|
raise NotImplementedError
|
|
|
|
def init(self, agent):
|
|
print(f"Initializing InMemoryStateManager with agent object")
|
|
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
|
self.memory = agent.memory
|
|
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
|
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
|
|
|
# Persistence manager also handles DB-related state
|
|
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
|
self.archival_memory = self.archival_memory_cls(
|
|
index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k
|
|
)
|