diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 3f73205b..45fc1606 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -10,7 +10,6 @@ from prettytable.colortable import ColorTable, Themes from tqdm import tqdm from memgpt import utils -from memgpt.agent_store.storage import StorageConnector, TableType from memgpt.config import MemGPTConfig from memgpt.constants import LLM_MAX_TOKENS, MEMGPT_DIR from memgpt.credentials import SUPPORTED_AUTH_TYPES, MemGPTCredentials @@ -38,7 +37,6 @@ from memgpt.local_llm.constants import ( ) from memgpt.local_llm.utils import get_available_wrappers from memgpt.metadata import MetadataStore -from memgpt.models.pydantic_models import PersonaModel from memgpt.server.utils import shorten_key_middle app = typer.Typer() @@ -1096,18 +1094,16 @@ class ListChoice(str, Enum): def list(arg: Annotated[ListChoice, typer.Argument]): from memgpt.client.client import create_client - config = MemGPTConfig.load() - ms = MetadataStore(config) - user_id = uuid.UUID(config.anon_clientid) client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS")) table = ColorTable(theme=Themes.OCEAN) if arg == ListChoice.agents: """List all agents""" table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"] - for agent in tqdm(ms.list_agents(user_id=user_id)): - source_ids = ms.list_attached_sources(agent_id=agent.id) + for agent in tqdm(client.list_agents()): + # TODO: add this function + source_ids = client.list_attached_sources(agent_id=agent.id) assert all([source_id is not None and isinstance(source_id, uuid.UUID) for source_id in source_ids]) - sources = [ms.get_source(source_id=source_id) for source_id in source_ids] + sources = [client.get_source(source_id=source_id) for source_id in source_ids] assert all([source is not None and isinstance(source, Source)] for source in sources) source_names = [source.name for source in sources if source is not None] table.add_row( @@ -1116,8 +1112,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]): agent.llm_config.model, agent.embedding_config.embedding_model, agent.embedding_config.embedding_dim, - agent.persona, - agent.human, + agent._metadata.get("persona", ""), + agent._metadata.get("human", ""), ",".join(source_names), utils.format_datetime(agent.created_at), ] @@ -1132,25 +1128,21 @@ def list(arg: Annotated[ListChoice, typer.Argument]): elif arg == ListChoice.personas: """List all personas""" table.field_names = ["Name", "Text"] - for persona in ms.list_personas(user_id=user_id): + for persona in client.list_personas(): table.add_row([persona.name, persona.text.replace("\n", "")[:100]]) print(table) elif arg == ListChoice.sources: """List all data sources""" # create table - table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"] + table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At"] # TODO: eventually look accross all storage connections # TODO: add data source stats # TODO: connect to agents # get all sources - for source in ms.list_sources(user_id=user_id): + for source in client.list_sources(): # get attached agents - agent_ids = ms.list_attached_agents(source_id=source.id) - agent_states = [ms.get_agent(agent_id=agent_id) for agent_id in agent_ids] - agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None] - table.add_row( [ source.name, @@ -1158,13 +1150,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]): source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), - ",".join(agent_names), ] ) print(table) else: raise ValueError(f"Unknown argument {arg}") + return table @app.command() @@ -1177,25 +1169,20 @@ def add( """Add a person/human""" from memgpt.client.client import create_client - config = MemGPTConfig.load() - user_id = uuid.UUID(config.anon_clientid) - ms = MetadataStore(config) client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS")) if filename: # read from file assert text is None, "Cannot specify both text and filename" with open(filename, "r", encoding="utf-8") as f: text = f.read() if option == "persona": - persona = ms.get_persona(name=name) + persona = client.get_persona(name) if persona: # config if user wants to overwrite if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask(): return - persona.text = text - ms.update_persona(persona) + client.update_persona(name=name, text=text) else: - persona = PersonaModel(name=name, text=text, user_id=user_id) - ms.add_persona(persona) + client.create_persona(name=name, text=text) elif option == "human": human = client.get_human(name=name) @@ -1203,10 +1190,9 @@ def add( # config if user wants to overwrite if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask(): return - human.text = text - client.update_human(human) + client.update_human(name=name, text=text) else: - human = client.create_human(name=name, human=text) + human = client.create_human(name=name, text=text) else: raise ValueError(f"Unknown kind {option}") @@ -1216,53 +1202,24 @@ def delete(option: str, name: str): """Delete a source from the archival memory.""" from memgpt.client.client import create_client - config = MemGPTConfig.load() - user_id = uuid.UUID(config.anon_clientid) client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_API_KEY")) - ms = MetadataStore(config) - assert ms.get_user(user_id=user_id), f"User {user_id} does not exist" - try: # delete from metadata if option == "source": # delete metadata - source = ms.get_source(source_name=name, user_id=user_id) + source = client.get_source(name) assert source is not None, f"Source {name} does not exist" - ms.delete_source(source_id=source.id) - - # delete from passages - conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id=user_id) - conn.delete({"data_source": name}) - - assert ( - conn.get_all({"data_source": name}) == [] - ), f"Expected no passages with source {name}, but got {conn.get_all({'data_source': name})}" - - # TODO: should we also delete from agents? + client.delete_source(source_id=source.id) elif option == "agent": - agent = ms.get_agent(agent_name=name, user_id=user_id) - assert agent is not None, f"Agent {name} for user_id {user_id} does not exist" - - # recall memory - recall_conn = StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, config, user_id=user_id, agent_id=agent.id) - recall_conn.delete({"agent_id": agent.id}) - - # archival memory - archival_conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id=user_id, agent_id=agent.id) - archival_conn.delete({"agent_id": agent.id}) - - # metadata - ms.delete_agent(agent_id=agent.id) - + client.delete_agent(name=name) elif option == "human": human = client.get_human(name=name) assert human is not None, f"Human {name} does not exist" client.delete_human(name=name) elif option == "persona": - persona = ms.get_persona(name=name) + persona = client.get_persona(name=name) assert persona is not None, f"Persona {name} does not exist" - ms.delete_persona(name=name) - assert ms.get_persona(name=name) is None, f"Persona {name} still exists" + client.delete_persona(name=name) else: raise ValueError(f"Option {option} not implemented") diff --git a/memgpt/client/client.py b/memgpt/client/client.py index 17ea6478..e92c0e9b 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -166,7 +166,7 @@ class AbstractClient(object): """List all humans.""" raise NotImplementedError - def create_human(self, name: str, human: str): + def create_human(self, name: str, text: str): """Create a human.""" raise NotImplementedError @@ -174,7 +174,7 @@ class AbstractClient(object): """List all personas.""" raise NotImplementedError - def create_persona(self, name: str, persona: str): + def create_persona(self, name: str, text: str): """Create a persona.""" raise NotImplementedError @@ -498,8 +498,8 @@ class RESTClient(AbstractClient): response = requests.get(f"{self.base_url}/api/humans", headers=self.headers) return ListHumansResponse(**response.json()) - def create_human(self, name: str, human: str) -> HumanModel: - data = {"name": name, "text": human} + def create_human(self, name: str, text: str) -> HumanModel: + data = {"name": name, "text": text} response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create human: {response.text}") @@ -509,8 +509,8 @@ class RESTClient(AbstractClient): response = requests.get(f"{self.base_url}/api/personas", headers=self.headers) return ListPersonasResponse(**response.json()) - def create_persona(self, name: str, persona: str) -> PersonaModel: - data = {"name": name, "text": persona} + def create_persona(self, name: str, text: str) -> PersonaModel: + data = {"name": name, "text": text} response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers) if response.status_code != 200: raise ValueError(f"Failed to create persona: {response.text}") @@ -834,11 +834,11 @@ class LocalClient(AbstractClient): # humans / personas - def create_human(self, name: str, human: str): - return self.server.add_human(HumanModel(name=name, text=human, user_id=self.user_id)) + def create_human(self, name: str, text: str): + return self.server.add_human(HumanModel(name=name, text=text, user_id=self.user_id)) - def create_persona(self, name: str, persona: str): - return self.server.add_persona(PersonaModel(name=name, text=persona, user_id=self.user_id)) + def create_persona(self, name: str, text: str): + return self.server.add_persona(PersonaModel(name=name, text=text, user_id=self.user_id)) def list_humans(self): return self.server.list_humans(user_id=self.user_id if self.user_id else self.user_id) @@ -846,8 +846,10 @@ class LocalClient(AbstractClient): def get_human(self, name: str): return self.server.get_human(name=name, user_id=self.user_id) - def update_human(self, human: HumanModel): - return self.server.update_human(human=human) + def update_human(self, name: str, text: str): + human = self.get_human(name) + human.text = text + return self.server.update_human(human) def delete_human(self, name: str): return self.server.delete_human(name, self.user_id) @@ -858,8 +860,10 @@ class LocalClient(AbstractClient): def get_persona(self, name: str): return self.server.get_persona(name=name, user_id=self.user_id) - def update_persona(self, persona: PersonaModel): - return self.server.update_persona(persona=persona) + def update_persona(self, name: str, text: str): + persona = self.get_persona(name) + persona.text = text + return self.server.update_persona(persona) def delete_persona(self, name: str): return self.server.delete_persona(name, self.user_id) @@ -926,9 +930,19 @@ class LocalClient(AbstractClient): def create_source(self, name: str): return self.server.create_source(user_id=self.user_id, name=name) + def delete_source(self, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None): + # TODO: delete source data + self.server.delete_source(user_id=self.user.id, source_id=source_id, source_name=source_name) + + def get_source(self, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None): + return self.server.ms.get_source(user_id=self.user_id, source_id=source_id, source_name=source_name) + def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID): self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id) + def list_sources(self): + return self.server.list_all_sources(user_id=self.user_id) + def get_agent_archival_memory( self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000 ): @@ -973,3 +987,6 @@ class LocalClient(AbstractClient): ) return ListModelsResponse(models=[llm_config]) + + def list_attached_sources(self, agent_id: uuid.UUID): + return self.server.list_attached_sources(agent_id=agent_id) diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index a7e78196..4de9244a 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -191,3 +191,7 @@ class DocumentModel(BaseModel): data_source: str = Field(..., description="The data source of the document.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True) metadata: Optional[Dict] = Field({}, description="The metadata of the document.") + + +class UserModel(BaseModel): + user_id: uuid.UUID = Field(..., description="The unique identifier of the user.") diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index cabcae9a..eab528e0 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -51,7 +51,7 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. """ interface.clear() - agents_data = server.list_agents(user_id=user_id) + agents_data = server.list_agents_legacy(user_id=user_id) return ListAgentsResponse(**agents_data) @router.post("/agents", tags=["agents"], response_model=CreateAgentResponse) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index d2c7956d..e9366d5b 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -217,7 +217,7 @@ class SyncServer(LockingServer): # Initialize the connection to the DB self.config = MemGPTConfig.load() - logger.info(f"loading configuration from '{self.config.config_path}'") + logger.debug(f"loading configuration from '{self.config.config_path}'") assert self.config.persona is not None, "Persona must be set in the config" assert self.config.human is not None, "Human must be set in the config" @@ -805,6 +805,8 @@ class SyncServer(LockingServer): user_id: uuid.UUID, agent_id: uuid.UUID, ): + # TODO: delete agent data + if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: @@ -867,10 +869,21 @@ class SyncServer(LockingServer): } return agent_config - # TODO make return type pydantic def list_agents( self, user_id: uuid.UUID, + ) -> List[AgentState]: + """List all available agents to a user""" + if self.ms.get_user(user_id=user_id) is None: + raise ValueError(f"User user_id={user_id} does not exist") + + agents_states = self.ms.list_agents(user_id=user_id) + return agents_states + + # TODO make return type pydantic + def list_agents_legacy( + self, + user_id: uuid.UUID, ) -> dict: """List all available agents to a user""" if self.ms.get_user(user_id=user_id) is None: @@ -1191,12 +1204,18 @@ class SyncServer(LockingServer): # TODO: mark what is in-context versus not return cursor, json_records - def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> AgentState: + def get_agent_config(self, user_id: uuid.UUID, agent_id: Optional[uuid.UUID], agent_name: Optional[str] = None) -> AgentState: """Return the config of an agent""" if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") - if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: - raise ValueError(f"Agent agent_id={agent_id} does not exist") + if agent_id: + if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: + raise ValueError(f"Agent agent_id={agent_id} does not exist") + else: + agent_state = self.ms.get_agent(agent_name=agent_name, user_id=user_id) + if agent_state is None: + raise ValueError(f"Agent agent_name={agent_name} does not exist") + agent_id = agent_state.id # Get the agent object (loaded in memory) memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id) diff --git a/tests/test_cli.py b/tests/test_cli.py index f15ab7fc..58617240 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -4,6 +4,10 @@ import sys subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"]) import pexpect +from prettytable.colortable import ColorTable + +from memgpt.cli.cli_config import ListChoice, add, delete +from memgpt.cli.cli_config import list as list_command from .constants import TIMEOUT from .utils import create_config @@ -11,6 +15,39 @@ from .utils import create_config # def test_configure_memgpt(): # configure_memgpt() +options = [ListChoice.agents, ListChoice.sources, ListChoice.humans, ListChoice.personas] + + +def test_cli_list(): + for option in options: + output = list_command(arg=option) + # check if is a list + assert isinstance(output, ColorTable) + + +def test_cli_config(): + + # test add + for option in ["human", "persona"]: + + # create initial + add(option=option, name="test", text="test data") + + ## update + # filename = "test.txt" + # open(filename, "w").write("test data new") + # child = pexpect.spawn(f"poetry run memgpt add --{str(option)} {filename} --name test --strip-ui") + # child.expect("Human test already exists. Overwrite?", timeout=TIMEOUT) + # child.sendline() + # child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit + # child.close() + + for row in list_command(arg=ListChoice.humans if option == "human" else ListChoice.personas): + if row[0] == "test": + assert "test data" in row + # delete + delete(option=option, name="test") + def test_save_load(): # configure_memgpt() # rely on configure running first^ diff --git a/tests/test_client.py b/tests/test_client.py index 9e855e62..1b406c2c 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -196,14 +196,14 @@ def test_humans_personas(client, agent): persona_name = "TestPersona" if client.get_persona(persona_name): client.delete_persona(persona_name) - persona = client.create_persona(name=persona_name, persona="Persona text") + persona = client.create_persona(name=persona_name, text="Persona text") assert persona.name == persona_name assert persona.text == "Persona text", "Creating persona failed" human_name = "TestHuman" if client.get_human(human_name): client.delete_human(human_name) - human = client.create_human(name=human_name, human="Human text") + human = client.create_human(name=human_name, text="Human text") assert human.name == human_name assert human.text == "Human text", "Creating human failed"