diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index f4c59ce3..dee58e9c 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -12,6 +12,7 @@ from memgpt.personas import personas from memgpt.humans import humans from memgpt.config import AgentConfig from memgpt.cli.cli import attach +from memgpt.cli.cli_load import load_directory, load_webpage, load_index, load_database, load_vector_database from memgpt.connectors.storage import StorageConnector @@ -172,6 +173,22 @@ class MemGPTAgent(ConversableAgent): self._is_termination_msg = is_termination_msg if is_termination_msg is not None else (lambda x: x == "TERMINATE") + def load(self, name: str, type: str, **kwargs): + # call load function based on type + match type: + case "directory": + load_directory(name=name, **kwargs) + case "webpage": + load_webpage(name=name, **kwargs) + case "index": + load_index(name=name, **kwargs) + case "database": + load_database(name=name, **kwargs) + case "vector_database": + load_vector_database(name=name, **kwargs) + case _: + raise ValueError(f"Invalid data source type {type}") + def attach(self, data_source: str): # attach new data attach(self.agent.config.name, data_source) @@ -182,6 +199,15 @@ class MemGPTAgent(ConversableAgent): # reload agent with new data source self.agent.persistence_manager.archival_memory.storage = StorageConnector.get_storage_connector(agent_config=self.agent.config) + def load_and_attach(self, name: str, type: str, force=False, **kwargs): + # check if data source already exists + if name in StorageConnector.list_loaded_data() and not force: + print(f"Data source {name} already exists. Use force=True to overwrite.") + self.attach(name) + else: + self.load(name, type, **kwargs) + self.attach(name) + def format_other_agent_message(self, msg): if "name" in msg: user_message = f"{msg['name']}: {msg['content']}" diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 62939a35..5a05d257 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -93,7 +93,6 @@ def load_directory( assert input_dir is not None, "Must provide input directory if recursive is True." if input_dir is not None: - assert len(input_files) == 0, "Either load in a list of files OR a directory." reader = SimpleDirectoryReader( input_dir=input_dir, recursive=recursive,