Merge pull request #146 from sarahwooders/main
Support loading data into archival with Llama Index connectors
This commit is contained in:
111
memgpt/connectors/connector.py
Normal file
111
memgpt/connectors/connector.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""
|
||||
This file contains functions for loading data into MemGPT's archival storage.
|
||||
|
||||
Data can be loaded with the following command, once a load function is defined:
|
||||
```
|
||||
memgpt load <data-connector-type> --name <dataset-name> [ADDITIONAL ARGS]
|
||||
```
|
||||
|
||||
"""
|
||||
|
||||
from llama_index import download_loader
|
||||
from typing import List
|
||||
import os
|
||||
import typer
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.utils import estimate_openai_cost, get_index, save_index
|
||||
|
||||
app = typer.Typer()
|
||||
|
||||
|
||||
@app.command("directory")
|
||||
def load_directory(
|
||||
name: str = typer.Option(help="Name of dataset to load."),
|
||||
input_dir: str = typer.Option(None, help="Path to directory containing dataset."),
|
||||
input_files: List[str] = typer.Option(None, help="List of paths to files containing dataset."),
|
||||
recursive: bool = typer.Option(False, help="Recursively search for files in directory."),
|
||||
):
|
||||
from llama_index import SimpleDirectoryReader
|
||||
|
||||
if recursive:
|
||||
assert input_dir is not None, "Must provide input directory if recursive is True."
|
||||
reader = SimpleDirectoryReader(
|
||||
input_dir=input_dir,
|
||||
recursive=True,
|
||||
)
|
||||
else:
|
||||
reader = SimpleDirectoryReader(input_files=input_files)
|
||||
|
||||
# load docs
|
||||
print("Loading data...")
|
||||
docs = reader.load_data()
|
||||
|
||||
# embed docs
|
||||
print("Indexing documents...")
|
||||
index = get_index(name, docs)
|
||||
# save connector information into .memgpt metadata file
|
||||
save_index(index, name)
|
||||
|
||||
|
||||
@app.command("webpage")
|
||||
def load_webpage(
|
||||
name: str = typer.Option(help="Name of dataset to load."),
|
||||
urls: List[str] = typer.Option(None, help="List of urls to load."),
|
||||
):
|
||||
from llama_index import SimpleWebPageReader
|
||||
|
||||
docs = SimpleWebPageReader(html_to_text=True).load_data(urls)
|
||||
|
||||
# embed docs
|
||||
print("Indexing documents...")
|
||||
index = get_index(docs)
|
||||
# save connector information into .memgpt metadata file
|
||||
save_index(index, name)
|
||||
|
||||
|
||||
@app.command("database")
|
||||
def load_database(
|
||||
name: str = typer.Option(help="Name of dataset to load."),
|
||||
query: str = typer.Option(help="Database query."),
|
||||
dump_path: str = typer.Option(None, help="Path to dump file."),
|
||||
scheme: str = typer.Option(None, help="Database scheme."),
|
||||
host: str = typer.Option(None, help="Database host."),
|
||||
port: int = typer.Option(None, help="Database port."),
|
||||
user: str = typer.Option(None, help="Database user."),
|
||||
password: str = typer.Option(None, help="Database password."),
|
||||
dbname: str = typer.Option(None, help="Database name."),
|
||||
):
|
||||
from llama_index.readers.database import DatabaseReader
|
||||
|
||||
print(dump_path, scheme)
|
||||
|
||||
if dump_path is not None:
|
||||
# read from database dump file
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
|
||||
engine = create_engine(f"sqlite:///{dump_path}")
|
||||
|
||||
db = DatabaseReader(engine=engine)
|
||||
else:
|
||||
assert dump_path is None, "Cannot provide both dump_path and database connection parameters."
|
||||
assert scheme is not None, "Must provide database scheme."
|
||||
assert host is not None, "Must provide database host."
|
||||
assert port is not None, "Must provide database port."
|
||||
assert user is not None, "Must provide database user."
|
||||
assert password is not None, "Must provide database password."
|
||||
assert dbname is not None, "Must provide database name."
|
||||
|
||||
db = DatabaseReader(
|
||||
scheme=scheme, # Database Scheme
|
||||
host=host, # Database Host
|
||||
port=port, # Database Port
|
||||
user=user, # Database User
|
||||
password=password, # Database Password
|
||||
dbname=dbname, # Database Name
|
||||
)
|
||||
|
||||
# load data
|
||||
docs = db.load_data(query=query)
|
||||
|
||||
index = get_index(name, docs)
|
||||
save_index(index, name)
|
||||
@@ -28,15 +28,16 @@ from memgpt.persistence_manager import (
|
||||
|
||||
from memgpt.config import Config
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from memgpt.connectors import connector
|
||||
from memgpt.openai_tools import (
|
||||
configure_azure_support,
|
||||
check_azure_embeddings,
|
||||
get_set_azure_env_vars,
|
||||
)
|
||||
|
||||
import asyncio
|
||||
|
||||
app = typer.Typer()
|
||||
app.add_typer(connector.app, name="load")
|
||||
|
||||
|
||||
def clear_line():
|
||||
@@ -109,7 +110,7 @@ def load(memgpt_agent, filename):
|
||||
print(f"/load warning: loading persistence manager from {filename} failed with: {e}")
|
||||
|
||||
|
||||
@app.command()
|
||||
@app.callback(invoke_without_command=True) # make default command
|
||||
def run(
|
||||
persona: str = typer.Option(None, help="Specify persona"),
|
||||
human: str = typer.Option(None, help="Specify human"),
|
||||
|
||||
@@ -1,14 +1,26 @@
|
||||
from abc import ABC, abstractmethod
|
||||
import os
|
||||
import datetime
|
||||
import re
|
||||
import faiss
|
||||
import numpy as np
|
||||
from typing import Optional, List, Tuple
|
||||
|
||||
from .constants import MESSAGE_SUMMARY_WARNING_TOKENS
|
||||
from .constants import MESSAGE_SUMMARY_WARNING_TOKENS, MEMGPT_DIR
|
||||
from .utils import cosine_similarity, get_local_time, printd, count_tokens
|
||||
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from .openai_tools import acompletions_with_backoff as acreate, async_get_embedding_with_backoff
|
||||
|
||||
from llama_index import (
|
||||
VectorStoreIndex,
|
||||
get_response_synthesizer,
|
||||
load_index_from_storage,
|
||||
StorageContext,
|
||||
)
|
||||
from llama_index.retrievers import VectorIndexRetriever
|
||||
from llama_index.query_engine import RetrieverQueryEngine
|
||||
from llama_index.indices.postprocessor import SimilarityPostprocessor
|
||||
|
||||
|
||||
class CoreMemory(object):
|
||||
"""Held in-context inside the system message
|
||||
@@ -128,10 +140,26 @@ async def summarize_messages(
|
||||
class ArchivalMemory(ABC):
|
||||
@abstractmethod
|
||||
def insert(self, memory_string):
|
||||
"""Insert new archival memory
|
||||
|
||||
:param memory_string: Memory string to insert
|
||||
:type memory_string: str
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def search(self, query_string, count=None, start=None):
|
||||
def search(self, query_string, count=None, start=None) -> Tuple[List[str], int]:
|
||||
"""Search archival memory
|
||||
|
||||
:param query_string: Query string
|
||||
:type query_string: str
|
||||
:param count: Number of results to return (None for all)
|
||||
:type count: Optional[int]
|
||||
:param start: Offset to start returning results from (None if 0)
|
||||
:type start: Optional[int]
|
||||
|
||||
:return: Tuple of (list of results, total number of results)
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -515,3 +543,51 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
return matches[start:], len(matches)
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
|
||||
class LocalArchivalMemory(ArchivalMemory):
|
||||
"""Archival memory built on top of Llama Index"""
|
||||
|
||||
def __init__(self, archival_memory_database: Optional[str] = None, top_k: Optional[int] = 100):
|
||||
"""Init function for archival memory
|
||||
|
||||
:param archiva_memory_database: name of dataset to pre-fill archival with
|
||||
:type archival_memory_database: str
|
||||
"""
|
||||
|
||||
if archival_memory_database is not None:
|
||||
# TODO: load form ~/.memgpt/archival
|
||||
directory = f"{MEMGPT_DIR}/archival/{archival_memory_database}"
|
||||
assert os.path.exists(directory), f"Archival memory database {archival_memory_database} does not exist"
|
||||
storage_context = StorageContext.from_defaults(persist_dir=directory)
|
||||
self.index = load_index_from_storage(storage_context)
|
||||
else:
|
||||
self.index = VectorIndex()
|
||||
self.top_k = top_k
|
||||
self.retriever = VectorIndexRetriever(
|
||||
index=self.index, # does this get refreshed?
|
||||
similarity_top_k=self.top_k,
|
||||
)
|
||||
# TODO: have some mechanism for cleanup otherwise will lead to OOM
|
||||
self.cache = {}
|
||||
|
||||
async def insert(self, memory_string):
|
||||
self.index.insert(memory_string)
|
||||
|
||||
async def search(self, query_string, count=None, start=None):
|
||||
start = start if start else 0
|
||||
count = count if count else self.top_k
|
||||
count = min(count + start, self.top_k)
|
||||
|
||||
if query_string not in self.cache:
|
||||
self.cache[query_string] = self.retriever.retrieve(query_string)
|
||||
|
||||
results = self.cache[query_string][start : start + count]
|
||||
results = [{"timestamp": get_local_time(), "content": node.node.text} for node in results]
|
||||
# from pprint import pprint
|
||||
# pprint(results)
|
||||
return results, len(results)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
print(self.index.ref_doc_info)
|
||||
return ""
|
||||
|
||||
@@ -7,6 +7,7 @@ from .memory import (
|
||||
DummyArchivalMemory,
|
||||
DummyArchivalMemoryWithEmbeddings,
|
||||
DummyArchivalMemoryWithFaiss,
|
||||
LocalArchivalMemory,
|
||||
)
|
||||
from .utils import get_local_time, printd
|
||||
|
||||
@@ -100,6 +101,74 @@ class InMemoryStateManager(PersistenceManager):
|
||||
self.memory = new_memory
|
||||
|
||||
|
||||
class LocalStateManager(PersistenceManager):
|
||||
"""In-memory state manager has nothing to manage, all agents are held in-memory"""
|
||||
|
||||
recall_memory_cls = DummyRecallMemory
|
||||
archival_memory_cls = LocalArchivalMemory
|
||||
|
||||
def __init__(self, archival_memory_db=None):
|
||||
# Memory held in-state useful for debugging stateful versions
|
||||
self.memory = None
|
||||
self.messages = []
|
||||
self.all_messages = []
|
||||
self.archival_memory = LocalArchivalMemory(archival_memory_database=archival_memory_db)
|
||||
|
||||
@staticmethod
|
||||
def load(filename):
|
||||
with open(filename, "rb") as f:
|
||||
return pickle.load(f)
|
||||
|
||||
def save(self, filename):
|
||||
with open(filename, "wb") as fh:
|
||||
pickle.dump(self, fh, protocol=pickle.HIGHEST_PROTOCOL)
|
||||
|
||||
def init(self, agent):
|
||||
printd(f"Initializing InMemoryStateManager with agent object")
|
||||
self.all_messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.messages = [{"timestamp": get_local_time(), "message": msg} for msg in agent.messages.copy()]
|
||||
self.memory = agent.memory
|
||||
printd(f"InMemoryStateManager.all_messages.len = {len(self.all_messages)}")
|
||||
printd(f"InMemoryStateManager.messages.len = {len(self.messages)}")
|
||||
|
||||
# Persistence manager also handles DB-related state
|
||||
self.recall_memory = self.recall_memory_cls(message_database=self.all_messages)
|
||||
|
||||
# TODO: init archival memory here?
|
||||
|
||||
def trim_messages(self, num):
|
||||
# printd(f"InMemoryStateManager.trim_messages")
|
||||
self.messages = [self.messages[0]] + self.messages[num:]
|
||||
|
||||
def prepend_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"InMemoryStateManager.prepend_to_message")
|
||||
self.messages = [self.messages[0]] + added_messages + self.messages[1:]
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
def append_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
added_messages = [{"timestamp": get_local_time(), "message": msg} for msg in added_messages]
|
||||
|
||||
printd(f"InMemoryStateManager.append_to_messages")
|
||||
self.messages = self.messages + added_messages
|
||||
self.all_messages.extend(added_messages)
|
||||
|
||||
def swap_system_message(self, new_system_message):
|
||||
# first tag with timestamps
|
||||
new_system_message = {"timestamp": get_local_time(), "message": new_system_message}
|
||||
|
||||
printd(f"InMemoryStateManager.swap_system_message")
|
||||
self.messages[0] = new_system_message
|
||||
self.all_messages.append(new_system_message)
|
||||
|
||||
def update_memory(self, new_memory):
|
||||
printd(f"InMemoryStateManager.update_memory")
|
||||
self.memory = new_memory
|
||||
|
||||
|
||||
class InMemoryStateManagerWithPreloadedArchivalMemory(InMemoryStateManager):
|
||||
archival_memory_cls = DummyArchivalMemory
|
||||
recall_memory_cls = DummyRecallMemory
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from datetime import datetime
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import difflib
|
||||
@@ -14,8 +13,11 @@ import glob
|
||||
import sqlite3
|
||||
import fitz
|
||||
from tqdm import tqdm
|
||||
import typer
|
||||
from memgpt.openai_tools import async_get_embedding_with_backoff
|
||||
from memgpt.constants import MEMGPT_DIR
|
||||
from llama_index import set_global_service_context, ServiceContext, VectorStoreIndex, load_index_from_storage, StorageContext
|
||||
from llama_index.embeddings import OpenAIEmbedding
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
@@ -327,3 +329,96 @@ def read_database_as_list(database_name):
|
||||
except Exception as e:
|
||||
result_list.append(f"Error: {str(e)}")
|
||||
return result_list
|
||||
|
||||
|
||||
def estimate_openai_cost(docs):
|
||||
"""Estimate OpenAI embedding cost
|
||||
|
||||
:param docs: Documents to be embedded
|
||||
:type docs: List[Document]
|
||||
:return: Estimated cost
|
||||
:rtype: float
|
||||
"""
|
||||
from llama_index import MockEmbedding
|
||||
from llama_index.callbacks import CallbackManager, TokenCountingHandler
|
||||
import tiktoken
|
||||
|
||||
embed_model = MockEmbedding(embed_dim=1536)
|
||||
|
||||
token_counter = TokenCountingHandler(tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode)
|
||||
|
||||
callback_manager = CallbackManager([token_counter])
|
||||
|
||||
set_global_service_context(ServiceContext.from_defaults(embed_model=embed_model, callback_manager=callback_manager))
|
||||
index = VectorStoreIndex.from_documents(docs)
|
||||
|
||||
# estimate cost
|
||||
cost = 0.0001 * token_counter.total_embedding_token_count / 1000
|
||||
token_counter.reset_counts()
|
||||
return cost
|
||||
|
||||
|
||||
def get_index(name, docs):
|
||||
"""Index documents
|
||||
|
||||
:param docs: Documents to be embedded
|
||||
:type docs: List[Document]
|
||||
"""
|
||||
|
||||
# check if directory exists
|
||||
dir = f"{MEMGPT_DIR}/archival/{name}"
|
||||
if os.path.exists(dir):
|
||||
confirm = typer.confirm(typer.style(f"Index with name {name} already exists -- re-index?", fg="yellow"), default=False)
|
||||
if not confirm:
|
||||
# return existing index
|
||||
storage_context = StorageContext.from_defaults(persist_dir=dir)
|
||||
return load_index_from_storage(storage_context)
|
||||
|
||||
# TODO: support configurable embeddings
|
||||
# TODO: read from config how to index (open ai vs. local): then embed_mode="local"
|
||||
|
||||
estimated_cost = estimate_openai_cost(docs)
|
||||
# TODO: prettier cost formatting
|
||||
confirm = typer.confirm(
|
||||
typer.style(f"Open AI embedding cost will be approximately ${estimated_cost} - continue?", fg="yellow"), default=True
|
||||
)
|
||||
|
||||
if not confirm:
|
||||
typer.secho("Aborting.", fg="red")
|
||||
exit()
|
||||
|
||||
embed_model = OpenAIEmbedding()
|
||||
service_context = ServiceContext.from_defaults(embed_model=embed_model, chunk_size=300)
|
||||
set_global_service_context(service_context)
|
||||
|
||||
# index documents
|
||||
index = VectorStoreIndex.from_documents(docs)
|
||||
return index
|
||||
|
||||
|
||||
def save_index(index, name):
|
||||
"""Save index to a specificed name in ~/.memgpt
|
||||
|
||||
:param index: Index to save
|
||||
:type index: VectorStoreIndex
|
||||
:param name: Name of index
|
||||
:type name: str
|
||||
"""
|
||||
# save
|
||||
# TODO: load directory from config
|
||||
# TODO: save to vectordb/local depending on config
|
||||
|
||||
dir = f"{MEMGPT_DIR}/archival/{name}"
|
||||
|
||||
## Avoid overwriting
|
||||
## check if directory exists
|
||||
# if os.path.exists(dir):
|
||||
# confirm = typer.confirm(typer.style(f"Index with name {name} already exists -- overwrite?", fg="red"), default=False)
|
||||
# if not confirm:
|
||||
# typer.secho("Aborting.", fg="red")
|
||||
# exit()
|
||||
|
||||
# create directory, even if it already exists
|
||||
os.makedirs(dir, exist_ok=True)
|
||||
index.storage_context.persist(dir)
|
||||
print(dir)
|
||||
|
||||
1305
poetry.lock
generated
1305
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -19,7 +19,7 @@ readme = "README.md"
|
||||
memgpt = "memgpt.main:app"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "<3.13,>=3.9"
|
||||
python = "<3.12,>=3.9"
|
||||
typer = {extras = ["all"], version = "^0.9.0"}
|
||||
questionary = "^2.0.1"
|
||||
demjson3 = "^3.0.6"
|
||||
@@ -31,6 +31,10 @@ pymupdf = "^1.23.5"
|
||||
tqdm = "^4.66.1"
|
||||
openai = "^0.28.1"
|
||||
black = "^23.10.1"
|
||||
pytest = "^7.4.3"
|
||||
llama-index = "^0.8.53.post3"
|
||||
setuptools = "^68.2.2"
|
||||
datasets = "^2.14.6"
|
||||
|
||||
|
||||
[build-system]
|
||||
|
||||
107
tests/test_load_archival.py
Normal file
107
tests/test_load_archival.py
Normal file
@@ -0,0 +1,107 @@
|
||||
import tempfile
|
||||
import asyncio
|
||||
import os
|
||||
from memgpt.connectors.connector import load_directory, load_database, load_webpage
|
||||
import memgpt.agent as agent
|
||||
import memgpt.system as system
|
||||
import memgpt.utils as utils
|
||||
import memgpt.presets as presets
|
||||
import memgpt.constants as constants
|
||||
import memgpt.personas.personas as personas
|
||||
import memgpt.humans.humans as humans
|
||||
from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager
|
||||
from memgpt.config import Config
|
||||
from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL
|
||||
from memgpt.connectors import connector
|
||||
import memgpt.interface # for printing to terminal
|
||||
import asyncio
|
||||
from datasets import load_dataset
|
||||
|
||||
|
||||
def test_load_directory():
|
||||
# downloading hugging face dataset (if does not exist)
|
||||
dataset = load_dataset("MemGPT/example_short_stories")
|
||||
|
||||
cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
|
||||
if cache_dir is None:
|
||||
# Construct the default path if the environment variable is not set.
|
||||
cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
|
||||
# load directory
|
||||
print("Loading dataset into index...")
|
||||
print(cache_dir)
|
||||
load_directory(
|
||||
name="tmp_hf_dataset",
|
||||
input_dir=cache_dir,
|
||||
recursive=True,
|
||||
)
|
||||
|
||||
# create state manager based off loaded data
|
||||
persistence_manager = LocalStateManager(archival_memory_db="tmp_hf_dataset")
|
||||
|
||||
# create agent
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT,
|
||||
DEFAULT_MEMGPT_MODEL,
|
||||
personas.get_persona_text(personas.DEFAULT),
|
||||
humans.get_human_text(humans.DEFAULT),
|
||||
memgpt.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
|
||||
def query(q):
|
||||
res = asyncio.run(memgpt_agent.archival_memory_search(q))
|
||||
return res
|
||||
|
||||
results = query("cinderella be getting sick")
|
||||
assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}"
|
||||
|
||||
|
||||
def test_load_webpage():
|
||||
pass
|
||||
|
||||
|
||||
def test_load_database():
|
||||
from sqlalchemy import create_engine, MetaData
|
||||
import pandas as pd
|
||||
|
||||
db_path = "memgpt/personas/examples/sqldb/test.db"
|
||||
engine = create_engine(f"sqlite:///{db_path}")
|
||||
|
||||
# Create a MetaData object and reflect the database to get table information.
|
||||
metadata = MetaData()
|
||||
metadata.reflect(bind=engine)
|
||||
|
||||
# Get a list of table names from the reflected metadata.
|
||||
table_names = metadata.tables.keys()
|
||||
|
||||
print(table_names)
|
||||
|
||||
# Define a SQL query to retrieve data from a table (replace 'your_table_name' with your actual table name).
|
||||
query = f"SELECT * FROM {list(table_names)[0]}"
|
||||
|
||||
# Use Pandas to read data from the database into a DataFrame.
|
||||
df = pd.read_sql_query(query, engine)
|
||||
print(df)
|
||||
|
||||
load_database(
|
||||
name="tmp_db_dataset",
|
||||
# engine=engine,
|
||||
dump_path=db_path,
|
||||
query=f"SELECT * FROM {list(table_names)[0]}",
|
||||
)
|
||||
|
||||
persistence_manager = LocalStateManager(archival_memory_db="tmp_db_dataset")
|
||||
|
||||
# create agent
|
||||
memgpt_agent = presets.use_preset(
|
||||
presets.DEFAULT,
|
||||
DEFAULT_MEMGPT_MODEL,
|
||||
personas.get_persona_text(personas.DEFAULT),
|
||||
humans.get_human_text(humans.DEFAULT),
|
||||
memgpt.interface,
|
||||
persistence_manager,
|
||||
)
|
||||
print("Successfully loaded into index")
|
||||
assert True
|
||||
Reference in New Issue
Block a user