Support attaching data sources to agents for storage refactor
This commit is contained in:
@@ -485,14 +485,14 @@ def attach(
|
||||
):
|
||||
try:
|
||||
# loads the data contained in data source into the agent's memory
|
||||
from memgpt.connectors.storage import StorageConnector
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from tqdm import tqdm
|
||||
|
||||
agent_config = AgentConfig.load(agent)
|
||||
|
||||
# get storage connectors
|
||||
source_storage = StorageConnector.get_archival_storage_connector(name=data_source)
|
||||
dest_storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config)
|
||||
# get storage connectors
|
||||
source_storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES)
|
||||
dest_storage = StorageConnector.get_storage_connector(table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config)
|
||||
|
||||
size = source_storage.size()
|
||||
typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN)
|
||||
@@ -501,6 +501,13 @@ def attach(
|
||||
passages = []
|
||||
for i in tqdm(range(0, size, page_size)):
|
||||
passages = next(generator)
|
||||
print("inserting", passages)
|
||||
|
||||
# need to associated passage with agent (for filtering)
|
||||
for passage in passages:
|
||||
passage.agent_id = agent_config.name
|
||||
|
||||
# insert into agent archival memory
|
||||
dest_storage.insert_many(passages)
|
||||
|
||||
# save destination storage
|
||||
|
||||
@@ -42,13 +42,19 @@ class ChromaStorageConnector(StorageConnector):
|
||||
filter_conditions = self.filters
|
||||
|
||||
# convert to chroma format
|
||||
chroma_filters = {"$and": []}
|
||||
chroma_filters = []
|
||||
ids = []
|
||||
for key, value in filter_conditions.items():
|
||||
if key == "id":
|
||||
ids = [str(value)]
|
||||
continue
|
||||
chroma_filters["$and"].append({key: {"$eq": value}})
|
||||
chroma_filters.append({key: {"$eq": value}})
|
||||
|
||||
if len(chroma_filters) > 1:
|
||||
chroma_filters = {"$and": chroma_filters}
|
||||
else:
|
||||
chroma_filters = chroma_filters[0]
|
||||
|
||||
return ids, chroma_filters
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
||||
@@ -133,6 +139,7 @@ class ChromaStorageConnector(StorageConnector):
|
||||
|
||||
def insert_many(self, records: List[Record], show_progress=True):
|
||||
ids, documents, embeddings, metadatas = self.format_records(records)
|
||||
print("Inserting", ids)
|
||||
if not any(embeddings):
|
||||
self.collection.add(documents=documents, ids=ids, metadatas=metadatas)
|
||||
else:
|
||||
|
||||
@@ -114,10 +114,12 @@ class StorageConnector:
|
||||
|
||||
# read from config if not provided
|
||||
if storage_type is None:
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES:
|
||||
storage_type = MemGPTConfig.load().archival_storage_type
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
storage_type = MemGPTConfig.load().recall_storage_type
|
||||
elif table_type == TableType.DATA_SOURCES or table_type == TableType.USERS or table_type == TableType.AGENTS:
|
||||
storage_type = MemGPTConfig.load().metadata_storage_type
|
||||
# TODO: other tables
|
||||
|
||||
if storage_type == "postgres":
|
||||
|
||||
@@ -9,6 +9,9 @@ from datasets import load_dataset
|
||||
|
||||
# import memgpt
|
||||
from memgpt.cli.cli_load import load_directory, load_database, load_webpage
|
||||
from memgpt.cli.cli import attach
|
||||
from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN
|
||||
from memgpt.config import AgentConfig, MemGPTConfig
|
||||
|
||||
# import memgpt.presets as presets
|
||||
# import memgpt.personas.personas as personas
|
||||
@@ -65,6 +68,28 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector):
|
||||
sources = data_source_conn.get_all()
|
||||
print("All sources", [s.name for s in sources])
|
||||
|
||||
# test loading into an agent
|
||||
# create agent
|
||||
agent_config = AgentConfig(
|
||||
name="test_agent",
|
||||
persona=DEFAULT_PERSONA,
|
||||
human=DEFAULT_HUMAN,
|
||||
model=DEFAULT_MEMGPT_MODEL,
|
||||
)
|
||||
agent_config.save()
|
||||
# create storage connector
|
||||
conn = StorageConnector.get_storage_connector(
|
||||
storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config
|
||||
)
|
||||
assert conn.size() == 0
|
||||
|
||||
# attach data
|
||||
attach(agent=agent_config.name, data_source=name)
|
||||
|
||||
# test to see if contained in storage
|
||||
assert len(passages) == conn.size()
|
||||
assert len(passages) == len(conn.get_all({"data_source": name}))
|
||||
|
||||
# test: delete source
|
||||
data_source_conn.delete({"name": name})
|
||||
passages_conn.delete({"data_source": name})
|
||||
|
||||
Reference in New Issue
Block a user