diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 3681ea98..31431c6d 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -485,14 +485,14 @@ def attach( ): try: # loads the data contained in data source into the agent's memory - from memgpt.connectors.storage import StorageConnector + from memgpt.connectors.storage import StorageConnector, TableType from tqdm import tqdm agent_config = AgentConfig.load(agent) - # get storage connectors - source_storage = StorageConnector.get_archival_storage_connector(name=data_source) - dest_storage = StorageConnector.get_archival_storage_connector(agent_config=agent_config) + # get storage connectors + source_storage = StorageConnector.get_storage_connector(table_type=TableType.PASSAGES) + dest_storage = StorageConnector.get_storage_connector(table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config) size = source_storage.size() typer.secho(f"Ingesting {size} passages into {agent_config.name}", fg=typer.colors.GREEN) @@ -501,6 +501,13 @@ def attach( passages = [] for i in tqdm(range(0, size, page_size)): passages = next(generator) + print("inserting", passages) + + # need to associated passage with agent (for filtering) + for passage in passages: + passage.agent_id = agent_config.name + + # insert into agent archival memory dest_storage.insert_many(passages) # save destination storage diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py index a1c5ac2b..8d27728f 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/connectors/chroma.py @@ -42,13 +42,19 @@ class ChromaStorageConnector(StorageConnector): filter_conditions = self.filters # convert to chroma format - chroma_filters = {"$and": []} + chroma_filters = [] ids = [] for key, value in filter_conditions.items(): if key == "id": ids = [str(value)] continue - chroma_filters["$and"].append({key: {"$eq": value}}) + chroma_filters.append({key: {"$eq": value}}) + + if len(chroma_filters) > 1: + chroma_filters = {"$and": chroma_filters} + else: + chroma_filters = chroma_filters[0] + return ids, chroma_filters def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: @@ -133,6 +139,7 @@ class ChromaStorageConnector(StorageConnector): def insert_many(self, records: List[Record], show_progress=True): ids, documents, embeddings, metadatas = self.format_records(records) + print("Inserting", ids) if not any(embeddings): self.collection.add(documents=documents, ids=ids, metadatas=metadatas) else: diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 81d97037..0564b12e 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -114,10 +114,12 @@ class StorageConnector: # read from config if not provided if storage_type is None: - if table_type == TableType.ARCHIVAL_MEMORY: + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: storage_type = MemGPTConfig.load().archival_storage_type elif table_type == TableType.RECALL_MEMORY: storage_type = MemGPTConfig.load().recall_storage_type + elif table_type == TableType.DATA_SOURCES or table_type == TableType.USERS or table_type == TableType.AGENTS: + storage_type = MemGPTConfig.load().metadata_storage_type # TODO: other tables if storage_type == "postgres": diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index 78c02454..5c0240d9 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -9,6 +9,9 @@ from datasets import load_dataset # import memgpt from memgpt.cli.cli_load import load_directory, load_database, load_webpage +from memgpt.cli.cli import attach +from memgpt.constants import DEFAULT_MEMGPT_MODEL, DEFAULT_PERSONA, DEFAULT_HUMAN +from memgpt.config import AgentConfig, MemGPTConfig # import memgpt.presets as presets # import memgpt.personas.personas as personas @@ -65,6 +68,28 @@ def test_load_directory(metadata_storage_connector, passage_storage_connector): sources = data_source_conn.get_all() print("All sources", [s.name for s in sources]) + # test loading into an agent + # create agent + agent_config = AgentConfig( + name="test_agent", + persona=DEFAULT_PERSONA, + human=DEFAULT_HUMAN, + model=DEFAULT_MEMGPT_MODEL, + ) + agent_config.save() + # create storage connector + conn = StorageConnector.get_storage_connector( + storage_type=passage_storage_connector, table_type=TableType.ARCHIVAL_MEMORY, agent_config=agent_config + ) + assert conn.size() == 0 + + # attach data + attach(agent=agent_config.name, data_source=name) + + # test to see if contained in storage + assert len(passages) == conn.size() + assert len(passages) == len(conn.get_all({"data_source": name})) + # test: delete source data_source_conn.delete({"name": name}) passages_conn.delete({"data_source": name})