Add load and load_and_attach functions to memgpt autogen agent. (#430)

* Add load and load_and_attach functions to memgpt autogen agent.

* Only recompute files if dataset does not exist.
This commit is contained in:
Wes
2023-11-14 23:51:21 -07:00
committed by GitHub
parent a23ba80ac8
commit 2597ff2eb8
2 changed files with 26 additions and 1 deletions

View File

@@ -12,6 +12,7 @@ from memgpt.personas import personas
from memgpt.humans import humans
from memgpt.config import AgentConfig
from memgpt.cli.cli import attach
from memgpt.cli.cli_load import load_directory, load_webpage, load_index, load_database, load_vector_database
from memgpt.connectors.storage import StorageConnector
@@ -172,6 +173,22 @@ class MemGPTAgent(ConversableAgent):
self._is_termination_msg = is_termination_msg if is_termination_msg is not None else (lambda x: x == "TERMINATE")
def load(self, name: str, type: str, **kwargs):
# call load function based on type
match type:
case "directory":
load_directory(name=name, **kwargs)
case "webpage":
load_webpage(name=name, **kwargs)
case "index":
load_index(name=name, **kwargs)
case "database":
load_database(name=name, **kwargs)
case "vector_database":
load_vector_database(name=name, **kwargs)
case _:
raise ValueError(f"Invalid data source type {type}")
def attach(self, data_source: str):
# attach new data
attach(self.agent.config.name, data_source)
@@ -182,6 +199,15 @@ class MemGPTAgent(ConversableAgent):
# reload agent with new data source
self.agent.persistence_manager.archival_memory.storage = StorageConnector.get_storage_connector(agent_config=self.agent.config)
def load_and_attach(self, name: str, type: str, force=False, **kwargs):
# check if data source already exists
if name in StorageConnector.list_loaded_data() and not force:
print(f"Data source {name} already exists. Use force=True to overwrite.")
self.attach(name)
else:
self.load(name, type, **kwargs)
self.attach(name)
def format_other_agent_message(self, msg):
if "name" in msg:
user_message = f"{msg['name']}: {msg['content']}"

View File

@@ -93,7 +93,6 @@ def load_directory(
assert input_dir is not None, "Must provide input directory if recursive is True."
if input_dir is not None:
assert len(input_files) == 0, "Either load in a list of files OR a directory."
reader = SimpleDirectoryReader(
input_dir=input_dir,
recursive=recursive,