fix: Various fixes to support source/agent deletion (#797)

This commit is contained in:
Charles Packer
2024-01-08 17:02:04 -08:00
committed by GitHub
2 changed files with 26 additions and 16 deletions

View File

@@ -22,6 +22,7 @@ from memgpt.openai_tools import openai_get_model_list, azure_openai_get_model_li
from memgpt.server.utils import shorten_key_middle
from memgpt.data_types import User, LLMConfig, EmbeddingConfig
from memgpt.metadata import MetadataStore
from memgpt.agent_store.storage import StorageConnector, TableType
app = typer.Typer()
@@ -647,22 +648,20 @@ def add(
def delete(option: str, name: str):
"""Delete a source from the archival memory."""
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
assert ms.get_user(user_id=user_id), f"User {user_id} does not exist"
try:
# delete from metadata
if option == "source":
conn = StorageConnector.get_metadata_storage_connector(TableType.DATA_SOURCES)
# Check if the source exists
if conn.get_all({"name": name}) == []:
raise ValueError(f"No source named '{name}'")
conn.delete({"name": name})
# It should now be deleted
assert conn.get_all({"name": name}) == [], f"Expected no sources named {name}, but got {conn.get_all({'name': name})}"
# delete metadata
source = ms.get_source(source_name=name, user_id=user_id)
ms.delete_source(source_id=source.id)
# delete from passages
conn = StorageConnector.get_storage_connector(TableType.PASSAGES)
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id=user_id)
conn.delete({"data_source": name})
assert (
@@ -670,6 +669,20 @@ def delete(option: str, name: str):
), f"Expected no passages with source {name}, but got {conn.get_all({'data_source': name})}"
# TODO: should we also delete from agents?
elif option == "agent":
agent = ms.get_agent(agent_name=name, user_id=user_id)
# recall memory
recall_conn = StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, config, user_id=user_id, agent_id=agent.id)
recall_conn.delete({"agent_id": agent.id})
# archival memory
archival_conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id=user_id, agent_id=agent.id)
archival_conn.delete({"agent_id": agent.id})
# metadata
ms.delete_agent(agent_id=agent.id)
else:
raise ValueError(f"Option {option} not implemented")

View File

@@ -42,19 +42,16 @@ def store_docs(name, docs, user_id=None, show_progress=True):
# record data source metadata
ms = MetadataStore(config)
user = ms.get_user(user_id)
print("USER", user)
if user is None:
raise ValueError(f"Cannot find user {user_id} in metadata store. Please run 'memgpt configure'.")
data_source = Source(user_id=user.id, name=name, created_at=datetime.now())
if not ms.get_source(user_id=user.id, source_name=name):
print("Trying to add...")
ms.create_source(data_source)
print("Created source", data_source)
else:
print(f"Source {name} for user {user.id} already exists")
# compute and record passages
print("USER ID", user.id)
storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user.id)
print("embedding config", user.default_embedding_config, user.default_embedding_config.embedding_dim)
embed_model = embedding_model(user.default_embedding_config)
orig_size = storage.size()