diff --git a/memgpt/agent.py b/memgpt/agent.py index 5bb286ed..ca50b4c4 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -185,12 +185,6 @@ class Agent(object): self.system, self.memory, ) - # Keep track of the total number of messages throughout all time - self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) - # self.messages_total_init = self.messages_total - self.messages_total_init = len(self._messages) - 1 - printd(f"Agent initialized, self.messages_total={self.messages_total}") - # Interface must implement: # - internal_monologue # - assistant_message @@ -209,6 +203,12 @@ class Agent(object): # creates a new agent object in the database self.persistence_manager.init(self) + # Keep track of the total number of messages throughout all time + self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system) + # self.messages_total_init = self.messages_total + self.messages_total_init = len(self._messages) - 1 + printd(f"Agent initialized, self.messages_total={self.messages_total}") + # State needed for heartbeat pausing self.pause_heartbeats_start = None self.pause_heartbeats_minutes = 0 diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index b37a38df..d10e593a 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -12,8 +12,13 @@ from typing import List from tqdm import tqdm import typer from memgpt.embeddings import embedding_model -from memgpt.connectors.storage import StorageConnector, Passage +from memgpt.connectors.storage import StorageConnector from memgpt.config import MemGPTConfig +from memgpt.data_types import Source, Passage, Document +from memgpt.utils import get_local_time +from memgpt.connectors.storage import StorageConnector, TableType + +from datetime import datetime from llama_index import ( VectorStoreIndex, @@ -28,8 +33,19 @@ app = typer.Typer() def store_docs(name, docs, show_progress=True): """Common function for embedding and storing documents""" - storage = StorageConnector.get_archival_storage_connector(name=name) config = MemGPTConfig.load() + + # record data source metadata + data_source = Source(user_id=config.anon_clientid, name=name, created_at=datetime.now()) + metadata_conn = StorageConnector.get_metadata_storage_connector(TableType.DATA_SOURCES) + if len(metadata_conn.get_all({"name": name})) > 0: + print(f"Data source {name} already exists in metadata, skipping.") + # TODO: should this error, or just add more data to this source? + else: + metadata_conn.insert(data_source) + + # compute and record passages + storage = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=config.archival_storage_type) embed_model = embedding_model() # use llama index to run embeddings code @@ -38,6 +54,8 @@ def store_docs(name, docs, show_progress=True): embed_dict = index._vector_store._data.embedding_dict node_dict = index._docstore.docs + # TODO: add document store + # gather passages passages = [] for node_id, node in tqdm(node_dict.items()): @@ -47,7 +65,15 @@ def store_docs(name, docs, show_progress=True): assert ( len(node.embedding) == config.embedding_dim ), f"Expected embedding dimension {config.embedding_dim}, got {len(node.embedding)}: {node.embedding}" - passages.append(Passage(text=text, embedding=vector)) + passages.append( + Passage( + user_id=config.anon_clientid, + text=text, + data_source=name, + embedding=node.embedding, + metadata=None, + ) + ) # insert into storage storage.insert_many(passages) diff --git a/memgpt/config.py b/memgpt/config.py index 5997ce14..1f29ded1 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -78,6 +78,11 @@ class MemGPTConfig: recall_storage_path: str = MEMGPT_DIR recall_storage_uri: str = None # TODO: eventually allow external vector DB + # database configs: metadata storage (sources, agents, data sources) + metadata_storage_type: str = "sqlite" + metadata_storage_path: str = MEMGPT_DIR + metadata_storage_uri: str = None + # database configs: agent state persistence_manager_type: str = None # in-memory, db persistence_manager_save_file: str = None # local file @@ -139,6 +144,9 @@ class MemGPTConfig: "recall_storage_type": get_field(config, "recall_storage", "type"), "recall_storage_path": get_field(config, "recall_storage", "path"), "recall_storage_uri": get_field(config, "recall_storage", "uri"), + "metadata_storage_type": get_field(config, "metadata_storage", "type"), + "metadata_storage_path": get_field(config, "metadata_storage", "path"), + "metadata_storage_uri": get_field(config, "metadata_storage", "uri"), "anon_clientid": get_field(config, "client", "anon_clientid"), "config_path": config_path, "memgpt_version": get_field(config, "version", "memgpt_version"), @@ -197,6 +205,11 @@ class MemGPTConfig: set_field(config, "recall_storage", "path", self.recall_storage_path) set_field(config, "recall_storage", "uri", self.recall_storage_uri) + # metadata storage + set_field(config, "metadata_storage", "type", self.metadata_storage_type) + set_field(config, "metadata_storage", "path", self.metadata_storage_path) + set_field(config, "metadata_storage", "uri", self.metadata_storage_uri) + # set version set_field(config, "version", "memgpt_version", memgpt.__version__) diff --git a/memgpt/connectors/chroma.py b/memgpt/connectors/chroma.py index 1fbd0c32..a1c5ac2b 100644 --- a/memgpt/connectors/chroma.py +++ b/memgpt/connectors/chroma.py @@ -19,7 +19,7 @@ class ChromaStorageConnector(StorageConnector): super().__init__(table_type=table_type, agent_config=agent_config) config = MemGPTConfig.load() - assert table_type == TableType.ARCHIVAL_MEMORY, "Chroma only supports archival memory" + assert table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES, "Chroma only supports archival memory" # create chroma client if config.archival_storage_path: @@ -51,7 +51,7 @@ class ChromaStorageConnector(StorageConnector): chroma_filters["$and"].append({key: {"$eq": value}}) return ids, chroma_filters - def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: + def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: offset = 0 ids, filters = self.get_filters(filters) while True: @@ -87,7 +87,7 @@ class ChromaStorageConnector(StorageConnector): for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"]) ] - def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]: + def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]: ids, filters = self.get_filters(filters) if self.collection.count() == 0: return [] @@ -114,7 +114,7 @@ class ChromaStorageConnector(StorageConnector): metadata.pop("embedding") if "created_at" in metadata: metadata["created_at"] = datetime_to_timestamp(metadata["created_at"]) - if "metadata" in metadata: + if "metadata" in metadata and metadata["metadata"] is not None: record_metadata = dict(metadata["metadata"]) metadata.pop("metadata") else: diff --git a/memgpt/connectors/db.py b/memgpt/connectors/db.py index ce82c55e..a302d6a7 100644 --- a/memgpt/connectors/db.py +++ b/memgpt/connectors/db.py @@ -27,7 +27,7 @@ from memgpt.connectors.storage import StorageConnector, TableType from memgpt.config import AgentConfig, MemGPTConfig from memgpt.constants import MEMGPT_DIR from memgpt.utils import printd -from memgpt.data_types import Record, Message, Passage +from memgpt.data_types import Record, Message, Passage, Source from datetime import datetime @@ -153,6 +153,30 @@ def get_db_model(table_name: str, table_type: TableType): class_name = f"{table_name.capitalize()}Model" Model = type(class_name, (MessageModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}}) return Model + elif table_type == TableType.DATA_SOURCES: + + class SourceModel(Base): + """Defines data model for storing Passages (consisting of text, embedding)""" + + __abstract__ = True # this line is necessary + + # Assuming passage_id is the primary key + # id = Column(UUID(as_uuid=True), primary_key=True, default=uuid.uuid4) + id = Column(CommonUUID, primary_key=True, default=uuid.uuid4) + user_id = Column(String, nullable=False) + name = Column(String, nullable=False) + created_at = Column(DateTime(timezone=True), server_default=func.now()) + + def __repr__(self): + return f"" + + def to_record(self): + return Source(id=self.id, user_id=self.user_id, name=self.name, created_at=self.created_at) + + """Create database model for table_name""" + class_name = f"{table_name.capitalize()}Model" + Model = type(class_name, (SourceModel,), {"__tablename__": table_name, "__table_args__": {"extend_existing": True}}) + return Model else: raise ValueError(f"Table type {table_type} not implemented") @@ -171,7 +195,7 @@ class SQLStorageConnector(StorageConnector): return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()] - def get_all_paginated(self, page_size: int, filters: Optional[Dict] = {}) -> Iterator[List[Record]]: + def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]: session = self.Session() offset = 0 filters = self.get_filters(filters) @@ -189,9 +213,10 @@ class SQLStorageConnector(StorageConnector): # Increment the offset to get the next chunk in the next iteration offset += page_size - def get_all(self, limit=10, filters: Optional[Dict] = {}) -> List[Record]: + def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]: session = self.Session() filters = self.get_filters(filters) + print("LIMIT", limit) db_records = session.query(self.db_model).filter(*filters).limit(limit).all() return [record.to_record() for record in db_records] @@ -287,14 +312,18 @@ class PostgresStorageConnector(SQLStorageConnector): super().__init__(table_type=table_type, agent_config=agent_config) # get storage URI - if table_type == TableType.ARCHIVAL_MEMORY: + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: self.uri = self.config.archival_storage_uri if self.config.archival_storage_uri is None: - raise ValueError(f"Must specifiy archival_storage_uri in config {config.config_path}") + raise ValueError(f"Must specifiy archival_storage_uri in config {self.config.config_path}") elif table_type == TableType.RECALL_MEMORY: self.uri = self.config.recall_storage_uri if self.config.recall_storage_uri is None: - raise ValueError(f"Must specifiy recall_storage_uri in config {config.config_path}") + raise ValueError(f"Must specifiy recall_storage_uri in config {self.config.config_path}") + elif table_type == TableType.DATA_SOURCES: + self.uri = self.config.metadata_storage_uri + if self.config.metadata_storage_uri is None: + raise ValueError(f"Must specifiy metadata_storage_uri in config {self.config.config_path}") else: raise ValueError(f"Table type {table_type} not implemented") # create table @@ -348,13 +377,17 @@ class SQLLiteStorageConnector(SQLStorageConnector): super().__init__(table_type=table_type, agent_config=agent_config) # get storage URI - if table_type == TableType.ARCHIVAL_MEMORY: + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: raise ValueError(f"Table type {table_type} not implemented") elif table_type == TableType.RECALL_MEMORY: # TODO: eventually implement URI option self.path = self.config.recall_storage_path if self.path is None: raise ValueError(f"Must specifiy recall_storage_path in config {self.config.recall_storage_path}") + elif table_type == TableType.DATA_SOURCES: + self.path = self.config.metadata_storage_path + if self.path is None: + raise ValueError(f"Must specifiy metadata_storage_path in config {self.config.metadata_storage_path}") else: raise ValueError(f"Table type {table_type} not implemented") diff --git a/memgpt/connectors/storage.py b/memgpt/connectors/storage.py index 06715208..81d97037 100644 --- a/memgpt/connectors/storage.py +++ b/memgpt/connectors/storage.py @@ -13,7 +13,7 @@ from tqdm import tqdm from memgpt.config import AgentConfig, MemGPTConfig -from memgpt.data_types import Record, Passage, Document, Message +from memgpt.data_types import Record, Passage, Document, Message, Source from memgpt.utils import printd @@ -30,10 +30,15 @@ class TableType: # table names used by MemGPT + +# agent tables RECALL_TABLE_NAME = "memgpt_recall_memory_agent" # agent memory ARCHIVAL_TABLE_NAME = "memgpt_archival_memory_agent" # agent memory -PASSAGE_TABLE_NAME = "memgpt_passages" # loads data sources -DOCUMENT_TABLE_NAME = "memgpt_documents" + +# external data source tables +SOURCE_TABLE_NAME = "memgpt_sources" # metadata for loaded data source +PASSAGE_TABLE_NAME = "memgpt_passages" # chunked/embedded passages (from source) +DOCUMENT_TABLE_NAME = "memgpt_documents" # original documents (from source) class StorageConnector: @@ -45,10 +50,12 @@ class StorageConnector: self.table_type = table_type # get object type - if table_type == TableType.ARCHIVAL_MEMORY: + if table_type == TableType.ARCHIVAL_MEMORY or table_type == TableType.PASSAGES: self.type = Passage elif table_type == TableType.RECALL_MEMORY: self.type = Message + elif table_type == TableType.DATA_SOURCES: + self.type = Source else: raise ValueError(f"Table type {table_type} not implemented") @@ -60,10 +67,11 @@ class StorageConnector: if self.table_type == TableType.ARCHIVAL_MEMORY or self.table_type == TableType.RECALL_MEMORY: # agent-specific table self.filters = {"user_id": self.user_id, "agent_id": self.agent_config.name} - - # setup base filters for user-specific tables - if self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS: + elif self.table_type == TableType.PASSAGES or self.table_type == TableType.DOCUMENTS or self.table_type == TableType.DATA_SOURCES: + # setup base filters for user-specific tables self.filters = {"user_id": self.user_id} + else: + self.filters = {} def get_filters(self, filters: Optional[Dict] = {}): # get all filters for query @@ -71,7 +79,6 @@ class StorageConnector: filter_conditions = {**self.filters, **filters} else: filter_conditions = self.filters - print("FILTERS", filter_conditions) return filter_conditions def generate_table_name(self, agent_config: AgentConfig, table_type: TableType): @@ -97,14 +104,14 @@ class StorageConnector: return PASSAGE_TABLE_NAME elif table_type == TableType.DOCUMENTS: return DOCUMENT_TABLE_NAME + elif table_type == TableType.DATA_SOURCES: + return SOURCE_TABLE_NAME else: raise ValueError(f"Table type {table_type} not implemented") @staticmethod def get_storage_connector(table_type: TableType, storage_type: Optional[str] = None, agent_config: Optional[AgentConfig] = None): - print("STORAGE", storage_type, table_type) - # read from config if not provided if storage_type is None: if table_type == TableType.ARCHIVAL_MEMORY: @@ -112,7 +119,6 @@ class StorageConnector: elif table_type == TableType.RECALL_MEMORY: storage_type = MemGPTConfig.load().recall_storage_type # TODO: other tables - print("read storage from config") if storage_type == "postgres": from memgpt.connectors.db import PostgresStorageConnector @@ -148,6 +154,11 @@ class StorageConnector: def get_recall_storage_connector(agent_config: Optional[AgentConfig] = None): return StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, agent_config=agent_config) + @staticmethod + def get_metadata_storage_connector(table_type: TableType): + storage_type = MemGPTConfig.load().metadata_storage_type + return StorageConnector.get_storage_connector(table_type, storage_type=storage_type) + @staticmethod def list_loaded_data(storage_type: Optional[str] = None): # TODO: modify this to simply list loaded data from a given user diff --git a/memgpt/data_types.py b/memgpt/data_types.py index b14d0f80..885b2ed0 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -15,21 +15,13 @@ class Record: Memory units are searched over by functions defined in the memory classes """ - def __init__(self, user_id: str, agent_id: str, text: str, id: Optional[str] = None): - self.user_id = user_id - self.agent_id = agent_id - self.text = text + def __init__(self, id: Optional[str] = None): if id is None: self.id = uuid.uuid4() else: self.id = id assert isinstance(self.id, uuid.UUID), f"UUID {self.id} must be a UUID type" - # todo: generate unique uuid - # todo: self.role = role (?) - - # def __repr__(self): - # pass class Message(Record): @@ -50,7 +42,10 @@ class Message(Record): embedding: Optional[np.ndarray] = None, id: Optional[str] = None, ): - super().__init__(user_id, agent_id, text, id) + super().__init__(id) + self.user_id = user_id + self.agent_id = agent_id + self.text = text self.model = model # model name (e.g. gpt-4) self.created_at = created_at @@ -74,7 +69,8 @@ class Document(Record): """A document represent a document loaded into MemGPT, which is broken down into passages.""" def __init__(self, user_id: str, text: str, data_source: str, document_id: Optional[str] = None): - super().__init__(user_id, agent_id, text, id) + super().__init__(id) + self.user_id = user_id self.text = text self.document_id = document_id self.data_source = data_source @@ -101,8 +97,10 @@ class Passage(Record): id: Optional[str] = None, metadata: Optional[dict] = {}, ): - super().__init__(user_id, agent_id, text, id) - print(self.text) + super().__init__(id) + self.user_id = user_id + self.agent_id = agent_id + self.text = text self.data_source = data_source self.embedding = embedding self.doc_id = doc_id @@ -110,3 +108,17 @@ class Passage(Record): # def __repr__(self): # pass + + +class Source(Record): + def __init__( + self, + user_id: str, + name: str, + created_at: Optional[str] = None, + id: Optional[str] = None, + ): + super().__init__(id) + self.name = name + self.user_id = user_id + self.created_at = created_at diff --git a/memgpt/memory.py b/memgpt/memory.py index 482860ce..f100147c 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -341,6 +341,9 @@ class BaseRecallMemory(RecallMemory): def save(self): self.storage.save() + def size(self): + return self.storage.size() + class EmbeddingArchivalMemory(ArchivalMemory): """Archival memory with embedding based search""" diff --git a/tests/test_load_archival.py b/tests/test_load_archival.py index ef260c10..78c02454 100644 --- a/tests/test_load_archival.py +++ b/tests/test_load_archival.py @@ -1,12 +1,14 @@ # import tempfile # import asyncio import os +import pytest +from memgpt.connectors.storage import StorageConnector, TableType # import asyncio -# from datasets import load_dataset +from datasets import load_dataset # import memgpt -# from memgpt.cli.cli_load import load_directory, load_database, load_webpage +from memgpt.cli.cli_load import load_directory, load_database, load_webpage # import memgpt.presets as presets # import memgpt.personas.personas as personas @@ -18,205 +20,53 @@ import os # import memgpt.interface # for printing to terminal -def test_postgres(): - return +# @pytest.mark.parametrize("storage_connector", ["sqllite", "postgres"]) +@pytest.mark.parametrize("metadata_storage_connector", ["sqlite"]) +@pytest.mark.parametrize("passage_storage_connector", ["chroma"]) +def test_load_directory(metadata_storage_connector, passage_storage_connector): - # override config path with enviornment variable - # TODO: make into temporary file - os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg" - print("env", os.getenv("MEMGPT_CONFIG_PATH")) - config = memgpt.config.MemGPTConfig(archival_storage_type="postgres", config_path=os.getenv("MEMGPT_CONFIG_PATH")) - print(config) - config.save() - # exit() + data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES) + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector) - name = "tmp_hf_dataset2" + # load hugging face dataset + # dataset_name = "MemGPT/example_short_stories" + # dataset = load_dataset(dataset_name) - dataset = load_dataset("MemGPT/example_short_stories") + # cache_dir = os.getenv("HF_DATASETS_CACHE") + # if cache_dir is None: + # # Construct the default path if the environment variable is not set. + # cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") + # print("HF Directory", cache_dir) + name = "test_dataset" + cache_dir = "CONTRIBUTING.md" - cache_dir = os.getenv("HF_DATASETS_CACHE") - if cache_dir is None: - # Construct the default path if the environment variable is not set. - cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") + # clear out data + data_source_conn.delete_table() + passages_conn.delete_table() + data_source_conn = StorageConnector.get_storage_connector(storage_type=metadata_storage_connector, table_type=TableType.DATA_SOURCES) + passages_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, storage_type=passage_storage_connector) - load_directory( - name=name, - input_dir=cache_dir, - recursive=True, - ) + # test: load directory + load_directory(name=name, input_dir=None, input_files=[cache_dir], recursive=False) # cache_dir, + # test to see if contained in storage + sources = data_source_conn.get_all({"name": name}) + assert len(sources) == 1, f"Expected 1 source, but got {len(sources)}" + assert sources[0].name == name, f"Expected name {name}, but got {sources[0].name}" + print("Source", sources) -def test_lancedb(): - return + # test to see if contained in storage + passages = passages_conn.get_all({"data_source": name}) + assert len(passages) > 0, f"Expected >0 passages, but got {len(passages)}" + assert [p.data_source == name for p in passages] + print("Passages", passages) - subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"]) - import lancedb # Try to import again after installing + # test: listing sources + sources = data_source_conn.get_all() + print("All sources", [s.name for s in sources]) - # override config path with enviornment variable - # TODO: make into temporary file - os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg" - print("env", os.getenv("MEMGPT_CONFIG_PATH")) - config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb", config_path=os.getenv("MEMGPT_CONFIG_PATH")) - print(config) - config.save() - - # loading dataset from hugging face - name = "tmp_hf_dataset" - - dataset = load_dataset("MemGPT/example_short_stories") - - cache_dir = os.getenv("HF_DATASETS_CACHE") - if cache_dir is None: - # Construct the default path if the environment variable is not set. - cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") - - config = memgpt.config.MemGPTConfig(archival_storage_type="lancedb") - - load_directory( - name=name, - input_dir=cache_dir, - recursive=True, - ) - - -def test_chroma(): - return - - subprocess.check_call([sys.executable, "-m", "pip", "install", "chromadb"]) - import chromadb # Try to import again after installing - - # override config path with enviornment variable - # TODO: make into temporary file - os.environ["MEMGPT_CONFIG_PATH"] = "test_config.cfg" - print("env", os.getenv("MEMGPT_CONFIG_PATH")) - config = memgpt.config.MemGPTConfig(archival_storage_type="chroma", config_path=os.getenv("MEMGPT_CONFIG_PATH")) - print(config) - config.save() - # exit() - - name = "tmp_hf_dataset" - - dataset = load_dataset("MemGPT/example_short_stories") - - cache_dir = os.getenv("HF_DATASETS_CACHE") - if cache_dir is None: - # Construct the default path if the environment variable is not set. - cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") - - config = memgpt.config.MemGPTConfig(archival_storage_type="chroma") - - load_directory( - name=name, - input_dir=cache_dir, - recursive=True, - ) - - -def test_load_directory(): - return - # downloading hugging face dataset (if does not exist) - dataset = load_dataset("MemGPT/example_short_stories") - - cache_dir = os.getenv("HF_DATASETS_CACHE") - - if cache_dir is None: - # Construct the default path if the environment variable is not set. - cache_dir = os.path.join(os.path.expanduser("~"), ".cache", "huggingface", "datasets") - - # load directory - print("Loading dataset into index...") - print(cache_dir) - load_directory( - name="tmp_hf_dataset", - input_dir=cache_dir, - recursive=True, - ) - - # create agents with defaults - agent_config = AgentConfig( - persona=personas.DEFAULT, - human=humans.DEFAULT, - model=DEFAULT_MEMGPT_MODEL, - data_source="tmp_hf_dataset", - ) - - # create state manager based off loaded data - persistence_manager = LocalStateManager(agent_config=agent_config) - - # create agent - memgpt_agent = presets.use_preset( - presets.DEFAULT_PRESET, - agent_config, - DEFAULT_MEMGPT_MODEL, - personas.get_persona_text(personas.DEFAULT), - humans.get_human_text(humans.DEFAULT), - memgpt.interface, - persistence_manager, - ) - - def query(q): - res = asyncio.run(memgpt_agent.archival_memory_search(q)) - return res - - results = query("cinderella be getting sick") - assert "Cinderella" in results, f"Expected 'Cinderella' in results, but got {results}" - - -def test_load_webpage(): - pass - - -def test_load_database(): - return - from sqlalchemy import create_engine, MetaData - import pandas as pd - - db_path = "memgpt/personas/examples/sqldb/test.db" - engine = create_engine(f"sqlite:///{db_path}") - - # Create a MetaData object and reflect the database to get table information. - metadata = MetaData() - metadata.reflect(bind=engine) - - # Get a list of table names from the reflected metadata. - table_names = metadata.tables.keys() - - print(table_names) - - # Define a SQL query to retrieve data from a table (replace 'your_table_name' with your actual table name). - query = f"SELECT * FROM {list(table_names)[0]}" - - # Use Pandas to read data from the database into a DataFrame. - df = pd.read_sql_query(query, engine) - print(df) - - load_database( - name="tmp_db_dataset", - # engine=engine, - dump_path=db_path, - query=f"SELECT * FROM {list(table_names)[0]}", - ) - - # create agents with defaults - agent_config = AgentConfig( - persona=personas.DEFAULT, - human=humans.DEFAULT, - model=DEFAULT_MEMGPT_MODEL, - data_source="tmp_hf_dataset", - ) - - # create state manager based off loaded data - persistence_manager = LocalStateManager(agent_config=agent_config) - - # create agent - memgpt_agent = presets.use_preset( - presets.DEFAULT, - agent_config, - DEFAULT_MEMGPT_MODEL, - personas.get_persona_text(personas.DEFAULT), - humans.get_human_text(humans.DEFAULT), - memgpt.interface, - persistence_manager, - ) - print("Successfully loaded into index") - assert True + # test: delete source + data_source_conn.delete({"name": name}) + passages_conn.delete({"data_source": name}) + assert len(data_source_conn.get_all({"name": name})) == 0 + assert len(passages_conn.get_all({"data_source": name})) == 0