Support attaching data sources to agents for storage refactor

This commit is contained in:
Sarah Wooders
2023-12-24 13:46:27 +04:00
parent 7a14e2020a
commit d4ddf549e3
4 changed files with 48 additions and 7 deletions

View File

@@ -485,14 +485,14 @@ def attach(
):
try:
# loads the data contained in data source into the agent's memory
from memgpt.connectors.storage import StorageConnector
from memgpt.connectors.storage import StorageConnector, TableType
from tqdm import tqdm
agent_config = AgentConfig.load(agent)
# get storage connectors
source_storage = StorageConnector.get_archival_storage_connector(name=data_source)
dest_storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config)
# get storage connectors
source_storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES)
dest_storage = StorageConnector.get_storage_connector(table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config)
size = source_storage.size()
typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN)
@@ -501,6 +501,13 @@ def attach(
passages = []
for i in tqdm(range(0, size, page_size)):
passages = next(generator)
print("inserting", passages)
# need to associated passage with agent (for filtering)
for passage in passages:
passage.agent_id = agent_config.name
# insert into agent archival memory
dest_storage.insert_many(passages)
# save destination storage

View File

@@ -42,13 +42,19 @@ class ChromaStorageConnector(StorageConnector):
filter_conditions = self.filters
# convert to chroma format
chroma_filters = {"$and": []}
chroma_filters = []
ids = []
for key, value in filter_conditions.items():
if key == "id":
ids = [str(value)]
continue
chroma_filters["$and"].append({key: {"$eq": value}})
chroma_filters.append({key: {"$eq": value}})
if len(chroma_filters) > 1:
chroma_filters = {"$and": chroma_filters}
else:
chroma_filters = chroma_filters[0]
return ids, chroma_filters
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
@@ -133,6 +139,7 @@ class ChromaStorageConnector(StorageConnector):
def insert_many(self, records: List[Record], show_progress=True):
ids, documents, embeddings, metadatas = self.format_records(records)
print("Inserting", ids)
if not any(embeddings):
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
else:

View File

@@ -114,10 +114,12 @@ class StorageConnector:
# read from config if not provided
if storage_type is None:
if table_type == TableType.ARCHIVAL_MEMORY:
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
storage_type = MemGPTConfig.load().archival_storage_type
elif table_type == TableType.RECALL_MEMORY:
storage_type = MemGPTConfig.load().recall_storage_type
elif table_type == TableType.DATA_SOURCES or table_type == TableType.USERS or table_type == TableType.AGENTS:
storage_type = MemGPTConfig.load().metadata_storage_type
# TODO: other tables
if storage_type == "postgres":

View File

@@ -9,6 +9,9 @@ from datasets import load_dataset
# import memgpt
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
from memgpt.cli.cli import attach
from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN
from memgpt.config import AgentConfig, MemGPTConfig
# import memgpt.presets as presets
# import memgpt.personas.personas as personas
@@ -65,6 +68,28 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
sources = data_source_conn.get_all()
print("All sources", [s.name for s in sources])
# test loading into an agent
# create agent
agent_config = AgentConfig(
name="test_agent",
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
model=DEFAULT_MEMGPT_MODEL,
)
agent_config.save()
# create storage connector
conn = StorageConnector.get_storage_connector(
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
)
assert conn.size() == 0
# attach data
attach(agent=agent_config.name, data_source=name)
# test to see if contained in storage
assert len(passages) == conn.size()
assert len(passages) == len(conn.get_all({"data_source": name}))
# test: delete source
data_source_conn.delete({"name": name})
passages_conn.delete({"data_source": name})