fix: fix CLI commands by migrating to Python client (#1563)
This commit is contained in:
@@ -10,7 +10,6 @@ from prettytable.colortable import ColorTable, Themes
|
||||
from tqdm import tqdm
|
||||
|
||||
from memgpt import utils
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import LLM_MAX_TOKENS, MEMGPT_DIR
|
||||
from memgpt.credentials import SUPPORTED_AUTH_TYPES, MemGPTCredentials
|
||||
@@ -38,7 +37,6 @@ from memgpt.local_llm.constants import (
|
||||
)
|
||||
from memgpt.local_llm.utils import get_available_wrappers
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models.pydantic_models import PersonaModel
|
||||
from memgpt.server.utils import shorten_key_middle
|
||||
|
||||
app = typer.Typer()
|
||||
@@ -1096,18 +1094,16 @@ class ListChoice(str, Enum):
|
||||
def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
from memgpt.client.client import create_client
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
ms = MetadataStore(config)
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
|
||||
table = ColorTable(theme=Themes.OCEAN)
|
||||
if arg == ListChoice.agents:
|
||||
"""List all agents"""
|
||||
table.field_names = ["Name", "LLM Model", "Embedding Model", "Embedding Dim", "Persona", "Human", "Data Source", "Create Time"]
|
||||
for agent in tqdm(ms.list_agents(user_id=user_id)):
|
||||
source_ids = ms.list_attached_sources(agent_id=agent.id)
|
||||
for agent in tqdm(client.list_agents()):
|
||||
# TODO: add this function
|
||||
source_ids = client.list_attached_sources(agent_id=agent.id)
|
||||
assert all([source_id is not None and isinstance(source_id, uuid.UUID) for source_id in source_ids])
|
||||
sources = [ms.get_source(source_id=source_id) for source_id in source_ids]
|
||||
sources = [client.get_source(source_id=source_id) for source_id in source_ids]
|
||||
assert all([source is not None and isinstance(source, Source)] for source in sources)
|
||||
source_names = [source.name for source in sources if source is not None]
|
||||
table.add_row(
|
||||
@@ -1116,8 +1112,8 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
agent.llm_config.model,
|
||||
agent.embedding_config.embedding_model,
|
||||
agent.embedding_config.embedding_dim,
|
||||
agent.persona,
|
||||
agent.human,
|
||||
agent._metadata.get("persona", ""),
|
||||
agent._metadata.get("human", ""),
|
||||
",".join(source_names),
|
||||
utils.format_datetime(agent.created_at),
|
||||
]
|
||||
@@ -1132,25 +1128,21 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
elif arg == ListChoice.personas:
|
||||
"""List all personas"""
|
||||
table.field_names = ["Name", "Text"]
|
||||
for persona in ms.list_personas(user_id=user_id):
|
||||
for persona in client.list_personas():
|
||||
table.add_row([persona.name, persona.text.replace("\n", "")[:100]])
|
||||
print(table)
|
||||
elif arg == ListChoice.sources:
|
||||
"""List all data sources"""
|
||||
|
||||
# create table
|
||||
table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
|
||||
table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At"]
|
||||
# TODO: eventually look accross all storage connections
|
||||
# TODO: add data source stats
|
||||
# TODO: connect to agents
|
||||
|
||||
# get all sources
|
||||
for source in ms.list_sources(user_id=user_id):
|
||||
for source in client.list_sources():
|
||||
# get attached agents
|
||||
agent_ids = ms.list_attached_agents(source_id=source.id)
|
||||
agent_states = [ms.get_agent(agent_id=agent_id) for agent_id in agent_ids]
|
||||
agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None]
|
||||
|
||||
table.add_row(
|
||||
[
|
||||
source.name,
|
||||
@@ -1158,13 +1150,13 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
source.embedding_model,
|
||||
source.embedding_dim,
|
||||
utils.format_datetime(source.created_at),
|
||||
",".join(agent_names),
|
||||
]
|
||||
)
|
||||
|
||||
print(table)
|
||||
else:
|
||||
raise ValueError(f"Unknown argument {arg}")
|
||||
return table
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -1177,25 +1169,20 @@ def add(
|
||||
"""Add a person/human"""
|
||||
from memgpt.client.client import create_client
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
ms = MetadataStore(config)
|
||||
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_SERVER_PASS"))
|
||||
if filename: # read from file
|
||||
assert text is None, "Cannot specify both text and filename"
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
if option == "persona":
|
||||
persona = ms.get_persona(name=name)
|
||||
persona = client.get_persona(name)
|
||||
if persona:
|
||||
# config if user wants to overwrite
|
||||
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
|
||||
return
|
||||
persona.text = text
|
||||
ms.update_persona(persona)
|
||||
client.update_persona(name=name, text=text)
|
||||
else:
|
||||
persona = PersonaModel(name=name, text=text, user_id=user_id)
|
||||
ms.add_persona(persona)
|
||||
client.create_persona(name=name, text=text)
|
||||
|
||||
elif option == "human":
|
||||
human = client.get_human(name=name)
|
||||
@@ -1203,10 +1190,9 @@ def add(
|
||||
# config if user wants to overwrite
|
||||
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
|
||||
return
|
||||
human.text = text
|
||||
client.update_human(human)
|
||||
client.update_human(name=name, text=text)
|
||||
else:
|
||||
human = client.create_human(name=name, human=text)
|
||||
human = client.create_human(name=name, text=text)
|
||||
else:
|
||||
raise ValueError(f"Unknown kind {option}")
|
||||
|
||||
@@ -1216,53 +1202,24 @@ def delete(option: str, name: str):
|
||||
"""Delete a source from the archival memory."""
|
||||
from memgpt.client.client import create_client
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
client = create_client(base_url=os.getenv("MEMGPT_BASE_URL"), token=os.getenv("MEMGPT_API_KEY"))
|
||||
ms = MetadataStore(config)
|
||||
assert ms.get_user(user_id=user_id), f"User {user_id} does not exist"
|
||||
|
||||
try:
|
||||
# delete from metadata
|
||||
if option == "source":
|
||||
# delete metadata
|
||||
source = ms.get_source(source_name=name, user_id=user_id)
|
||||
source = client.get_source(name)
|
||||
assert source is not None, f"Source {name} does not exist"
|
||||
ms.delete_source(source_id=source.id)
|
||||
|
||||
# delete from passages
|
||||
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id=user_id)
|
||||
conn.delete({"data_source": name})
|
||||
|
||||
assert (
|
||||
conn.get_all({"data_source": name}) == []
|
||||
), f"Expected no passages with source {name}, but got {conn.get_all({'data_source': name})}"
|
||||
|
||||
# TODO: should we also delete from agents?
|
||||
client.delete_source(source_id=source.id)
|
||||
elif option == "agent":
|
||||
agent = ms.get_agent(agent_name=name, user_id=user_id)
|
||||
assert agent is not None, f"Agent {name} for user_id {user_id} does not exist"
|
||||
|
||||
# recall memory
|
||||
recall_conn = StorageConnector.get_storage_connector(TableType.RECALL_MEMORY, config, user_id=user_id, agent_id=agent.id)
|
||||
recall_conn.delete({"agent_id": agent.id})
|
||||
|
||||
# archival memory
|
||||
archival_conn = StorageConnector.get_storage_connector(TableType.ARCHIVAL_MEMORY, config, user_id=user_id, agent_id=agent.id)
|
||||
archival_conn.delete({"agent_id": agent.id})
|
||||
|
||||
# metadata
|
||||
ms.delete_agent(agent_id=agent.id)
|
||||
|
||||
client.delete_agent(name=name)
|
||||
elif option == "human":
|
||||
human = client.get_human(name=name)
|
||||
assert human is not None, f"Human {name} does not exist"
|
||||
client.delete_human(name=name)
|
||||
elif option == "persona":
|
||||
persona = ms.get_persona(name=name)
|
||||
persona = client.get_persona(name=name)
|
||||
assert persona is not None, f"Persona {name} does not exist"
|
||||
ms.delete_persona(name=name)
|
||||
assert ms.get_persona(name=name) is None, f"Persona {name} still exists"
|
||||
client.delete_persona(name=name)
|
||||
else:
|
||||
raise ValueError(f"Option {option} not implemented")
|
||||
|
||||
|
||||
@@ -166,7 +166,7 @@ class AbstractClient(object):
|
||||
"""List all humans."""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_human(self, name: str, human: str):
|
||||
def create_human(self, name: str, text: str):
|
||||
"""Create a human."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -174,7 +174,7 @@ class AbstractClient(object):
|
||||
"""List all personas."""
|
||||
raise NotImplementedError
|
||||
|
||||
def create_persona(self, name: str, persona: str):
|
||||
def create_persona(self, name: str, text: str):
|
||||
"""Create a persona."""
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -498,8 +498,8 @@ class RESTClient(AbstractClient):
|
||||
response = requests.get(f"{self.base_url}/api/humans", headers=self.headers)
|
||||
return ListHumansResponse(**response.json())
|
||||
|
||||
def create_human(self, name: str, human: str) -> HumanModel:
|
||||
data = {"name": name, "text": human}
|
||||
def create_human(self, name: str, text: str) -> HumanModel:
|
||||
data = {"name": name, "text": text}
|
||||
response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create human: {response.text}")
|
||||
@@ -509,8 +509,8 @@ class RESTClient(AbstractClient):
|
||||
response = requests.get(f"{self.base_url}/api/personas", headers=self.headers)
|
||||
return ListPersonasResponse(**response.json())
|
||||
|
||||
def create_persona(self, name: str, persona: str) -> PersonaModel:
|
||||
data = {"name": name, "text": persona}
|
||||
def create_persona(self, name: str, text: str) -> PersonaModel:
|
||||
data = {"name": name, "text": text}
|
||||
response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create persona: {response.text}")
|
||||
@@ -834,11 +834,11 @@ class LocalClient(AbstractClient):
|
||||
|
||||
# humans / personas
|
||||
|
||||
def create_human(self, name: str, human: str):
|
||||
return self.server.add_human(HumanModel(name=name, text=human, user_id=self.user_id))
|
||||
def create_human(self, name: str, text: str):
|
||||
return self.server.add_human(HumanModel(name=name, text=text, user_id=self.user_id))
|
||||
|
||||
def create_persona(self, name: str, persona: str):
|
||||
return self.server.add_persona(PersonaModel(name=name, text=persona, user_id=self.user_id))
|
||||
def create_persona(self, name: str, text: str):
|
||||
return self.server.add_persona(PersonaModel(name=name, text=text, user_id=self.user_id))
|
||||
|
||||
def list_humans(self):
|
||||
return self.server.list_humans(user_id=self.user_id if self.user_id else self.user_id)
|
||||
@@ -846,8 +846,10 @@ class LocalClient(AbstractClient):
|
||||
def get_human(self, name: str):
|
||||
return self.server.get_human(name=name, user_id=self.user_id)
|
||||
|
||||
def update_human(self, human: HumanModel):
|
||||
return self.server.update_human(human=human)
|
||||
def update_human(self, name: str, text: str):
|
||||
human = self.get_human(name)
|
||||
human.text = text
|
||||
return self.server.update_human(human)
|
||||
|
||||
def delete_human(self, name: str):
|
||||
return self.server.delete_human(name, self.user_id)
|
||||
@@ -858,8 +860,10 @@ class LocalClient(AbstractClient):
|
||||
def get_persona(self, name: str):
|
||||
return self.server.get_persona(name=name, user_id=self.user_id)
|
||||
|
||||
def update_persona(self, persona: PersonaModel):
|
||||
return self.server.update_persona(persona=persona)
|
||||
def update_persona(self, name: str, text: str):
|
||||
persona = self.get_persona(name)
|
||||
persona.text = text
|
||||
return self.server.update_persona(persona)
|
||||
|
||||
def delete_persona(self, name: str):
|
||||
return self.server.delete_persona(name, self.user_id)
|
||||
@@ -926,9 +930,19 @@ class LocalClient(AbstractClient):
|
||||
def create_source(self, name: str):
|
||||
return self.server.create_source(user_id=self.user_id, name=name)
|
||||
|
||||
def delete_source(self, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None):
|
||||
# TODO: delete source data
|
||||
self.server.delete_source(user_id=self.user.id, source_id=source_id, source_name=source_name)
|
||||
|
||||
def get_source(self, source_id: Optional[uuid.UUID] = None, source_name: Optional[str] = None):
|
||||
return self.server.ms.get_source(user_id=self.user_id, source_id=source_id, source_name=source_name)
|
||||
|
||||
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)
|
||||
|
||||
def list_sources(self):
|
||||
return self.server.list_all_sources(user_id=self.user_id)
|
||||
|
||||
def get_agent_archival_memory(
|
||||
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
|
||||
):
|
||||
@@ -973,3 +987,6 @@ class LocalClient(AbstractClient):
|
||||
)
|
||||
|
||||
return ListModelsResponse(models=[llm_config])
|
||||
|
||||
def list_attached_sources(self, agent_id: uuid.UUID):
|
||||
return self.server.list_attached_sources(agent_id=agent_id)
|
||||
|
||||
@@ -191,3 +191,7 @@ class DocumentModel(BaseModel):
|
||||
data_source: str = Field(..., description="The data source of the document.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the document.", primary_key=True)
|
||||
metadata: Optional[Dict] = Field({}, description="The metadata of the document.")
|
||||
|
||||
|
||||
class UserModel(BaseModel):
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user.")
|
||||
|
||||
@@ -51,7 +51,7 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
This endpoint retrieves a list of all agents and their configurations associated with the specified user ID.
|
||||
"""
|
||||
interface.clear()
|
||||
agents_data = server.list_agents(user_id=user_id)
|
||||
agents_data = server.list_agents_legacy(user_id=user_id)
|
||||
return ListAgentsResponse(**agents_data)
|
||||
|
||||
@router.post("/agents", tags=["agents"], response_model=CreateAgentResponse)
|
||||
|
||||
@@ -217,7 +217,7 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Initialize the connection to the DB
|
||||
self.config = MemGPTConfig.load()
|
||||
logger.info(f"loading configuration from '{self.config.config_path}'")
|
||||
logger.debug(f"loading configuration from '{self.config.config_path}'")
|
||||
assert self.config.persona is not None, "Persona must be set in the config"
|
||||
assert self.config.human is not None, "Human must be set in the config"
|
||||
|
||||
@@ -805,6 +805,8 @@ class SyncServer(LockingServer):
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
):
|
||||
# TODO: delete agent data
|
||||
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
@@ -867,10 +869,21 @@ class SyncServer(LockingServer):
|
||||
}
|
||||
return agent_config
|
||||
|
||||
# TODO make return type pydantic
|
||||
def list_agents(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
) -> List[AgentState]:
|
||||
"""List all available agents to a user"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
agents_states = self.ms.list_agents(user_id=user_id)
|
||||
return agents_states
|
||||
|
||||
# TODO make return type pydantic
|
||||
def list_agents_legacy(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
) -> dict:
|
||||
"""List all available agents to a user"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
@@ -1191,12 +1204,18 @@ class SyncServer(LockingServer):
|
||||
# TODO: mark what is in-context versus not
|
||||
return cursor, json_records
|
||||
|
||||
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> AgentState:
|
||||
def get_agent_config(self, user_id: uuid.UUID, agent_id: Optional[uuid.UUID], agent_name: Optional[str] = None) -> AgentState:
|
||||
"""Return the config of an agent"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
if agent_id:
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
else:
|
||||
agent_state = self.ms.get_agent(agent_name=agent_name, user_id=user_id)
|
||||
if agent_state is None:
|
||||
raise ValueError(f"Agent agent_name={agent_name} does not exist")
|
||||
agent_id = agent_state.id
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
|
||||
@@ -4,6 +4,10 @@ import sys
|
||||
|
||||
subprocess.check_call([sys.executable, "-m", "pip", "install", "pexpect"])
|
||||
import pexpect
|
||||
from prettytable.colortable import ColorTable
|
||||
|
||||
from memgpt.cli.cli_config import ListChoice, add, delete
|
||||
from memgpt.cli.cli_config import list as list_command
|
||||
|
||||
from .constants import TIMEOUT
|
||||
from .utils import create_config
|
||||
@@ -11,6 +15,39 @@ from .utils import create_config
|
||||
# def test_configure_memgpt():
|
||||
# configure_memgpt()
|
||||
|
||||
options = [ListChoice.agents, ListChoice.sources, ListChoice.humans, ListChoice.personas]
|
||||
|
||||
|
||||
def test_cli_list():
|
||||
for option in options:
|
||||
output = list_command(arg=option)
|
||||
# check if is a list
|
||||
assert isinstance(output, ColorTable)
|
||||
|
||||
|
||||
def test_cli_config():
|
||||
|
||||
# test add
|
||||
for option in ["human", "persona"]:
|
||||
|
||||
# create initial
|
||||
add(option=option, name="test", text="test data")
|
||||
|
||||
## update
|
||||
# filename = "test.txt"
|
||||
# open(filename, "w").write("test data new")
|
||||
# child = pexpect.spawn(f"poetry run memgpt add --{str(option)} {filename} --name test --strip-ui")
|
||||
# child.expect("Human test already exists. Overwrite?", timeout=TIMEOUT)
|
||||
# child.sendline()
|
||||
# child.expect(pexpect.EOF, timeout=TIMEOUT) # Wait for child to exit
|
||||
# child.close()
|
||||
|
||||
for row in list_command(arg=ListChoice.humans if option == "human" else ListChoice.personas):
|
||||
if row[0] == "test":
|
||||
assert "test data" in row
|
||||
# delete
|
||||
delete(option=option, name="test")
|
||||
|
||||
|
||||
def test_save_load():
|
||||
# configure_memgpt() # rely on configure running first^
|
||||
|
||||
@@ -196,14 +196,14 @@ def test_humans_personas(client, agent):
|
||||
persona_name = "TestPersona"
|
||||
if client.get_persona(persona_name):
|
||||
client.delete_persona(persona_name)
|
||||
persona = client.create_persona(name=persona_name, persona="Persona text")
|
||||
persona = client.create_persona(name=persona_name, text="Persona text")
|
||||
assert persona.name == persona_name
|
||||
assert persona.text == "Persona text", "Creating persona failed"
|
||||
|
||||
human_name = "TestHuman"
|
||||
if client.get_human(human_name):
|
||||
client.delete_human(human_name)
|
||||
human = client.create_human(name=human_name, human="Human text")
|
||||
human = client.create_human(name=human_name, text="Human text")
|
||||
assert human.name == human_name
|
||||
assert human.text == "Human text", "Creating human failed"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user