add flag for preloading files

This commit is contained in:
Vivian Fang
2023-10-16 16:55:25 -07:00
parent 86d52c4cdf
commit 0e6786a72a
5 changed files with 98 additions and 4 deletions

View File

@@ -76,6 +76,8 @@ python main.py --human me.txt
enables debugging output
--archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH>
load in document database (backed by FAISS index)
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB>"
pre-load files into archival memory
```
### Interactive CLI commands

View File

@@ -16,7 +16,7 @@ import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithFaiss
from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithPreloadedArchivalMemory, InMemoryStateManagerWithFaiss
FLAGS = flags.FLAGS
flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona")
@@ -24,7 +24,8 @@ flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Speci
flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model")
flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence")
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage to load (a folder with a .index and .json describing documents to be loaded)")
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)")
flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)")
def clear_line():
@@ -47,6 +48,10 @@ async def main():
if FLAGS.archival_storage_faiss_path:
index, archival_database = utils.prepare_archival_index(FLAGS.archival_storage_faiss_path)
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
elif FLAGS.archival_storage_files:
archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files)
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database)
else:
persistence_manager = InMemoryStateManager()
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)

View File

@@ -85,11 +85,29 @@ class InMemoryStateManager(PersistenceManager):
self.memory = new_memory
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
archival_memory_cls = DummyArchivalMemory
recall_memory_cls = DummyRecallMemory
def __init__(self, archival_memory_db):
self.archival_memory_db = archival_memory_db
def init(self, agent):
print(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
self.memory = agent.memory
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
self.archival_memory = self.archival_memory_cls(archival_memory_database=self.archival_memory_db)
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
archival_memory_cls = DummyArchivalMemoryWithEmbeddings
recall_memory_cls = DummyRecallMemoryWithEmbeddings
class InMemoryStateManagerWithFaiss(InMemoryStateManager):
archival_memory_cls = DummyArchivalMemoryWithFaiss
recall_memory_cls = DummyRecallMemoryWithEmbeddings

View File

@@ -0,0 +1,16 @@
# Preloading Archival Memory with Files
MemGPT enables you to chat with your data locally -- this example gives the workflow for loading documents into MemGPT's archival memory.
To run our example where you can search over the SEC 10-K filings of Uber, Lyft, and Airbnb,
1. Download the .txt files from [HuggingFace](https://huggingface.co/datasets/MemGPT/example-sec-filings/tree/main) and place them in this directory.
2. In the root `MemGPT` directory, run
```bash
python3 main.py --archival_storage_files="memgpt/personas/examples/preload_archival/*.txt" --persona=memgpt_doc --human=basic
```
If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes).
## Demo

View File

@@ -7,6 +7,7 @@ import pytz
import os
import faiss
import tiktoken
import glob
def count_tokens(s: str, model: str = "gpt-4") -> int:
encoding = tiktoken.encoding_for_model(model)
@@ -83,4 +84,56 @@ def prepare_archival_index(folder):
'content': f"[Title: {passage['title']}, {i}/{total}] {passage['text']}",
'timestamp': get_local_time(),
})
return index, archival_database
return index, archival_database
def read_in_chunks(file_object, chunk_size):
while True:
data = file_object.read(chunk_size)
if not data:
break
yield data
def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'):
encoding = tiktoken.encoding_for_model(model)
files = glob.glob(glob_pattern)
archival_database = []
for file in files:
timestamp = os.path.getmtime(file)
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
with open(file, 'r') as f:
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
chunks = []
curr_chunk = []
curr_token_ct = 0
for line in lines:
line = line.rstrip()
line = line.lstrip()
try:
line_token_ct = len(encoding.encode(line))
except Exception as e:
line_token_ct = len(line.split(' ')) / .75
print(f"Could not encode line {line}, estimating it to be {line_token_ct} tokens")
if line_token_ct > tkns_per_chunk:
if len(curr_chunk) > 0:
chunks.append(''.join(curr_chunk))
curr_chunk = []
curr_token_ct = 0
chunks.append(line[:3200])
continue
curr_token_ct += line_token_ct
curr_chunk.append(line)
if curr_token_ct > tkns_per_chunk:
chunks.append(''.join(curr_chunk))
curr_chunk = []
curr_token_ct = 0
if len(curr_chunk) > 0:
chunks.append(''.join(curr_chunk))
file_stem = file.split('/')[-1]
for i, chunk in enumerate(chunks):
archival_database.append({
'content': f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}",
'timestamp': formatted_time,
})
return archival_database