Files
letta-server/memgpt/connectors/local.py

142 lines
5.2 KiB
Python

from typing import Optional, List, Iterator
from memgpt.config import AgentConfig, MemGPTConfig
from tqdm import tqdm
import re
import pickle
import os
from typing import List, Optional
from llama_index import (
VectorStoreIndex,
EmptyIndex,
ServiceContext,
)
from llama_index.retrievers import VectorIndexRetriever
from llama_index.schema import TextNode
from memgpt.constants import MEMGPT_DIR
from memgpt.config import MemGPTConfig
from memgpt.connectors.storage import StorageConnector, Passage
from memgpt.config import AgentConfig, MemGPTConfig
class LocalStorageConnector(StorageConnector):
"""Local storage connector based on LlamaIndex"""
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
from memgpt.embeddings import embedding_model
config = MemGPTConfig.load()
# TODO: add asserts to avoid both being passed
if name is None:
self.name = agent_config.name
self.save_directory = agent_config.save_agent_index_dir()
else:
self.name = name
self.save_directory = f"{MEMGPT_DIR}/archival/{name}"
# llama index contexts
self.embed_model = embedding_model()
self.service_context = ServiceContext.from_defaults(llm=None, embed_model=self.embed_model, chunk_size=config.embedding_chunk_size)
# load/create index
self.save_path = f"{self.save_directory}/nodes.pkl"
if os.path.exists(self.save_path):
self.nodes = pickle.load(open(self.save_path, "rb"))
else:
self.nodes = []
# create vectorindex
if len(self.nodes):
self.index = VectorStoreIndex(self.nodes)
else:
self.index = EmptyIndex()
def get_nodes(self) -> List[TextNode]:
"""Get llama index nodes"""
embed_dict = self.index._vector_store._data.embedding_dict
node_dict = self.index._docstore.docs
nodes = []
for node_id, node in node_dict.items():
vector = embed_dict[node_id]
node.embedding = vector
nodes.append(TextNode(text=node.text, embedding=vector))
return nodes
def add_nodes(self, nodes: List[TextNode]):
self.nodes += nodes
self.index = VectorStoreIndex(self.nodes)
def get_all_paginated(self, page_size: int = 100) -> Iterator[List[Passage]]:
"""Get all passages in the index"""
nodes = self.get_nodes()
for i in tqdm(range(0, len(nodes), page_size)):
yield [Passage(text=node.text, embedding=node.embedding) for node in nodes[i : i + page_size]]
def get_all(self, limit: int) -> List[Passage]:
passages = []
for node in self.get_nodes():
assert node.embedding is not None, f"Node embedding is None"
passages.append(Passage(text=node.text, embedding=node.embedding))
if len(passages) >= limit:
break
return passages
def get(self, id: str) -> Passage:
pass
def insert(self, passage: Passage):
nodes = [TextNode(text=passage.text, embedding=passage.embedding)]
self.nodes += nodes
if isinstance(self.index, EmptyIndex):
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
else:
self.index.insert_nodes(nodes)
def insert_many(self, passages: List[Passage]):
nodes = [TextNode(text=passage.text, embedding=passage.embedding) for passage in passages]
self.nodes += nodes
if isinstance(self.index, EmptyIndex):
self.index = VectorStoreIndex(self.nodes, service_context=self.service_context, show_progress=True)
else:
orig_size = len(self.get_nodes())
self.index.insert_nodes(nodes)
assert len(self.get_nodes()) == orig_size + len(
passages
), f"expected {orig_size + len(passages)} nodes, got {len(self.get_nodes())} nodes"
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
if isinstance(self.index, EmptyIndex): # empty index
return []
# TODO: this may be super slow?
# the nice thing about creating this here is that now we can save the persistent storage manager
retriever = VectorIndexRetriever(
index=self.index, # does this get refreshed?
similarity_top_k=top_k,
)
nodes = retriever.retrieve(query)
results = [Passage(embedding=node.embedding, text=node.text) for node in nodes]
return results
def save(self):
# assert len(self.nodes) == len(self.get_nodes()), f"Expected {len(self.nodes)} nodes, got {len(self.get_nodes())} nodes"
self.nodes = self.get_nodes()
os.makedirs(self.save_directory, exist_ok=True)
pickle.dump(self.nodes, open(self.save_path, "wb"))
@staticmethod
def list_loaded_data():
sources = []
for data_source_file in os.listdir(os.path.join(MEMGPT_DIR, "archival")):
name = os.path.basename(data_source_file)
sources.append(name)
return sources
def size(self):
return len(self.get_nodes())