fix: fix CLI commands by migrating to Python client (#1563)

This commit is contained in:
Sarah Wooders
2024-07-23 19:52:18 -07:00
committed by GitHub
parent c504529018
commit fe0818d34c
7 changed files with 119 additions and 85 deletions

View File

@@ -10,7 +10,6 @@ from prettytable.colortable import ColorTable, Themes
from tqdm import tqdm
from memgpt import utils
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.constants import LLM_MAX_TOKENS, MEMGPT_DIR
from memgpt.credentials import SUPPORTED_AUTH_TYPES, MemGPTCredentials
@@ -38,7 +37,6 @@ from memgpt.local_llm.constants import (
)
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import PersonaModel
from memgpt.server.utils import shorten_key_middle
app = typer.Typer()
@@ -1096,18 +1094,16 @@ class ListChoice(str, Enum):
def list(arg: Annotated[ListChoice, typer.Argument]):
from memgpt.client.client import create_client
config = MemGPTConfig.load()
ms = MetadataStore(config)
user_id = uuid.UUID(config.anon_clientid)
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
table = ColorTable(theme=Themes.OCEAN)
if arg == ListChoice.agents:
"""List all agents"""
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
for agent in tqdm(ms.list_agents(user_id=user_id)):
source_ids = ms.list_attached_sources(agent_id=agent.id)
for agent in tqdm(client.list_agents()):
# TODO: add this function
source_ids = client.list_attached_sources(agent_id=agent.id)
assert all([source_id is not None and isinstance(source_id, uuid.UUID) for source_id in source_ids])
sources = [ms.get_source(source_id=source_id) for source_id in source_ids]
sources = [client.get_source(source_id=source_id) for source_id in source_ids]
assert all([source is not None and isinstance(source, Source)] for source in sources)
source_names = [source.name for source in sources if source is not None]
table.add_row(
@@ -1116,8 +1112,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
agent.llm_config.model,
agent.embedding_config.embedding_model,
agent.embedding_config.embedding_dim,
agent.persona,
agent.human,
agent._metadata.get("persona", ""),
agent._metadata.get("human", ""),
",".join(source_names),
utils.format_datetime(agent.created_at),
]
@@ -1132,25 +1128,21 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
elif arg == ListChoice.personas:
"""List all personas"""
table.field_names = ["Name", "Text"]
for persona in ms.list_personas(user_id=user_id):
for persona in client.list_personas():
table.add_row([persona.name, persona.text.replace("\n", "")[:100]])
print(table)
elif arg == ListChoice.sources:
"""List all data sources"""
# create table
table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At"]
# TODO: eventually look accross all storage connections
# TODO: add data source stats
# TODO: connect to agents
# get all sources
for source in ms.list_sources(user_id=user_id):
for source in client.list_sources():
# get attached agents
agent_ids = ms.list_attached_agents(source_id=source.id)
agent_states = [ms.get_agent(agent_id=agent_id) for agent_id in agent_ids]
agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None]
table.add_row(
[
source.name,
@@ -1158,13 +1150,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
source.embedding_model,
source.embedding_dim,
utils.format_datetime(source.created_at),
",".join(agent_names),
]
)
print(table)
else:
raise ValueError(f"Unknown argument {arg}")
return table
@app.command()
@@ -1177,25 +1169,20 @@ def add(
"""Add a person/human"""
from memgpt.client.client import create_client
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
if filename: # read from file
assert text is None, "Cannot specify both text and filename"
with open(filename, "r", encoding="utf-8") as f:
text = f.read()
if option == "persona":
persona = ms.get_persona(name=name)
persona = client.get_persona(name)
if persona:
# config if user wants to overwrite
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
return
persona.text = text
ms.update_persona(persona)
client.update_persona(name=name, text=text)
else:
persona = PersonaModel(name=name, text=text, user_id=user_id)
ms.add_persona(persona)
client.create_persona(name=name, text=text)
elif option == "human":
human = client.get_human(name=name)
@@ -1203,10 +1190,9 @@ def add(
# config if user wants to overwrite
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
return
human.text = text
client.update_human(human)
client.update_human(name=name, text=text)
else:
human = client.create_human(name=name, human=text)
human = client.create_human(name=name, text=text)
else:
raise ValueError(f"Unknown kind {option}")
@@ -1216,53 +1202,24 @@ def delete(option: str, name: str):
"""Delete a source from the archival memory."""
from memgpt.client.client import create_client
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_API_KEY"))
ms = MetadataStore(config)
assert ms.get_user(user_id=user_id), f"User {user_id} does not exist"
try:
# delete from metadata
if option == "source":
# delete metadata
source = ms.get_source(source_name=name, user_id=user_id)
source = client.get_source(name)
assert source is not None, f"Source {name} does not exist"
ms.delete_source(source_id=source.id)
# delete from passages
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id=user_id)
conn.delete({"data_source": name})
assert (
conn.get_all({"data_source": name}) == []
), f"Expected no passages with source {name}, but got {conn.get_all({'data_source': name})}"
# TODO: should we also delete from agents?
client.delete_source(source_id=source.id)
elif option == "agent":
agent = ms.get_agent(agent_name=name, user_id=user_id)
assert agent is not None, f"Agent {name} for user_id {user_id} does not exist"
# recall memory
recall_conn = StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, config, user_id=user_id, agent_id=agent.id)
recall_conn.delete({"agent_id": agent.id})
# archival memory
archival_conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id=user_id, agent_id=agent.id)
archival_conn.delete({"agent_id": agent.id})
# metadata
ms.delete_agent(agent_id=agent.id)
client.delete_agent(name=name)
elif option == "human":
human = client.get_human(name=name)
assert human is not None, f"Human {name} does not exist"
client.delete_human(name=name)
elif option == "persona":
persona = ms.get_persona(name=name)
persona = client.get_persona(name=name)
assert persona is not None, f"Persona {name} does not exist"
ms.delete_persona(name=name)
assert ms.get_persona(name=name) is None, f"Persona {name} still exists"
client.delete_persona(name=name)
else:
raise ValueError(f"Option {option} not implemented")

