Merge pull request #146 from sarahwooders/main

Support loading data into archival with Llama Index connectors
This commit is contained in:
Sarah Wooders
2023-10-27 13:28:26 -07:00
committed by GitHub
8 changed files with 1764 additions and 16 deletions

View File

@@ -0,0 +1,111 @@
"""
This file contains functions for loading data into MemGPT's archival storage.
Data can be loaded with the following command, once a load function is defined:
```
memgpt load <data-connector-type> --name <dataset-name> [ADDITIONAL ARGS]
```
"""
from llama_index import download_loader
from typing import List
import os
import typer
from memgpt.constants import MEMGPT_DIR
from memgpt.utils import estimate_openai_cost, get_index, save_index
app = typer.Typer()
@app.command("directory")
def load_directory(
name: str = typer.Option(help="Name of dataset to load."),
input_dir: str = typer.Option(None, help="Path to directory containing dataset."),
input_files: List[str] = typer.Option(None, help="List of paths to files containing dataset."),
recursive: bool = typer.Option(False, help="Recursively search for files in directory."),
):
from llama_index import SimpleDirectoryReader
if recursive:
assert input_dir is not None, "Must provide input directory if recursive is True."
reader = SimpleDirectoryReader(
input_dir=input_dir,
recursive=True,
)
else:
reader = SimpleDirectoryReader(input_files=input_files)
# load docs
print("Loading data...")
docs = reader.load_data()
# embed docs
print("Indexing documents...")
index = get_index(name, docs)
# save connector information into .memgpt metadata file
save_index(index, name)
@app.command("webpage")
def load_webpage(
name: str = typer.Option(help="Name of dataset to load."),
urls: List[str] = typer.Option(None, help="List of urls to load."),
):
from llama_index import SimpleWebPageReader
docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
# embed docs
print("Indexing documents...")
index = get_index(docs)
# save connector information into .memgpt metadata file
save_index(index, name)
@app.command("database")
def load_database(
name: str = typer.Option(help="Name of dataset to load."),
query: str = typer.Option(help="Database query."),
dump_path: str = typer.Option(None, help="Path to dump file."),
scheme: str = typer.Option(None, help="Database scheme."),
host: str = typer.Option(None, help="Database host."),
port: int = typer.Option(None, help="Database port."),
user: str = typer.Option(None, help="Database user."),
password: str = typer.Option(None, help="Database password."),
dbname: str = typer.Option(None, help="Database name."),
):
from llama_index.readers.database import DatabaseReader
print(dump_path, scheme)
if dump_path is not None:
# read from database dump file
from sqlalchemy import create_engine, MetaData
engine = create_engine(f"sqlite:///{dump_path}")
db = DatabaseReader(engine=engine)
else:
assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
assert scheme is not None, "Must provide database scheme."
assert host is not None, "Must provide database host."
assert port is not None, "Must provide database port."
assert user is not None, "Must provide database user."
assert password is not None, "Must provide database password."
assert dbname is not None, "Must provide database name."
db = DatabaseReader(
scheme=scheme, # Database Scheme
host=host, # Database Host
port=port, # Database Port
user=user, # Database User
password=password, # Database Password
dbname=dbname, # Database Name
)
# load data
docs = db.load_data(query=query)
index = get_index(name, docs)
save_index(index, name)

View File

