Files
letta-server/memgpt/connectors/local.py

314 lines
12 KiB
Python

from typing import Optional, List, Iterator
import shutil
from memgpt.config import AgentConfig, MemGPTConfig
from tqdm import tqdm
import re
import pickle
import os
import json
import glob
from typing import List, Optional, Dict
from abc import abstractmethod
from llama_index import VectorStoreIndex, ServiceContext, set_global_service_context
from llama_index.indices.empty.base import EmptyIndex
from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import TextNode
from memgpt.constants import MEMGPT_DIR
from memgpt.data_types import Record
from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.config import AgentConfig, MemGPTConfig
from memgpt.utils import printd, get_local_time, parse_formatted_time
from memgpt.data_types import Message, Passage, Record
# class VectorIndexStorageConnector(StorageConnector):
# """Local storage connector based on LlamaIndex"""
# def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
# super().__init__(table_type=table_type, agent_config=agent_config)
# config = MemGPTConfig.load()
## TODO: add asserts to avoid both being passed
# if agent_config is not None:
# self.name = agent_config.name
# self.save_directory = agent_config.save_agent_index_dir()
# else:
# self.name = name
# self.save_directory = f"{MEMGPT_DIR}/archival/{name}"
## llama index contexts
# self.embed_model = embedding_model()
# self.service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model, chunk_size=config.embedding_chunk_size)
# set_global_service_context(self.service_context)
## load/create index
# self.save_path = f"{self.save_directory}/nodes.pkl"
# if os.path.exists(self.save_path):
# self.nodes = pickle.load(open(self.save_path, "rb"))
# else:
# self.nodes = []
## create vectorindex
# if len(self.nodes):
# self.index = VectorStoreIndex(self.nodes)
# else:
# self.index = EmptyIndex()
# def load(self, filters: Dict):
## load correct version based off filters
# if "agent_id" in filters and filters["agent_id"] is not None:
## load agent archival memory
# save_directory = self.agent_config.save_agent_index_dir()
# elif "data_source" in filters and filters["data_source"] is not None:
# name = filters["data_source"]
# save_directory = f"{MEMGPT_DIR}/archival/{name}"
# else:
# raise ValueError(f"Cannot load index without agent_id or data_source {filters}")
# save_path = f"{save_directory}/nodes.pkl"
# if os.path.exists(save_path):
# nodes = pickle.load(open(save_path, "rb"))
# else:
# nodes = []
## create vectorindex
# if len(self.nodes):
# self.index = VectorStoreIndex(self.nodes)
# else:
# self.index = EmptyIndex()
# def get_nodes(self) -> List[TextNode]:
# """Get llama index nodes"""
# embed_dict = self.index._vector_store._data.embedding_dict
# node_dict = self.index._docstore.docs
# nodes = []
# for node_id, node in node_dict.items():
# vector = embed_dict[node_id]
# node.embedding = vector
# nodes.append(TextNode(text=node.text, embedding=vector))
# return nodes
# def add_nodes(self, nodes: List[TextNode]):
# self.nodes += nodes
# self.index = VectorStoreIndex(self.nodes)
# def get_all_paginated(self, page_size: int = 100) -> Iterator[List[Passage]]:
# """Get all passages in the index"""
# nodes = self.get_nodes()
# for i in tqdm(range(0, len(nodes), page_size)):
# yield [Passage(text=node.text, embedding=node.embedding) for node in nodes[i : i + page_size]]
# def get_all(self, limit: int) -> List[Passage]:
# passages = []
# for node in self.get_nodes():
# assert node.embedding is not None, f"Node embedding is None"
# passages.append(Passage(text=node.text, embedding=node.embedding))
# if len(passages) >= limit:
# break
# return passages
# def get(self, id: str) -> Passage:
# pass
# def insert(self, passage: Passage):
# nodes = [TextNode(text=passage.text, embedding=passage.embedding)]
# self.nodes += nodes
# if isinstance(self.index, EmptyIndex):
# self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
# else:
# self.index.insert_nodes(nodes)
# def insert_many(self, passages: List[Passage]):
# nodes = [TextNode(text=passage.text, embedding=passage.embedding) for passage in passages]
# self.nodes += nodes
# if isinstance(self.index, EmptyIndex):
# self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
# else:
# orig_size = len(self.get_nodes())
# self.index.insert_nodes(nodes)
# assert len(self.get_nodes()) == orig_size + len(
# passages
# ), f"expected {orig_size + len(passages)} nodes, got {len(self.get_nodes())} nodes"
# def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
# if isinstance(self.index, EmptyIndex): # empty index
# return []
## TODO: this may be super slow?
## the nice thing about creating this here is that now we can save the persistent storage manager
# retriever = VectorIndexRetriever(
# index=self.index, # does this get refreshed?
# similarity_top_k=top_k,
# )
# nodes = retriever.retrieve(query)
# results = [Passage(embedding=node.embedding, text=node.text) for node in nodes]
# return results
# def save(self):
## assert len(self.nodes) == len(self.get_nodes()), f"Expected {len(self.nodes)} nodes, got {len(self.get_nodes())} nodes"
# self.nodes = self.get_nodes()
# os.makedirs(self.save_directory, exist_ok=True)
# pickle.dump(self.nodes, open(self.save_path, "wb"))
# @staticmethod
# def list_loaded_data():
# sources = []
# for data_source_file in os.listdir(os.path.join(MEMGPT_DIR, "archival")):
# name = os.path.basename(data_source_file)
# sources.append(name)
# return sources
# def size(self):
# return len(self.get_nodes())
class InMemoryStorageConnector(StorageConnector):
"""Really dumb class so we can have a unified storae connector interface - keeps everything in memory"""
""" Backwards compatible with previous version of recall memory """
# TODO: maybae replace this with sqllite?
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
config = MemGPTConfig.load()
# supported table types
self.supported_types = [TableType.RECALL_MEMORY]
if table_type not in self.supported_types:
raise ValueError(f"Table type {table_type} not supported by InMemoryStorageConnector")
# TODO: load if exists
self.agent_config = agent_config
if agent_config is None:
# is a data source
raise ValueError("Cannot load data source from InMemoryStorageConnector")
else:
directory = agent_config.save_state_dir()
if os.path.exists(directory):
print(f"Loading saved agent {agent_config.name} from {directory}")
json_files = glob.glob(os.path.join(directory, "*.json")) # This will list all .json files in the current directory.
if not json_files:
print(f"/load error: no .json checkpoint files found")
raise ValueError(f"Cannot load {agent_config.name} - no saved checkpoints found in {directory}")
# Sort files based on modified timestamp, with the latest file being the first.
filename = max(json_files, key=os.path.getmtime)
state = json.load(open(filename, "r"))
# load persistence manager
filename = os.path.basename(filename).replace(".json", ".persistence.pickle")
directory = agent_config.save_persistence_manager_dir()
printd(f"Loading persistence manager from {os.path.join(directory, filename)}")
with open(filename, "rb") as f:
data = pickle.load(f)
self.rows = data["all_messages"]
else:
print(f"Creating new agent {agent_config.name}")
self.rows = []
# convert to Record class
self.rows = [self.json_to_message(m) for m in self.rows]
def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]:
offset = 0
while True:
yield self.rows[offset : offset + page_size]
offset += page_size
if offset >= len(self.rows):
break
def get_all(self, limit: Optional[int] = None, filters: Optional[Dict] = {}) -> List[Record]:
if limit:
return self.rows[:limit]
return self.rows
def get(self, id: str) -> Record:
match_row = [row for row in self.rows if row.id == id]
if len(match_row) == 0:
return None
assert len(match_row) == 1, f"Expected 1 match, got {len(match_row)} matches"
return match_row[0]
def insert(self, record: Record):
self.rows.append(record)
def insert_many(self, records: List[Record]):
self.rows += records
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
raise NotImplementedError
def json_to_message(self, message_json) -> Message:
"""Convert agent message JSON into Message object"""
timestamp = message_json["timestamp"]
message = message_json["message"]
return Message(
user_id=self.config.anon_clientid,
agent_id=self.agent_config.name,
role=message["role"],
text=message["content"],
model=self.agent_config.model,
created_at=parse_formatted_time(timestamp),
function_name=message["function_name"] if "function_name" in message else None,
function_args=message["function_args"] if "function_args" in message else None,
function_response=message["function_response"] if "function_response" in message else None,
id=message["id"] if "id" in message else None,
)
def message_to_json(self, message: Message) -> Dict:
"""Convert Message object into JSON"""
return {
"timestamp": message.created_at.strftime("%Y-%m-%d %H:%M:%S %Z%z"),
"message": {
"role": message.role,
"content": message.text,
"function_name": message.function_name,
"function_args": message.function_args,
"function_response": message.function_response,
"id": message.id,
},
}
def save(self):
"""Save state of storage connector"""
timestamp = get_local_time().replace(" ", "_").replace(":", "_")
filename = f"{timestamp}.persistence.pickle"
os.makedirs(self.config.save_persistence_manager_dir(), exist_ok=True)
filename = os.path.join(self.config.save_persistence_manager_dir(), filename)
all_messages = [self.message_to_json(m) for m in self.rows]
with open(filename, "wb") as fh:
## TODO: fix this hacky solution to pickle the retriever
pickle.dump(
{
"all_messages": all_messages,
},
fh,
protocol=pickle.HIGHEST_PROTOCOL,
)
printd(f"Saved state to {fh}")
def size(self, filters: Optional[Dict] = {}) -> int:
return len(self.rows)
def query_date(self, start_date, end_date) -> List[Record]:
return [row for row in self.rows if row.created_at >= start_date and row.created_at <= end_date]
def query_text(self, query: str) -> List[Record]:
return [row for row in self.rows if row.role not in ["system", "function"] and query.lower() in row.text.lower()]
def delete(self, filters: Optional[Dict] = {}):
raise NotImplementedError
def delete_table(self, filters: Optional[Dict] = {}):
if os.path.exists(self.agent_config.save_state_dir()):
shutil.rmtree(self.agent_config.save_state_dir())
if os.path.exists(self.agent_config.save_persistence_manager_dir()):
shutil.rmtree(self.agent_config.save_persistence_manager_dir())