View File

@@ -166,7 +166,7 @@ class AbstractClient(object):
"""List all humans."""
raise NotImplementedError
def create_human(self, name: str, human: str):
def create_human(self, name: str, text: str):
"""Create a human."""
raise NotImplementedError
@@ -174,7 +174,7 @@ class AbstractClient(object):
"""List all personas."""
raise NotImplementedError
def create_persona(self, name: str, persona: str):
def create_persona(self, name: str, text: str):
"""Create a persona."""
raise NotImplementedError
@@ -498,8 +498,8 @@ class RESTClient(AbstractClient):
response = requests.get(f"{self.base_url}/api/humans", headers=self.headers)
return ListHumansResponse(**response.json())
def create_human(self, name: str, human: str) -> HumanModel:
data = {"name": name, "text": human}
def create_human(self, name: str, text: str) -> HumanModel:
data = {"name": name, "text": text}
response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create human: {response.text}")
@@ -509,8 +509,8 @@ class RESTClient(AbstractClient):
response = requests.get(f"{self.base_url}/api/personas", headers=self.headers)
return ListPersonasResponse(**response.json())
def create_persona(self, name: str, persona: str) -> PersonaModel:
data = {"name": name, "text": persona}
def create_persona(self, name: str, text: str) -> PersonaModel:
data = {"name": name, "text": text}
response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create persona: {response.text}")
@@ -834,11 +834,11 @@ class LocalClient(AbstractClient):
# humans / personas
def create_human(self, name: str, human: str):
return self.server.add_human(HumanModel(name=name, text=human, user_id=self.user_id))
def create_human(self, name: str, text: str):
return self.server.add_human(HumanModel(name=name, text=text, user_id=self.user_id))
def create_persona(self, name: str, persona: str):
return self.server.add_persona(PersonaModel(name=name, text=persona, user_id=self.user_id))
def create_persona(self, name: str, text: str):
return self.server.add_persona(PersonaModel(name=name, text=text, user_id=self.user_id))
def list_humans(self):
return self.server.list_humans(user_id=self.user_id if self.user_id else self.user_id)
@@ -846,8 +846,10 @@ class LocalClient(AbstractClient):
def get_human(self, name: str):
return self.server.get_human(name=name, user_id=self.user_id)
def update_human(self, human: HumanModel):
return self.server.update_human(human=human)
def update_human(self, name: str, text: str):
human = self.get_human(name)
human.text = text
return self.server.update_human(human)
def delete_human(self, name: str):
return self.server.delete_human(name, self.user_id)
@@ -858,8 +860,10 @@ class LocalClient(AbstractClient):
def get_persona(self, name: str):
return self.server.get_persona(name=name, user_id=self.user_id)
def update_persona(self, persona: PersonaModel):
return self.server.update_persona(persona=persona)
def update_persona(self, name: str, text: str):
persona = self.get_persona(name)
persona.text = text
return self.server.update_persona(persona)
def delete_persona(self, name: str):
return self.server.delete_persona(name, self.user_id)
@@ -926,9 +930,19 @@ class LocalClient(AbstractClient):
def create_source(self, name: str):
return self.server.create_source(user_id=self.user_id, name=name)
def delete_source(self, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None):
# TODO: delete source data
self.server.delete_source(user_id=self.user.id, source_id=source_id, source_name=source_name)
def get_source(self, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None):
return self.server.ms.get_source(user_id=self.user_id, source_id=source_id, source_name=source_name)
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)
def list_sources(self):
return self.server.list_all_sources(user_id=self.user_id)
def get_agent_archival_memory(
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
):
@@ -973,3 +987,6 @@ class LocalClient(AbstractClient):
)
return ListModelsResponse(models=[llm_config])
def list_attached_sources(self, agent_id: uuid.UUID):
return self.server.list_attached_sources(agent_id=agent_id)