@@ -28,15 +28,16 @@ from memgpt.persistence_manager import (
from memgpt.config import Config
from memgpt.constants import MEMGPT_DIR
from memgpt.connectors import connector
from memgpt.openai_tools import (
configure_azure_support,
check_azure_embeddings,
get_set_azure_env_vars,
)
import asyncio
app = typer.Typer()
app.add_typer(connector.app, name="load")
def clear_line():
@@ -109,7 +110,7 @@ def load(memgpt_agent, filename):
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
@app.command()
@app.callback(invoke_without_command=True) # make default command
def run(
persona: str = typer.Option(None, help="Specify persona"),
human: str = typer.Option(None, help="Specify human"),

View File

@@ -1,14 +1,26 @@
from abc import ABC, abstractmethod
import os
import datetime
import re
import faiss
import numpy as np
from typing import Optional, List, Tuple
from .constants import MESSAGE_SUMMARY_WARNING_TOKENS
from .constants import MESSAGE_SUMMARY_WARNING_TOKENS, MEMGPT_DIR
from .utils import cosine_similarity, get_local_time, printd, count_tokens
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
from .openai_tools import acompletions_with_backoff as acreate, async_get_embedding_with_backoff
from llama_index import (
VectorStoreIndex,
get_response_synthesizer,
load_index_from_storage,
StorageContext,
)
from llama_index.retrievers import VectorIndexRetriever
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.indices.postprocessor import SimilarityPostprocessor
class CoreMemory(object):
"""Held in-context inside the system message
@@ -128,10 +140,26 @@ async def summarize_messages(
class ArchivalMemory(ABC):
@abstractmethod
def insert(self, memory_string):
"""Insert new archival memory
:param memory_string: Memory string to insert
:type memory_string: str
"""
pass
@abstractmethod
def search(self, query_string, count=None, start=None):
def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]:
"""Search archival memory
:param query_string: Query string
:type query_string: str
:param count: Number of results to return (None for all)
:type count: Optional[int]
:param start: Offset to start returning results from (None if 0)
:type start: Optional[int]
:return: Tuple of (list of results, total number of results)
"""
pass
@abstractmethod
@@ -515,3 +543,51 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
return matches[start:], len(matches)
else:
return matches, len(matches)
class LocalArchivalMemory(ArchivalMemory):
"""Archival memory built on top of Llama Index"""
def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100):
"""Init function for archival memory
:param archiva_memory_database: name of dataset to pre-fill archival with
:type archival_memory_database: str
"""
if archival_memory_database is not None:
# TODO: load form ~/.memgpt/archival
directory = f"{MEMGPT_DIR}/archival/{archival_memory_database}"
assert os.path.exists(directory), f"Archival memory database {archival_memory_database} does not exist"
storage_context = StorageContext.from_defaults(persist_dir=directory)
self.index = load_index_from_storage(storage_context)
else:
self.index = VectorIndex()
self.top_k = top_k
self.retriever = VectorIndexRetriever(
index=self.index, # does this get refreshed?
similarity_top_k=self.top_k,
)
# TODO: have some mechanism for cleanup otherwise will lead to OOM
self.cache = {}
async def insert(self, memory_string):
self.index.insert(memory_string)
async def search(self, query_string, count=None, start=None):
start = start if start else 0
count = count if count else self.top_k
count = min(count + start, self.top_k)
if query_string not in self.cache:
self.cache[query_string] = self.retriever.retrieve(query_string)
results = self.cache[query_string][start : start + count]
results = [{"timestamp": get_local_time(), "content": node.node.text} for node in results]
# from pprint import pprint
# pprint(results)
return results, len(results)
def __repr__(self) -> str:
print(self.index.ref_doc_info)
return ""

View File

@@ -7,6 +7,7 @@ from .memory import (
DummyArchivalMemory,
DummyArchivalMemoryWithEmbeddings,
DummyArchivalMemoryWithFaiss,
LocalArchivalMemory,
)
from .utils import get_local_time, printd
@@ -100,6 +101,74 @@ class InMemoryStateManager(PersistenceManager):
self.memory = new_memory
class LocalStateManager(PersistenceManager):
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
recall_memory_cls = DummyRecallMemory
archival_memory_cls = LocalArchivalMemory
def __init__(self, archival_memory_db=None):
# Memory held in-state useful for debugging stateful versions
self.memory = None
self.messages = []
self.all_messages = []
self.archival_memory = LocalArchivalMemory(archival_memory_database=archival_memory_db)
@staticmethod
def load(filename):
with open(filename, "rb") as f:
return pickle.load(f)
def save(self, filename):
with open(filename, "wb") as fh:
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
def init(self, agent):
printd(f"Initializing InMemoryStateManager with agent object")
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
self.memory = agent.memory
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
# Persistence manager also handles DB-related state
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
# TODO: init archival memory here?
def trim_messages(self, num):
# printd(f"InMemoryStateManager.trim_messages")
self.messages = [self.messages[0]] + self.messages[num:]
def prepend_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.prepend_to_message")
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
self.all_messages.extend(added_messages)
def append_to_messages(self, added_messages):
# first tag with timestamps
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
printd(f"InMemoryStateManager.append_to_messages")
self.messages = self.messages + added_messages
self.all_messages.extend(added_messages)
def swap_system_message(self, new_system_message):
# first tag with timestamps
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
printd(f"InMemoryStateManager.swap_system_message")
self.messages[0] = new_system_message
self.all_messages.append(new_system_message)
def update_memory(self, new_memory):
printd(f"InMemoryStateManager.update_memory")
self.memory = new_memory
class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
archival_memory_cls = DummyArchivalMemory
recall_memory_cls = DummyRecallMemory

View File

@@ -1,5 +1,4 @@
from datetime import datetime
import asyncio
import csv
import difflib
@@ -14,8 +13,11 @@ import glob
import sqlite3
import fitz
from tqdm import tqdm
import typer
from memgpt.openai_tools import async_get_embedding_with_backoff
from memgpt.constants import MEMGPT_DIR
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
from llama_index.embeddings import OpenAIEmbedding
def count_tokens(s: str, model: str = "gpt-4") -> int:
@@ -327,3 +329,96 @@ def read_database_as_list(database_name):
except Exception as e:
result_list.append(f"Error: {str(e)}")
return result_list
def estimate_openai_cost(docs):
"""Estimate OpenAI embedding cost
:param docs: Documents to be embedded
:type docs: List[Document]
:return: Estimated cost
:rtype: float
"""
from llama_index import MockEmbedding
from llama_index.callbacks import CallbackManager, TokenCountingHandler
import tiktoken
embed_model = MockEmbedding(embed_dim=1536)
token_counter = TokenCountingHandler(tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
callback_manager = CallbackManager([token_counter])
set_global_service_context(ServiceContext.from_defaults(embed_model=embed_model, callback_manager=callback_manager))
index = VectorStoreIndex.from_documents(docs)
# estimate cost
cost = 0.0001 * token_counter.total_embedding_token_count / 1000
token_counter.reset_counts()
return cost
def get_index(name, docs):
"""Index documents
:param docs: Documents to be embedded
:type docs: List[Document]
"""
# check if directory exists
dir = f"{MEMGPT_DIR}/archival/{name}"
if os.path.exists(dir):
confirm = typer.confirm(typer.style(f"Index with name {name} already exists -- re-index?", fg="yellow"), default=False)
if not confirm:
# return existing index
storage_context = StorageContext.from_defaults(persist_dir=dir)
return load_index_from_storage(storage_context)
# TODO: support configurable embeddings
# TODO: read from config how to index (open ai vs. local): then embed_mode="local"
estimated_cost = estimate_openai_cost(docs)
# TODO: prettier cost formatting
confirm = typer.confirm(
typer.style(f"Open AI embedding cost will be approximately ${estimated_cost} - continue?", fg="yellow"), default=True
)
if not confirm:
typer.secho("Aborting.", fg="red")
exit()
embed_model = OpenAIEmbedding()
service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=300)
set_global_service_context(service_context)
# index documents
index = VectorStoreIndex.from_documents(docs)
return index
def save_index(index, name):
"""Save index to a specificed name in ~/.memgpt
:param index: Index to save
:type index: VectorStoreIndex
:param name: Name of index
:type name: str
"""
# save
# TODO: load directory from config
# TODO: save to vectordb/local depending on config
dir = f"{MEMGPT_DIR}/archival/{name}"
## Avoid overwriting
## check if directory exists
# if os.path.exists(dir):
# confirm = typer.confirm(typer.style(f"Index with name {name} already exists -- overwrite?", fg="red"), default=False)
# if not confirm:
# typer.secho("Aborting.", fg="red")
# exit()
# create directory, even if it already exists
os.makedirs(dir, exist_ok=True)
index.storage_context.persist(dir)
print(dir)

1305
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -19,7 +19,7 @@ readme = "README.md"
memgpt = "memgpt.main:app"
[tool.poetry.dependencies]
python = "<3.13,>=3.9"
python = "<3.12,>=3.9"
typer = {extras = ["all"], version = "^0.9.0"}
questionary = "^2.0.1"
demjson3 = "^3.0.6"
@@ -31,6 +31,10 @@ pymupdf = "^1.23.5"
tqdm = "^4.66.1"
openai = "^0.28.1"
black = "^23.10.1"
pytest = "^7.4.3"
llama-index = "^0.8.53.post3"
setuptools = "^68.2.2"
datasets = "^2.14.6"
[build-system]

107
tests/test_load_archival.py Normal file
View File

@@ -0,0 +1,107 @@
import tempfile
import asyncio
import os
from memgpt.connectors.connector import load_directory, load_database, load_webpage
import memgpt.agent as agent
import memgpt.system as system
import memgpt.utils as utils
import memgpt.presets as presets
import memgpt.constants as constants
import memgpt.personas.personas as personas
import memgpt.humans.humans as humans
from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager
from memgpt.config import Config
from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL
from memgpt.connectors import connector
import memgpt.interface # for printing to terminal
import asyncio
from datasets import load_dataset
def test_load_directory():
# downloading hugging face dataset (if does not exist)
dataset = load_dataset("MemGPT/example_short_stories")
cache_dir = os.getenv("HF_DATASETS_CACHE")
if cache_dir is None:
# Construct the default path if the environment variable is not set.
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
# load directory
print("Loading dataset into index...")
print(cache_dir)
load_directory(
name="tmp_hf_dataset",
input_dir=cache_dir,
recursive=True,
)
# create state manager based off loaded data
persistence_manager = LocalStateManager(archival_memory_db="tmp_hf_dataset")
# create agent
memgpt_agent = presets.use_preset(
presets.DEFAULT,
DEFAULT_MEMGPT_MODEL,
personas.get_persona_text(personas.DEFAULT),
humans.get_human_text(humans.DEFAULT),
memgpt.interface,
persistence_manager,
)
def query(q):
res = asyncio.run(memgpt_agent.archival_memory_search(q))
return res
results = query("cinderella be getting sick")
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
def test_load_webpage():
pass
def test_load_database():
from sqlalchemy import create_engine, MetaData
import pandas as pd
db_path = "memgpt/personas/examples/sqldb/test.db"
engine = create_engine(f"sqlite:///{db_path}")
# Create a MetaData object and reflect the database to get table information.
metadata = MetaData()
metadata.reflect(bind=engine)
# Get a list of table names from the reflected metadata.
table_names = metadata.tables.keys()
print(table_names)
# Define a SQL query to retrieve data from a table (replace 'your_table_name' with your actual table name).
query = f"SELECT * FROM {list(table_names)[0]}"
# Use Pandas to read data from the database into a DataFrame.
df = pd.read_sql_query(query, engine)
print(df)
load_database(
name="tmp_db_dataset",
# engine=engine,
dump_path=db_path,
query=f"SELECT * FROM {list(table_names)[0]}",
)
persistence_manager = LocalStateManager(archival_memory_db="tmp_db_dataset")
# create agent
memgpt_agent = presets.use_preset(
presets.DEFAULT,
DEFAULT_MEMGPT_MODEL,
personas.get_persona_text(personas.DEFAULT),
humans.get_human_text(humans.DEFAULT),
memgpt.interface,
persistence_manager,
)
print("Successfully loaded into index")
assert True