Set get_all limit to None by default and add postgres to archival memory tests
This commit is contained in:
@@ -494,10 +494,10 @@ def attach(
|
||||
source_storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES)
|
||||
dest_storage = StorageConnector.get_storage_connector(table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config)
|
||||
|
||||
size = source_storage.size()
|
||||
size = source_storage.size({"data_source": data_source})
|
||||
typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN)
|
||||
page_size = 100
|
||||
generator = source_storage.get_all_paginated(page_size=page_size) # yields List[Passage]
|
||||
generator = source_storage.get_all_paginated(filters={"data_source": data_source}, page_size=page_size) # yields List[Passage]
|
||||
passages = []
|
||||
for i in tqdm(range(0, size, page_size)):
|
||||
passages = next(generator)
|
||||
|
||||
@@ -47,6 +47,7 @@ def store_docs(name, docs, show_progress=True):
|
||||
# compute and record passages
|
||||
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=config.archival_storage_type)
|
||||
embed_model = embedding_model()
|
||||
orig_size = storage.size()
|
||||
|
||||
# use llama index to run embeddings code
|
||||
service_context = ServiceContext.from_defaults(llm=None, embed_model=embed_model, chunk_size=config.embedding_chunk_size)
|
||||
@@ -77,6 +78,7 @@ def store_docs(name, docs, show_progress=True):
|
||||
|
||||
# insert into storage
|
||||
storage.insert_many(passages)
|
||||
assert orig_size + len(passages) == storage.size(), f"Expected {orig_size + len(passages)} passages, got {storage.size()}"
|
||||
storage.save()
|
||||
|
||||
|
||||
|
||||
@@ -139,7 +139,6 @@ class ChromaStorageConnector(StorageConnector):
|
||||
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
ids, documents, embeddings, metadatas = self.format_records(records)
|
||||
print("Inserting", ids)
|
||||
if not any(embeddings):
|
||||
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
else:
|
||||
|
||||
@@ -191,8 +191,7 @@ class SQLStorageConnector(StorageConnector):
|
||||
filter_conditions = {**self.filters, **filters}
|
||||
else:
|
||||
filter_conditions = self.filters
|
||||
print("FILTERS", filter_conditions)
|
||||
|
||||
print("SQL FILTERS", filter_conditions)
|
||||
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
||||
@@ -213,11 +212,13 @@ class SQLStorageConnector(StorageConnector):
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
print("LIMIT", limit)
|
||||
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
if limit:
|
||||
db_records = session.query(self.db_model).filter(*filters).limit(limit).all()
|
||||
else:
|
||||
db_records = session.query(self.db_model).filter(*filters).all()
|
||||
return [record.to_record() for record in db_records]
|
||||
|
||||
def get(self, id: str) -> Optional[Record]:
|
||||
@@ -229,9 +230,9 @@ class SQLStorageConnector(StorageConnector):
|
||||
|
||||
def size(self, filters: Optional[Dict] = {}) -> int:
|
||||
# return size of table
|
||||
print("size")
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
print("ALL FILTERS", filters)
|
||||
return session.query(self.db_model).filter(*filters).count()
|
||||
|
||||
def insert(self, record: Record):
|
||||
@@ -404,122 +405,3 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
|
||||
sqlite3.register_adapter(uuid.UUID, lambda u: u.bytes_le)
|
||||
sqlite3.register_converter("UUID", lambda b: uuid.UUID(bytes_le=b))
|
||||
|
||||
|
||||
class LanceDBConnector(StorageConnector):
|
||||
"""Storage via LanceDB"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, name: Optional[str] = None, agent_config: Optional[AgentConfig] = None):
|
||||
config = MemGPTConfig.load()
|
||||
# determine table name
|
||||
if agent_config:
|
||||
assert name is None, f"Cannot specify both agent config and name {name}"
|
||||
self.table_name = self.generate_table_name_agent(agent_config)
|
||||
elif name:
|
||||
assert agent_config is None, f"Cannot specify both agent config and name {name}"
|
||||
self.table_name = self.generate_table_name(name)
|
||||
else:
|
||||
raise ValueError("Must specify either agent config or name")
|
||||
|
||||
# create table
|
||||
self.uri = config.archival_storage_uri
|
||||
if config.archival_storage_uri is None:
|
||||
raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}")
|
||||
import lancedb
|
||||
|
||||
self.db = lancedb.connect(self.uri)
|
||||
if self.table_name in self.db.table_names():
|
||||
self.table = self.db[self.table_name]
|
||||
else:
|
||||
self.table = None
|
||||
|
||||
def get_all_paginated(self, page_size: int) -> Iterator[List[Passage]]:
|
||||
ds = self.table.to_lance()
|
||||
offset = 0
|
||||
while True:
|
||||
# Retrieve a chunk of records with the given page_size
|
||||
db_passages_chunk = ds.to_table(offset=offset, limit=page_size).to_pylist()
|
||||
# If the chunk is empty, we've retrieved all records
|
||||
if not db_passages_chunk:
|
||||
break
|
||||
|
||||
# Yield a list of Passage objects converted from the chunk
|
||||
yield [
|
||||
Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages_chunk
|
||||
]
|
||||
|
||||
# Increment the offset to get the next chunk in the next iteration
|
||||
offset += page_size
|
||||
|
||||
def get_all(self, limit=10) -> List[Passage]:
|
||||
db_passages = self.table.to_lance().to_table(limit=limit).to_pylist()
|
||||
return [Passage(text=p["text"], embedding=p["vector"], doc_id=p["doc_id"], passage_id=p["passage_id"]) for p in db_passages]
|
||||
|
||||
def get(self, id: str) -> Optional[Passage]:
|
||||
db_passage = self.table.where(f"passage_id={id}").to_list()
|
||||
if len(db_passage) == 0:
|
||||
return None
|
||||
return Passage(
|
||||
text=db_passage["text"], embedding=db_passage["embedding"], doc_id=db_passage["doc_id"], passage_id=db_passage["passage_id"]
|
||||
)
|
||||
|
||||
def size(self) -> int:
|
||||
# return size of table
|
||||
if self.table:
|
||||
return len(self.table)
|
||||
else:
|
||||
printd(f"Table with name {self.table_name} not present")
|
||||
return 0
|
||||
|
||||
def insert(self, passage: Passage):
|
||||
data = [{"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}]
|
||||
|
||||
if self.table is not None:
|
||||
self.table.add(data)
|
||||
else:
|
||||
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
|
||||
|
||||
def insert_many(self, passages: List[Passage], show_progress=True):
|
||||
data = []
|
||||
iterable = tqdm(passages) if show_progress else passages
|
||||
for passage in iterable:
|
||||
temp_dict = {"doc_id": passage.doc_id, "text": passage.text, "passage_id": passage.passage_id, "vector": passage.embedding}
|
||||
data.append(temp_dict)
|
||||
|
||||
if self.table is not None:
|
||||
self.table.add(data)
|
||||
else:
|
||||
self.table = self.db.create_table(self.table_name, data=data, mode="overwrite")
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10) -> List[Passage]:
|
||||
# Assuming query_vec is of same length as embeddings inside table
|
||||
results = self.table.search(query_vec).limit(top_k).to_list()
|
||||
# Convert the results into Passage objects
|
||||
passages = [
|
||||
Passage(text=result["text"], embedding=result["vector"], doc_id=result["doc_id"], passage_id=result["passage_id"])
|
||||
for result in results
|
||||
]
|
||||
return passages
|
||||
|
||||
def delete(self):
|
||||
"""Drop the passage table from the database."""
|
||||
# Drop the table specified by the PassageModel class
|
||||
self.db.drop_table(self.table_name)
|
||||
|
||||
def save(self):
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def list_loaded_data():
|
||||
config = MemGPTConfig.load()
|
||||
import lancedb
|
||||
|
||||
db = lancedb.connect(config.archival_storage_uri)
|
||||
|
||||
tables = db.table_names()
|
||||
tables = [table for table in tables if table.startswith("memgpt_")]
|
||||
start_chars = len("memgpt_")
|
||||
tables = [table[start_chars:] for table in tables]
|
||||
return tables
|
||||
|
||||
@@ -73,6 +73,8 @@ class StorageConnector:
|
||||
else:
|
||||
self.filters = {}
|
||||
|
||||
print("FILTERS", self.filters)
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}):
|
||||
# get all filters for query
|
||||
if filters is not None:
|
||||
|
||||
@@ -440,7 +440,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
def __repr__(self) -> str:
|
||||
limit = 10
|
||||
passages = []
|
||||
for passage in list(self.storage.get_all(limit)): # TODO: only get first 10
|
||||
for passage in list(self.storage.get_all(limit=limit)): # TODO: only get first 10
|
||||
passages.append(str(passage.text))
|
||||
memory_str = "\n".join(passages)
|
||||
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"
|
||||
|
||||
@@ -4,50 +4,62 @@ import os
|
||||
import pytest
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
|
||||
# import asyncio
|
||||
from datasets import load_dataset
|
||||
|
||||
# import memgpt
|
||||
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
from memgpt.cli.cli import attach
|
||||
from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
|
||||
# import memgpt.presets as presets
|
||||
# import memgpt.personas.personas as personas
|
||||
# import memgpt.humans.humans as humans
|
||||
# from memgpt.persistence_manager import InMemoryStateManager, LocalStateManager
|
||||
|
||||
# # from memgpt.config import AgentConfig
|
||||
# from memgpt.constants import MEMGPT_DIR, DEFAULT_MEMGPT_MODEL
|
||||
# import memgpt.interface # for printing to terminal
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("storage_connector", ["sqllite", "postgres"])
|
||||
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite"])
|
||||
@pytest.mark.parametrize("passage_storage_connector", ["chroma"])
|
||||
@pytest.mark.parametrize("metadata_storage_connector", ["sqlite", "postgres"])
|
||||
@pytest.mark.parametrize("passage_storage_connector", ["chroma", "postgres"])
|
||||
def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
|
||||
# setup config
|
||||
config = MemGPTConfig()
|
||||
if metadata_storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.metadata_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.metadata_storage_type = "postgres"
|
||||
elif metadata_storage_connector == "sqlite":
|
||||
print("testing sqlite metadata")
|
||||
# nothing to do (should be config defaults)
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {metadata_storage_connector} not implemented")
|
||||
if passage_storage_connector == "postgres":
|
||||
if not os.getenv("PGVECTOR_TEST_DB_URL"):
|
||||
print("Skipping test, missing PG URI")
|
||||
return
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.archival_storage_type = "postgres"
|
||||
elif passage_storage_connector == "chroma":
|
||||
print("testing chroma passage storage")
|
||||
# nothing to do (should be config defaults)
|
||||
else:
|
||||
raise NotImplementedError(f"Storage type {passage_storage_connector} not implemented")
|
||||
config.save()
|
||||
|
||||
# setup storage connectors
|
||||
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
||||
|
||||
# load hugging face dataset
|
||||
# dataset_name = "MemGPT/example_short_stories"
|
||||
# dataset = load_dataset(dataset_name)
|
||||
|
||||
# cache_dir = os.getenv("HF_DATASETS_CACHE")
|
||||
# if cache_dir is None:
|
||||
# # Construct the default path if the environment variable is not set.
|
||||
# cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets")
|
||||
# print("HF Directory", cache_dir)
|
||||
# load data
|
||||
name = "test_dataset"
|
||||
cache_dir = "CONTRIBUTING.md"
|
||||
|
||||
# TODO: load two different data sources
|
||||
|
||||
# clear out data
|
||||
data_source_conn.delete_table()
|
||||
passages_conn.delete_table()
|
||||
data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES)
|
||||
passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector)
|
||||
assert (
|
||||
data_source_conn.size() == 0
|
||||
), f"Expected 0 records, got {data_source_conn.size()}: {[vars(r) for r in data_source_conn.get_all()]}"
|
||||
assert passages_conn.size() == 0, f"Expected 0 records, got {passages_conn.size()}: {[vars(r) for r in passages_conn.get_all()]}"
|
||||
|
||||
# test: load directory
|
||||
load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False) # cache_dir,
|
||||
@@ -59,8 +71,14 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
print("Source", sources)
|
||||
|
||||
# test to see if contained in storage
|
||||
assert (
|
||||
len(passages_conn.get_all()) == passages_conn.size()
|
||||
), f"Expected {passages_conn.size()} passages, but got {len(passages_conn.get_all())}"
|
||||
passages = passages_conn.get_all({"data_source": name})
|
||||
print("Source", [p.data_source for p in passages])
|
||||
print("All sources", [p.data_source for p in passages_conn.get_all()])
|
||||
assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}"
|
||||
assert len(passages) == passages_conn.size(), f"Expected {passages_conn.size()} passages, but got {len(passages)}"
|
||||
assert [p.data_source == name for p in passages]
|
||||
print("Passages", passages)
|
||||
|
||||
@@ -71,7 +89,7 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
# test loading into an agent
|
||||
# create agent
|
||||
agent_config = AgentConfig(
|
||||
name="test_agent",
|
||||
name="memgpt_test_agent",
|
||||
persona=DEFAULT_PERSONA,
|
||||
human=DEFAULT_HUMAN,
|
||||
model=DEFAULT_MEMGPT_MODEL,
|
||||
@@ -81,7 +99,11 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
conn = StorageConnector.get_storage_connector(
|
||||
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
|
||||
)
|
||||
assert conn.size() == 0
|
||||
conn.delete_table()
|
||||
conn = StorageConnector.get_storage_connector(
|
||||
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
|
||||
)
|
||||
assert conn.size() == 0, f"Expected 0 records, got {conn.size()}: {[vars(r) for r in conn.get_all()]}"
|
||||
|
||||
# attach data
|
||||
attach(agent=agent_config.name, data_source=name)
|
||||
|
||||
@@ -56,8 +56,7 @@ def generate_messages():
|
||||
return messages
|
||||
|
||||
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqllite", "lancedb"])
|
||||
# @pytest.mark.parametrize("storage_connector", ["sqllite"])
|
||||
@pytest.mark.parametrize("storage_connector", ["postgres", "chroma", "sqlite"])
|
||||
@pytest.mark.parametrize("table_type", [TableType.RECALL_MEMORY, TableType.ARCHIVAL_MEMORY])
|
||||
def test_storage(storage_connector, table_type):
|
||||
|
||||
@@ -86,9 +85,9 @@ def test_storage(storage_connector, table_type):
|
||||
return
|
||||
config.archival_storage_type = "chroma"
|
||||
config.archival_storage_path = "./test_chroma"
|
||||
if storage_connector == "sqllite":
|
||||
if storage_connector == "sqlite":
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
print("Skipping test, sqllite only supported for recall memory")
|
||||
print("Skipping test, sqlite only supported for recall memory")
|
||||
return
|
||||
config.recall_storage_type = "local"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user