View File

@@ -191,3 +191,7 @@ class DocumentModel(BaseModel):
data_source: str = Field(..., description="The data source of the document.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")
class UserModel(BaseModel):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user.")

View File

@@ -51,7 +51,7 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
"""
interface.clear()
agents_data = server.list_agents(user_id=user_id)
agents_data = server.list_agents_legacy(user_id=user_id)
return ListAgentsResponse(**agents_data)
@router.post("/agents", tags=["agents"], response_model=CreateAgentResponse)

View File

@@ -217,7 +217,7 @@ class SyncServer(LockingServer):
# Initialize the connection to the DB
self.config = MemGPTConfig.load()
logger.info(f"loading configuration from '{self.config.config_path}'")
logger.debug(f"loading configuration from '{self.config.config_path}'")
assert self.config.persona is not None, "Persona must be set in the config"
assert self.config.human is not None, "Human must be set in the config"
@@ -805,6 +805,8 @@ class SyncServer(LockingServer):
user_id: uuid.UUID,
agent_id: uuid.UUID,
):
# TODO: delete agent data
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
@@ -867,10 +869,21 @@ class SyncServer(LockingServer):
}
return agent_config
# TODO make return type pydantic
def list_agents(
self,
user_id: uuid.UUID,
) -> List[AgentState]:
"""List all available agents to a user"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
agents_states = self.ms.list_agents(user_id=user_id)
return agents_states
# TODO make return type pydantic
def list_agents_legacy(
self,
user_id: uuid.UUID,
) -> dict:
"""List all available agents to a user"""
if self.ms.get_user(user_id=user_id) is None:
@@ -1191,12 +1204,18 @@ class SyncServer(LockingServer):
# TODO: mark what is in-context versus not
return cursor, json_records
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> AgentState:
def get_agent_config(self, user_id: uuid.UUID, agent_id: Optional[uuid.UUID], agent_name: Optional[str] = None) -> AgentState:
"""Return the config of an agent"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
if agent_id:
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
else:
agent_state = self.ms.get_agent(agent_name=agent_name, user_id=user_id)
if agent_state is None:
raise ValueError(f"Agent agent_name={agent_name} does not exist")
agent_id = agent_state.id
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)

View File

@@ -4,6 +4,10 @@ import sys
subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"])
import pexpect
from prettytable.colortable import ColorTable
from memgpt.cli.cli_config import ListChoice, add, delete
from memgpt.cli.cli_config import list as list_command
from .constants import TIMEOUT
from .utils import create_config
@@ -11,6 +15,39 @@ from .utils import create_config
# def test_configure_memgpt():
# configure_memgpt()
options = [ListChoice.agents, ListChoice.sources, ListChoice.humans, ListChoice.personas]
def test_cli_list():
for option in options:
output = list_command(arg=option)
# check if is a list
assert isinstance(output, ColorTable)
def test_cli_config():
# test add
for option in ["human", "persona"]:
# create initial
add(option=option, name="test", text="test data")
## update
# filename = "test.txt"
# open(filename, "w").write("test data new")
# child = pexpect.spawn(f"poetry run memgpt add --{str(option)} {filename} --name test --strip-ui")
# child.expect("Human test already exists. Overwrite?", timeout=TIMEOUT)
# child.sendline()
# child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
# child.close()
for row in list_command(arg=ListChoice.humans if option == "human" else ListChoice.personas):
if row[0] == "test":
assert "test data" in row
# delete
delete(option=option, name="test")
def test_save_load():
# configure_memgpt() # rely on configure running first^

View File

@@ -196,14 +196,14 @@ def test_humans_personas(client, agent):
persona_name = "TestPersona"
if client.get_persona(persona_name):
client.delete_persona(persona_name)
persona = client.create_persona(name=persona_name, persona="Persona text")
persona = client.create_persona(name=persona_name, text="Persona text")
assert persona.name == persona_name
assert persona.text == "Persona text", "Creating persona failed"
human_name = "TestHuman"
if client.get_human(human_name):
client.delete_human(human_name)
human = client.create_human(name=human_name, human="Human text")
human = client.create_human(name=human_name, text="Human text")
assert human.name == human_name
assert human.text == "Human text", "Creating human failed"