add flag for preloading files
This commit is contained in:
@@ -76,6 +76,8 @@ python main.py --human me.txt
|
||||
enables debugging output
|
||||
--archival_storage_faiss_path=<ARCHIVAL_STORAGE_FAISS_PATH>
|
||||
load in document database (backed by FAISS index)
|
||||
--archival_storage_files="<ARCHIVAL_STORAGE_FILES_GLOB>"
|
||||
pre-load files into archival memory
|
||||
```
|
||||
|
||||
### Interactive CLI commands
|
||||
|
||||
9
main.py
9
main.py
@@ -16,7 +16,7 @@ import memgpt.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithFaiss
|
||||
from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithPreloadedArchivalMemory, InMemoryStateManagerWithFaiss
|
||||
|
||||
FLAGS = flags.FLAGS
|
||||
flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona")
|
||||
@@ -24,7 +24,8 @@ flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Speci
|
||||
flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model")
|
||||
flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence")
|
||||
flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output")
|
||||
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage to load (a folder with a .index and .json describing documents to be loaded)")
|
||||
flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage with FAISS index to load (a folder with a .index and .json describing documents to be loaded)")
|
||||
flags.DEFINE_string("archival_storage_files", default="", required=False, help="Specify files to pre-load into archival memory (glob pattern)")
|
||||
|
||||
|
||||
def clear_line():
|
||||
@@ -47,6 +48,10 @@ async def main():
|
||||
if FLAGS.archival_storage_faiss_path:
|
||||
index, archival_database = utils.prepare_archival_index(FLAGS.archival_storage_faiss_path)
|
||||
persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database)
|
||||
elif FLAGS.archival_storage_files:
|
||||
archival_database = utils.prepare_archival_index_from_files(FLAGS.archival_storage_files)
|
||||
print(f"Preloaded {len(archival_database)} chunks into archival memory.")
|
||||
persistence_manager = InMemoryStateManagerWithPreloadedArchivalMemory(archival_database)
|
||||
else:
|
||||
persistence_manager = InMemoryStateManager()
|
||||
memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager)
|
||||
|
||||
@@ -85,11 +85,29 @@ class InMemoryStateManager(PersistenceManager):
|
||||
self.memory = new_memory
|
||||
|
||||
|
||||
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
|
||||
class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemory
|
||||
recall_memory_cls = DummyRecallMemory
|
||||
|
||||
def __init__(self, archival_memory_db):
|
||||
self.archival_memory_db = archival_memory_db
|
||||
|
||||
def init(self, agent):
|
||||
print(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
||||
print(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
self.archival_memory = self.archival_memory_cls(archival_memory_database=self.archival_memory_db)
|
||||
|
||||
|
||||
class InMemoryStateManagerWithEmbeddings(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemoryWithEmbeddings
|
||||
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
||||
|
||||
|
||||
class InMemoryStateManagerWithFaiss(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemoryWithFaiss
|
||||
recall_memory_cls = DummyRecallMemoryWithEmbeddings
|
||||
|
||||
16
memgpt/personas/examples/preload_archival/README.md
Normal file
16
memgpt/personas/examples/preload_archival/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Preloading Archival Memory with Files
|
||||
MemGPT enables you to chat with your data locally -- this example gives the workflow for loading documents into MemGPT's archival memory.
|
||||
|
||||
To run our example where you can search over the SEC 10-K filings of Uber, Lyft, and Airbnb,
|
||||
|
||||
1. Download the .txt files from [HuggingFace](https://huggingface.co/datasets/MemGPT/example-sec-filings/tree/main) and place them in this directory.
|
||||
|
||||
2. In the root `MemGPT` directory, run
|
||||
```bash
|
||||
python3 main.py --archival_storage_files="memgpt/personas/examples/preload_archival/*.txt" --persona=memgpt_doc --human=basic
|
||||
```
|
||||
|
||||
|
||||
If you would like to load your own local files into MemGPT's archival memory, run the command above but replace `--archival_storage_files="memgpt/personas/examples/preload_archival/*.txt"` with your own file glob expression (enclosed in quotes).
|
||||
|
||||
## Demo
|
||||
@@ -7,6 +7,7 @@ import pytz
|
||||
import os
|
||||
import faiss
|
||||
import tiktoken
|
||||
import glob
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
@@ -83,4 +84,56 @@ def prepare_archival_index(folder):
|
||||
'content': f"[Title: {passage['title']}, {i}/{total}] {passage['text']}",
|
||||
'timestamp': get_local_time(),
|
||||
})
|
||||
return index, archival_database
|
||||
return index, archival_database
|
||||
|
||||
def read_in_chunks(file_object, chunk_size):
|
||||
while True:
|
||||
data = file_object.read(chunk_size)
|
||||
if not data:
|
||||
break
|
||||
yield data
|
||||
|
||||
def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'):
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
files = glob.glob(glob_pattern)
|
||||
archival_database = []
|
||||
for file in files:
|
||||
timestamp = os.path.getmtime(file)
|
||||
formatted_time = datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
with open(file, 'r') as f:
|
||||
lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)]
|
||||
chunks = []
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
for line in lines:
|
||||
line = line.rstrip()
|
||||
line = line.lstrip()
|
||||
try:
|
||||
line_token_ct = len(encoding.encode(line))
|
||||
except Exception as e:
|
||||
line_token_ct = len(line.split(' ')) / .75
|
||||
print(f"Could not encode line {line}, estimating it to be {line_token_ct} tokens")
|
||||
if line_token_ct > tkns_per_chunk:
|
||||
if len(curr_chunk) > 0:
|
||||
chunks.append(''.join(curr_chunk))
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
chunks.append(line[:3200])
|
||||
continue
|
||||
curr_token_ct += line_token_ct
|
||||
curr_chunk.append(line)
|
||||
if curr_token_ct > tkns_per_chunk:
|
||||
chunks.append(''.join(curr_chunk))
|
||||
curr_chunk = []
|
||||
curr_token_ct = 0
|
||||
|
||||
if len(curr_chunk) > 0:
|
||||
chunks.append(''.join(curr_chunk))
|
||||
|
||||
file_stem = file.split('/')[-1]
|
||||
for i, chunk in enumerate(chunks):
|
||||
archival_database.append({
|
||||
'content': f"[File: {file_stem} Part {i}/{len(chunks)}] {chunk}",
|
||||
'timestamp': formatted_time,
|
||||
})
|
||||
return archival_database
|
||||
|
||||
Reference in New Issue
Block a user