From 15540c24ac328c995fb11f89d2e228cde508ab37 Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Sun, 15 Oct 2023 16:38:35 -0700 Subject: [PATCH] fix paging bug, implement llamaindex api search on top of memgpt --- interface.py | 8 ++- main.py | 10 ++- memgpt/agent.py | 6 +- memgpt/memory.py | 81 ++++++++++++++++++++++++ memgpt/persistence_manager.py | 26 +++++++- memgpt/personas/examples/docqa/README.md | 13 ++++ memgpt/personas/examples/memgpt_doc.txt | 5 +- memgpt/utils.py | 19 ++++++ 8 files changed, 158 insertions(+), 10 deletions(-) create mode 100644 memgpt/personas/examples/docqa/README.md diff --git a/interface.py b/interface.py index be9951dc..b8729b1e 100644 --- a/interface.py +++ b/interface.py @@ -71,9 +71,13 @@ async def function_message(msg): print(f'{Fore.RED}{Style.BRIGHT}⚡🧠 [function] {Fore.RED}updating memory with {function_name}{Style.RESET_ALL}:') try: msg_dict = eval(function_args) - print(f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t→ {msg_dict["new_content"]}') + if function_name == 'archival_memory_search': + print(f'{Fore.RED}\tquery: {msg_dict["query"]}, page: {msg_dict["page"]}') + else: + print(f'{Fore.RED}{Style.BRIGHT}\t{Fore.RED} {msg_dict["old_content"]}\n\t{Fore.GREEN}→ {msg_dict["new_content"]}') except Exception as e: - print(e) + printd(e) + printd(msg_dict) pass else: printd(f"Warning: did not recognize function message") diff --git a/main.py b/main.py index 97e32e74..a3824f78 100644 --- a/main.py +++ b/main.py @@ -16,7 +16,7 @@ import memgpt.presets as presets import memgpt.constants as constants import memgpt.personas.personas as personas import memgpt.humans.humans as humans -from memgpt.persistence_manager import InMemoryStateManager as persistence_manager +from memgpt.persistence_manager import InMemoryStateManager, InMemoryStateManagerWithFaiss FLAGS = flags.FLAGS flags.DEFINE_string("persona", default=personas.DEFAULT, required=False, help="Specify persona") @@ -24,6 +24,7 @@ flags.DEFINE_string("human", default=humans.DEFAULT, required=False, help="Speci flags.DEFINE_string("model", default=constants.DEFAULT_MEMGPT_MODEL, required=False, help="Specify the LLM model") flags.DEFINE_boolean("first", default=False, required=False, help="Use -first to send the first message in the sequence") flags.DEFINE_boolean("debug", default=False, required=False, help="Use -debug to enable debugging output") +flags.DEFINE_string("archival_storage_faiss_path", default="", required=False, help="Specify archival storage to load (a folder with a .index and .json describing documents to be loaded)") def clear_line(): @@ -43,7 +44,12 @@ async def main(): logging.getLogger().setLevel(logging.DEBUG) print("Running... [exit by typing '/exit']") - memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(), interface, persistence_manager()) + if FLAGS.archival_storage_faiss_path: + index, archival_database = utils.prepare_archival_index(FLAGS.archival_storage_faiss_path) + persistence_manager = InMemoryStateManagerWithFaiss(index, archival_database) + else: + persistence_manager = InMemoryStateManager() + memgpt_agent = presets.use_preset(presets.DEFAULT, FLAGS.model, personas.get_persona_text(FLAGS.persona), humans.get_human_text(FLAGS.human), interface, persistence_manager) print_messages = interface.print_messages await print_messages(memgpt_agent.messages) diff --git a/memgpt/agent.py b/memgpt/agent.py index 593b5e75..a64c29ec 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -624,7 +624,7 @@ class AgentAsync(object): return None async def recall_memory_search(self, query, count=5, page=0): - results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page) + results, total = await self.persistence_manager.recall_memory.text_search(query, count=count, start=page*count) num_pages = math.ceil(total / count) - 1 # 0 index if len(results) == 0: results_str = f"No results found." @@ -635,7 +635,7 @@ class AgentAsync(object): return results_str async def recall_memory_search_date(self, start_date, end_date, count=5, page=0): - results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page) + results, total = await self.persistence_manager.recall_memory.date_search(start_date, end_date, count=count, start=page*count) num_pages = math.ceil(total / count) - 1 # 0 index if len(results) == 0: results_str = f"No results found." @@ -650,7 +650,7 @@ class AgentAsync(object): return None async def archival_memory_search(self, query, count=5, page=0): - results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page) + results, total = await self.persistence_manager.archival_memory.search(query, count=count, start=page*count) num_pages = math.ceil(total / count) - 1 # 0 index if len(results) == 0: results_str = f"No results found." diff --git a/memgpt/memory.py b/memgpt/memory.py index 272dd683..fb064959 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -1,6 +1,8 @@ from abc import ABC, abstractmethod import datetime import re +import faiss +import numpy as np from .utils import cosine_similarity, get_local_time, printd from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM @@ -239,6 +241,85 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory): return matches, len(matches) +class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): + """Dummy in-memory version of an archival memory database, using a FAISS + index for fast nearest-neighbors embedding search. + + Archival memory is effectively "infinite" overflow for core memory, + and is read-only via string queries. + + Archival Memory: A more structured and deep storage space for the AI's reflections, + insights, or any other data that doesn't fit into the active memory but + is essential enough not to be left only to the recall memory. + """ + + def __init__(self, index=None, archival_memory_database=None, embedding_model='text-embedding-ada-002', k=100): + if index is None: + self.index = faiss.IndexFlatL2(1536) # openai embedding vector size. + else: + self.index = index + self.k = k + self._archive = [] if archival_memory_database is None else archival_memory_database # consists of {'content': str} dicts + self.embedding_model = embedding_model + self.embeddings_dict = {} + self.search_results = {} + + def __len__(self): + return len(self._archive) + + async def insert(self, memory_string, embedding=None): + if embedding is None: + # Get the embedding + embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model) + print(f"Got an embedding, type {type(embedding)}, len {len(embedding)}") + + self._archive.append({ + # can eventually upgrade to adding semantic tags, etc + 'timestamp': get_local_time(), + 'content': memory_string, + }) + embedding = np.array([embedding]).astype('float32') + self.index.add(embedding) + + async def search(self, query_string, count=None, start=None): + """Simple embedding-based search (inefficient, no caching)""" + # see: https://github.com/openai/openai-cookbook/blob/main/examples/Semantic_text_search_using_embeddings.ipynb + + # query_embedding = get_embedding(query_string, model=self.embedding_model) + # our wrapped version supports backoff/rate-limits + if query_string in self.embeddings_dict: + query_embedding = self.embeddings_dict[query_string] + search_result = self.search_results[query_string] + else: + query_embedding = await async_get_embedding_with_backoff(query_string, model=self.embedding_model) + _, indices = self.index.search(np.array([np.array(query_embedding, dtype=np.float32)]), self.k) + search_result = [self._archive[idx] if idx < len(self._archive) else "" for idx in indices[0]] + self.embeddings_dict[query_string] = query_embedding + self.search_results[query_string] = search_result + + if start is not None and count is not None: + toprint = search_result[start:start+count] + else: + if len(search_result) >= 5: + toprint = search_result[:5] + else: + toprint = search_result + printd(f"archive_memory.search (vector-based): search for query '{query_string}' returned the following results ({start}--{start+5}/{len(search_result)}) and scores:\n{str([t[:60] if len(t) > 60 else t for t in toprint])}") + + # Extract the sorted archive without the scores + matches = search_result + + # start/count support paging through results + if start is not None and count is not None: + return matches[start:start+count], len(matches) + elif start is None and count is not None: + return matches[:count], len(matches) + elif start is not None and count is None: + return matches[start:], len(matches) + else: + return matches, len(matches) + + class RecallMemory(ABC): @abstractmethod diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index 3c8e24f2..575741b3 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings +from .memory import DummyRecallMemory, DummyRecallMemoryWithEmbeddings, DummyArchivalMemory, DummyArchivalMemoryWithEmbeddings, DummyArchivalMemoryWithFaiss from .utils import get_local_time, printd @@ -88,4 +88,26 @@ class InMemoryStateManager(PersistenceManager): class InMemoryStateManagerWithEmbeddings(InMemoryStateManager): archival_memory_cls = DummyArchivalMemoryWithEmbeddings - recall_memory_cls = DummyRecallMemoryWithEmbeddings \ No newline at end of file + recall_memory_cls = DummyRecallMemoryWithEmbeddings + +class InMemoryStateManagerWithFaiss(InMemoryStateManager): + archival_memory_cls = DummyArchivalMemoryWithFaiss + recall_memory_cls = DummyRecallMemoryWithEmbeddings + + def __init__(self, archival_index, archival_memory_db, a_k=100): + super().__init__() + self.archival_index = archival_index + self.archival_memory_db = archival_memory_db + self.a_k = a_k + + def init(self, agent): + print(f"Initializing InMemoryStateManager with agent object") + self.all_messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] + self.messages = [{'timestamp': get_local_time(), 'message': msg} for msg in agent.messages.copy()] + self.memory = agent.memory + print(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}") + print(f"InMemoryStateManager.messages.len = {len(self.messages)}") + + # Persistence manager also handles DB-related state + self.recall_memory = self.recall_memory_cls(message_database=self.all_messages) + self.archival_memory = self.archival_memory_cls(index=self.archival_index, archival_memory_database=self.archival_memory_db, k=self.a_k) diff --git a/memgpt/personas/examples/docqa/README.md b/memgpt/personas/examples/docqa/README.md new file mode 100644 index 00000000..1daa7f81 --- /dev/null +++ b/memgpt/personas/examples/docqa/README.md @@ -0,0 +1,13 @@ +# MemGPT Search over LlamaIndex API Docs + +1. + a. Download embeddings and docs index from XYZ. + -- OR -- + b. Build the index: + 1. Build llama_index API docs with `make text`. Instructions [here](https://github.com/run-llama/llama_index/blob/main/docs/DOCS_README.md). Copy over the generated `_build/text` folder to this directory. + 2. Generate embeddings and FAISS index. + ```bash + python3 scrape_docs.py + python3 generate_embeddings_for_docs.py all_docs.jsonl + python3 build_index.py --embedding_files all_docs.embeddings.jsonl --output_index_file all_docs.index + ``` \ No newline at end of file diff --git a/memgpt/personas/examples/memgpt_doc.txt b/memgpt/personas/examples/memgpt_doc.txt index 0a850b99..9af2c7f6 100644 --- a/memgpt/personas/examples/memgpt_doc.txt +++ b/memgpt/personas/examples/memgpt_doc.txt @@ -1,3 +1,6 @@ My name is MemGPT. I am an AI assistant designed to help human users with document analysis. -I can use this space in my core memory to keep track of my current tasks and goals. \ No newline at end of file +I can use this space in my core memory to keep track of my current tasks and goals. + +The answer to the human's question will usually be located somewhere in your archival memory, so keep paging through results until you find enough information to construct an answer. +Do not respond to the human until you have arrived at an answer. \ No newline at end of file diff --git a/memgpt/utils.py b/memgpt/utils.py index d9b1a4ae..e008cd93 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -4,6 +4,8 @@ import demjson3 as demjson import numpy as np import json import pytz +import os +import faiss # DEBUG = True @@ -61,3 +63,20 @@ def parse_json(string): except demjson.JSONDecodeError as e: print(f"Error parsing json with demjson package: {e}") raise e + +def prepare_archival_index(folder): + index_file = os.path.join(folder, "all_docs.index") + index = faiss.read_index(index_file) + + archival_database_file = os.path.join(folder, "all_docs.jsonl") + archival_database = [] + with open(archival_database_file, 'rt') as f: + all_data = [json.loads(line) for line in f] + for doc in all_data: + total = len(doc) + for i, passage in enumerate(doc): + archival_database.append({ + 'content': f"[Title: {passage['title']}, {i}/{total}] {passage['text']}", + 'timestamp': get_local_time(), + }) + return index, archival_database \ No newline at end of file