feat: Add ability to add tags to agents (#1984)
This commit is contained in:
52
alembic/versions/b6d7ca024aa9_add_agents_tags_table.py
Normal file
52
alembic/versions/b6d7ca024aa9_add_agents_tags_table.py
Normal file
@@ -0,0 +1,52 @@
|
||||
"""Add agents tags table
|
||||
|
||||
Revision ID: b6d7ca024aa9
|
||||
Revises: d14ae606614c
|
||||
Create Date: 2024-11-06 10:48:08.424108
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b6d7ca024aa9"
|
||||
down_revision: Union[str, None] = "d14ae606614c"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"agents_tags",
|
||||
sa.Column("agent_id", sa.String(), nullable=False),
|
||||
sa.Column("tag", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["agent_id"],
|
||||
["agents.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("agent_id", "id"),
|
||||
sa.UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("agents_tags")
|
||||
# ### end Alembic commands ###
|
||||
@@ -334,8 +334,12 @@ class RESTClient(AbstractClient):
|
||||
self._default_llm_config = default_llm_config
|
||||
self._default_embedding_config = default_embedding_config
|
||||
|
||||
def list_agents(self) -> List[AgentState]:
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers)
|
||||
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
|
||||
params = {}
|
||||
if tags:
|
||||
params["tags"] = tags
|
||||
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
|
||||
return [AgentState(**agent) for agent in response.json()]
|
||||
|
||||
def agent_exists(self, agent_id: str) -> bool:
|
||||
@@ -480,6 +484,7 @@ class RESTClient(AbstractClient):
|
||||
description: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
@@ -509,6 +514,7 @@ class RESTClient(AbstractClient):
|
||||
name=name,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tags=tags,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
llm_config=llm_config,
|
||||
@@ -1617,13 +1623,10 @@ class LocalClient(AbstractClient):
|
||||
self.organization = self.server.get_organization_or_default(self.org_id)
|
||||
|
||||
# agents
|
||||
def list_agents(self) -> List[AgentState]:
|
||||
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
|
||||
self.interface.clear()
|
||||
|
||||
# TODO: fix the server function
|
||||
# return self.server.list_agents(user_id=self.user_id)
|
||||
|
||||
return self.server.ms.list_agents(user_id=self.user_id)
|
||||
return self.server.list_agents(user_id=self.user_id, tags=tags)
|
||||
|
||||
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
@@ -1757,6 +1760,7 @@ class LocalClient(AbstractClient):
|
||||
description: Optional[str] = None,
|
||||
system: Optional[str] = None,
|
||||
tools: Optional[List[str]] = None,
|
||||
tags: Optional[List[str]] = None,
|
||||
metadata: Optional[Dict] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
@@ -1788,6 +1792,7 @@ class LocalClient(AbstractClient):
|
||||
name=name,
|
||||
system=system,
|
||||
tools=tools,
|
||||
tags=tags,
|
||||
description=description,
|
||||
metadata_=metadata,
|
||||
llm_config=llm_config,
|
||||
@@ -1872,7 +1877,7 @@ class LocalClient(AbstractClient):
|
||||
agent_state (AgentState): State of the agent
|
||||
"""
|
||||
self.interface.clear()
|
||||
return self.server.get_agent(agent_name=agent_name, user_id=self.user_id, agent_id=None)
|
||||
return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None)
|
||||
|
||||
def get_agent(self, agent_id: str) -> AgentState:
|
||||
"""
|
||||
|
||||
@@ -493,6 +493,7 @@ class MetadataStore:
|
||||
fields = vars(agent)
|
||||
fields["memory"] = agent.memory.to_dict()
|
||||
del fields["_internal_memory"]
|
||||
del fields["tags"]
|
||||
session.add(AgentModel(**fields))
|
||||
session.commit()
|
||||
|
||||
@@ -531,6 +532,7 @@ class MetadataStore:
|
||||
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
|
||||
fields["memory"] = agent.memory.to_dict()
|
||||
del fields["_internal_memory"]
|
||||
del fields["tags"]
|
||||
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
|
||||
session.commit()
|
||||
|
||||
|
||||
28
letta/orm/agents_tags.py
Normal file
28
letta/orm/agents_tags.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from sqlalchemy import ForeignKey, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
|
||||
|
||||
class AgentsTags(SqlalchemyBase, OrganizationMixin):
|
||||
"""Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering."""
|
||||
|
||||
__tablename__ = "agents_tags"
|
||||
__pydantic_model__ = PydanticAgentsTags
|
||||
__table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),)
|
||||
|
||||
# The agent associated with this tag
|
||||
agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
|
||||
# The name of the tag
|
||||
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.")
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags")
|
||||
@@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):
|
||||
|
||||
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
||||
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
# TODO: Map these relationships later when we actually make these models
|
||||
# below is just a suggestion
|
||||
|
||||
@@ -112,6 +112,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
self.is_deleted = True
|
||||
return self.update(db_session)
|
||||
|
||||
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database."""
|
||||
if actor:
|
||||
logger.info(f"User {actor.id} requested hard deletion of {self.__class__.__name__} with ID {self.id}")
|
||||
|
||||
with db_session as session:
|
||||
try:
|
||||
session.delete(self)
|
||||
session.commit()
|
||||
except Exception as e:
|
||||
session.rollback()
|
||||
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
|
||||
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
|
||||
else:
|
||||
logger.info(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
||||
|
||||
def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
|
||||
@@ -64,6 +64,9 @@ class AgentState(BaseAgent, validate_assignment=True):
|
||||
# tool rules
|
||||
tool_rules: Optional[List[BaseToolRule]] = Field(default=None, description="The list of tool rules.")
|
||||
|
||||
# tags
|
||||
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
|
||||
|
||||
# system prompt
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
|
||||
@@ -108,6 +111,7 @@ class CreateAgent(BaseAgent):
|
||||
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
|
||||
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
|
||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
@@ -148,6 +152,7 @@ class UpdateAgentState(BaseAgent):
|
||||
id: str = Field(..., description="The id of the agent.")
|
||||
name: Optional[str] = Field(None, description="The name of the agent.")
|
||||
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
|
||||
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
|
||||
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
|
||||
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
|
||||
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")
|
||||
|
||||
33
letta/schemas/agents_tags.py
Normal file
33
letta/schemas/agents_tags.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class AgentsTagsBase(LettaBase):
|
||||
__id_prefix__ = "agents_tags"
|
||||
|
||||
|
||||
class AgentsTags(AgentsTagsBase):
|
||||
"""
|
||||
Schema representing the relationship between tags and agents.
|
||||
|
||||
Parameters:
|
||||
agent_id (str): The ID of the associated agent.
|
||||
tag_id (str): The ID of the associated tag.
|
||||
tag_name (str): The name of the tag.
|
||||
created_at (datetime): The date this relationship was created.
|
||||
"""
|
||||
|
||||
id: str = AgentsTagsBase.generate_id_field()
|
||||
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||
tag: str = Field(..., description="The name of the tag.")
|
||||
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||
updated_at: Optional[datetime] = Field(None, description="The update date of the tag.")
|
||||
is_deleted: bool = Field(False, description="Whether this tag is deleted or not.")
|
||||
|
||||
|
||||
class AgentsTagsCreate(AgentsTagsBase):
|
||||
tag: str = Field(..., description="The tag name.")
|
||||
@@ -41,6 +41,7 @@ router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
|
||||
def list_agents(
|
||||
name: Optional[str] = Query(None, description="Name of the agent"),
|
||||
tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
@@ -50,7 +51,7 @@ def list_agents(
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
agents = server.list_agents(user_id=actor.id)
|
||||
agents = server.list_agents(user_id=actor.id, tags=tags)
|
||||
# TODO: move this logic to the ORM
|
||||
if name:
|
||||
agents = [a for a in agents if a.name == name]
|
||||
@@ -534,124 +535,3 @@ async def send_message_to_agent(
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
|
||||
|
||||
##### MISSING #######
|
||||
|
||||
# @router.post("/{agent_id}/command")
|
||||
# def run_command(
|
||||
# agent_id: "UUID",
|
||||
# command: "AgentCommandRequest",
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_letta_server),
|
||||
# ):
|
||||
# """
|
||||
# Execute a command on a specified agent.
|
||||
|
||||
# This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately.
|
||||
|
||||
# Raises an HTTPException for any processing errors.
|
||||
# """
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# response = server.run_command(user_id=actor.id,
|
||||
# agent_id=agent_id,
|
||||
# command=command.command)
|
||||
|
||||
# return AgentCommandResponse(response=response)
|
||||
|
||||
# @router.get("/{agent_id}/config")
|
||||
# def get_agent_config(
|
||||
# agent_id: "UUID",
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_letta_server),
|
||||
# ):
|
||||
# """
|
||||
# Retrieve the configuration for a specific agent.
|
||||
|
||||
# This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
|
||||
# """
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
|
||||
## agent does not exist
|
||||
# raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
|
||||
|
||||
# agent_state = server.get_agent_config(user_id=actor.id, agent_id=agent_id)
|
||||
## get sources
|
||||
# attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
|
||||
## configs
|
||||
# llm_config = LLMConfig(**vars(agent_state.llm_config))
|
||||
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
|
||||
|
||||
# return GetAgentResponse(
|
||||
# agent_state=AgentState(
|
||||
# id=agent_state.id,
|
||||
# name=agent_state.name,
|
||||
# user_id=agent_state.user_id,
|
||||
# llm_config=llm_config,
|
||||
# embedding_config=embedding_config,
|
||||
# state=agent_state.state,
|
||||
# created_at=int(agent_state.created_at.timestamp()),
|
||||
# tools=agent_state.tools,
|
||||
# system=agent_state.system,
|
||||
# metadata=agent_state._metadata,
|
||||
# ),
|
||||
# last_run_at=None, # TODO
|
||||
# sources=attached_sources,
|
||||
# )
|
||||
|
||||
# @router.patch("/{agent_id}/rename", response_model=GetAgentResponse)
|
||||
# def update_agent_name(
|
||||
# agent_id: "UUID",
|
||||
# agent_rename: AgentRenameRequest,
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_letta_server),
|
||||
# ):
|
||||
# """
|
||||
# Updates the name of a specific agent.
|
||||
|
||||
# This changes the name of the agent in the database but does NOT edit the agent's persona.
|
||||
# """
|
||||
# valid_name = agent_rename.agent_name
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# agent_state = server.rename_agent(user_id=actor.id, agent_id=agent_id, new_agent_name=valid_name)
|
||||
## get sources
|
||||
# attached_sources = server.list_attached_sources(agent_id=agent_id)
|
||||
# llm_config = LLMConfig(**vars(agent_state.llm_config))
|
||||
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
|
||||
|
||||
# return GetAgentResponse(
|
||||
# agent_state=AgentState(
|
||||
# id=agent_state.id,
|
||||
# name=agent_state.name,
|
||||
# user_id=agent_state.user_id,
|
||||
# llm_config=llm_config,
|
||||
# embedding_config=embedding_config,
|
||||
# state=agent_state.state,
|
||||
# created_at=int(agent_state.created_at.timestamp()),
|
||||
# tools=agent_state.tools,
|
||||
# system=agent_state.system,
|
||||
# ),
|
||||
# last_run_at=None, # TODO
|
||||
# sources=attached_sources,
|
||||
# )
|
||||
|
||||
|
||||
# @router.get("/{agent_id}/archival/all", response_model=GetAgentArchivalMemoryResponse)
|
||||
# def get_agent_archival_memory_all(
|
||||
# agent_id: "UUID",
|
||||
#
|
||||
# server: "SyncServer" = Depends(get_letta_server),
|
||||
# ):
|
||||
# """
|
||||
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
|
||||
# """
|
||||
# actor = server.get_current_user()
|
||||
#
|
||||
# archival_memories = server.get_all_archival_memories(user_id=actor.id, agent_id=agent_id)
|
||||
# print("archival_memories:", archival_memories)
|
||||
# archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories]
|
||||
# return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
|
||||
|
||||
@@ -82,6 +82,7 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
@@ -248,6 +249,7 @@ class SyncServer(Server):
|
||||
self.organization_manager = OrganizationManager()
|
||||
self.user_manager = UserManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.agents_tags_manager = AgentsTagsManager()
|
||||
|
||||
# Make default user and org
|
||||
if init_with_default_org_and_user:
|
||||
@@ -969,6 +971,19 @@ class SyncServer(Server):
|
||||
if request.metadata_:
|
||||
letta_agent.agent_state.metadata_ = request.metadata_
|
||||
|
||||
# Manage tag state
|
||||
if request.tags is not None:
|
||||
current_tags = set(self.agents_tags_manager.get_tags_for_agent(agent_id=letta_agent.agent_state.id, actor=actor))
|
||||
target_tags = set(request.tags)
|
||||
|
||||
tags_to_add = target_tags - current_tags
|
||||
tags_to_remove = current_tags - target_tags
|
||||
|
||||
for tag in tags_to_add:
|
||||
self.agents_tags_manager.add_tag_to_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor)
|
||||
for tag in tags_to_remove:
|
||||
self.agents_tags_manager.delete_tag_from_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor)
|
||||
|
||||
# save the agent
|
||||
assert isinstance(letta_agent.memory, Memory)
|
||||
save_agent(letta_agent, self.ms)
|
||||
@@ -1079,16 +1094,19 @@ class SyncServer(Server):
|
||||
}
|
||||
return agent_config
|
||||
|
||||
def list_agents(
|
||||
self,
|
||||
user_id: str,
|
||||
) -> List[AgentState]:
|
||||
def list_agents(self, user_id: str, tags: Optional[List[str]] = None) -> List[AgentState]:
|
||||
"""List all available agents to a user"""
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
user = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
|
||||
agents_states = self.ms.list_agents(user_id=user_id)
|
||||
return agents_states
|
||||
if tags is None:
|
||||
agents_states = self.ms.list_agents(user_id=user_id)
|
||||
return agents_states
|
||||
else:
|
||||
agent_ids = []
|
||||
for tag in tags:
|
||||
agent_ids += self.agents_tags_manager.get_agents_by_tag(tag=tag, actor=user)
|
||||
|
||||
return [self.get_agent_state(user_id=user.id, agent_id=agent_id) for agent_id in agent_ids]
|
||||
|
||||
def get_blocks(
|
||||
self,
|
||||
@@ -1160,18 +1178,6 @@ class SyncServer(Server):
|
||||
raise ValueError("Source does not exist")
|
||||
return existing_source.id
|
||||
|
||||
def get_agent(self, user_id: str, agent_id: Optional[str] = None, agent_name: Optional[str] = None):
|
||||
"""Get the agent state"""
|
||||
return self.ms.get_agent(agent_id=agent_id, agent_name=agent_name, user_id=user_id)
|
||||
|
||||
# def get_user(self, user_id: str) -> User:
|
||||
# """Get the user"""
|
||||
# user = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
# if user is None:
|
||||
# raise ValueError(f"User with user_id {user_id} does not exist")
|
||||
# else:
|
||||
# return user
|
||||
|
||||
def get_agent_memory(self, agent_id: str) -> Memory:
|
||||
"""Return the memory of an agent (core memory)"""
|
||||
agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
@@ -1389,8 +1395,7 @@ class SyncServer(Server):
|
||||
|
||||
def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]:
|
||||
"""Return the config of an agent"""
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
user = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
if agent_id:
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
return None
|
||||
@@ -1403,7 +1408,11 @@ class SyncServer(Server):
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
assert isinstance(letta_agent.memory, Memory)
|
||||
return letta_agent.agent_state.model_copy(deep=True)
|
||||
agent_state = letta_agent.agent_state.model_copy(deep=True)
|
||||
|
||||
# Load the tags in for the agent_state
|
||||
agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user)
|
||||
return agent_state
|
||||
|
||||
def get_server_config(self, include_defaults: bool = False) -> dict:
|
||||
"""Return the base config"""
|
||||
@@ -1485,8 +1494,11 @@ class SyncServer(Server):
|
||||
|
||||
def delete_agent(self, user_id: str, agent_id: str):
|
||||
"""Delete an agent in the database"""
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
actor = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL
|
||||
# TODO: EVENTUALLY WE GET AUTO-DELETES WHEN WE SPECIFY RELATIONSHIPS IN THE ORM
|
||||
self.agents_tags_manager.delete_all_tags_from_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
|
||||
|
||||
64
letta/services/agents_tags_manager.py
Normal file
64
letta/services/agents_tags_manager.py
Normal file
@@ -0,0 +1,64 @@
|
||||
from typing import List
|
||||
|
||||
from letta.orm.agents_tags import AgentsTags as AgentsTagsModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.utils import enforce_types
|
||||
|
||||
|
||||
class AgentsTagsManager:
|
||||
"""Manager class to handle business logic related to Tags."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
|
||||
self.session_maker = db_context
|
||||
|
||||
@enforce_types
|
||||
def add_tag_to_agent(self, agent_id: str, tag: str, actor: PydanticUser) -> PydanticAgentsTags:
|
||||
"""Add a tag to an agent."""
|
||||
with self.session_maker() as session:
|
||||
# Check if the tag already exists for this agent
|
||||
try:
|
||||
agents_tags_model = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
|
||||
return agents_tags_model.to_pydantic()
|
||||
except NoResultFound:
|
||||
agents_tags = PydanticAgentsTags(agent_id=agent_id, tag=tag).model_dump(exclude_none=True)
|
||||
new_tag = AgentsTagsModel(**agents_tags, organization_id=actor.organization_id)
|
||||
new_tag.create(session, actor=actor)
|
||||
return new_tag.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_all_tags_from_agent(self, agent_id: str, actor: PydanticUser):
|
||||
"""Delete a tag from an agent. This is a permanent hard delete."""
|
||||
tags = self.get_tags_for_agent(agent_id=agent_id, actor=actor)
|
||||
for tag in tags:
|
||||
self.delete_tag_from_agent(agent_id=agent_id, tag=tag, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def delete_tag_from_agent(self, agent_id: str, tag: str, actor: PydanticUser):
|
||||
"""Delete a tag from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Retrieve and delete the tag association
|
||||
tag_association = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
|
||||
tag_association.hard_delete(session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tag '{tag}' not found for agent '{agent_id}'.")
|
||||
|
||||
@enforce_types
|
||||
def get_agents_by_tag(self, tag: str, actor: PydanticUser) -> List[str]:
|
||||
"""Retrieve all agent IDs associated with a specific tag."""
|
||||
with self.session_maker() as session:
|
||||
# Query for all agents with the given tag
|
||||
agents_with_tag = AgentsTagsModel.list(db_session=session, tag=tag, organization_id=actor.organization_id)
|
||||
return [record.agent_id for record in agents_with_tag]
|
||||
|
||||
@enforce_types
|
||||
def get_tags_for_agent(self, agent_id: str, actor: PydanticUser) -> List[str]:
|
||||
"""Retrieve all tags associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
# Query for all tags associated with the given agent
|
||||
tags_for_agent = AgentsTagsModel.list(db_session=session, agent_id=agent_id, organization_id=actor.organization_id)
|
||||
return [record.tag for record in tags_for_agent]
|
||||
@@ -133,7 +133,7 @@ class ToolManager:
|
||||
"""Delete a tool by its ID."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id)
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||
tool.delete(db_session=session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@@ -76,7 +76,6 @@ def client(request):
|
||||
client = create_client(base_url=server_url, token=None) # This yields control back to the test function
|
||||
else:
|
||||
# use local client (no server)
|
||||
server_url = None
|
||||
client = create_client()
|
||||
|
||||
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
|
||||
@@ -88,7 +87,6 @@ def client(request):
|
||||
@pytest.fixture(scope="module")
|
||||
def agent(client: Union[LocalClient, RESTClient]):
|
||||
agent_state = client.create_agent(name=test_agent_name)
|
||||
print("AGENT ID", agent_state.id)
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
@@ -676,3 +674,38 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
|
||||
len(custom_agent_state.message_ids) == len(custom_sequence) + 1
|
||||
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
|
||||
assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
|
||||
|
||||
|
||||
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
"""
|
||||
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
|
||||
"""
|
||||
|
||||
# Step 1: Add multiple tags to the agent
|
||||
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]
|
||||
client.update_agent(agent_id=agent.id, tags=tags_to_add)
|
||||
|
||||
# Step 2: Retrieve tags for the agent and verify they match the added tags
|
||||
retrieved_tags = client.get_agent(agent_id=agent.id).tags
|
||||
assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}"
|
||||
|
||||
# Step 3: Retrieve agents by each tag to ensure the agent is associated correctly
|
||||
for tag in tags_to_add:
|
||||
agents_with_tag = client.list_agents(tags=[tag])
|
||||
assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'"
|
||||
|
||||
# Step 4: Delete a specific tag from the agent and verify its removal
|
||||
tag_to_delete = tags_to_add.pop()
|
||||
client.update_agent(agent_id=agent.id, tags=tags_to_add)
|
||||
|
||||
# Verify the tag is removed from the agent's tags
|
||||
remaining_tags = client.get_agent(agent_id=agent.id).tags
|
||||
assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected"
|
||||
assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}"
|
||||
|
||||
# Step 5: Delete all remaining tags from the agent
|
||||
client.update_agent(agent_id=agent.id, tags=[])
|
||||
|
||||
# Verify all tags are removed
|
||||
final_tags = client.get_agent(agent_id=agent.id).tags
|
||||
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"
|
||||
|
||||
@@ -6,6 +6,10 @@ from letta.functions.functions import derive_openai_json_schema, parse_source_co
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.user import User
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.organization import Organization as PydanticOrganization
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
@@ -29,7 +33,68 @@ def clear_tables(server: SyncServer):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fixture(server: SyncServer):
|
||||
def default_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_default_user(org_id=default_organization.id)
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id))
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sarah_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="sarah_agent",
|
||||
memory=ChatMemory(
|
||||
human="Charles",
|
||||
persona="I am a helpful assistant",
|
||||
),
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def charles_agent(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create and return a sample agent within the default organization."""
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="charles_agent",
|
||||
memory=ChatMemory(
|
||||
human="Sarah",
|
||||
persona="I am a helpful assistant",
|
||||
),
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
yield agent_state
|
||||
|
||||
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tool_fixture(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create a tool with default settings and clean up after the test."""
|
||||
|
||||
def print_tool(message: str):
|
||||
@@ -43,6 +108,7 @@ def tool_fixture(server: SyncServer):
|
||||
print(message)
|
||||
return message
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(print_tool)
|
||||
source_type = "python"
|
||||
description = "test_description"
|
||||
@@ -50,9 +116,9 @@ def tool_fixture(server: SyncServer):
|
||||
|
||||
org = server.organization_manager.create_default_organization()
|
||||
user = server.user_manager.create_default_user()
|
||||
other_user = server.user_manager.create_user(PydanticUser(name="other", organization_id=org.id))
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
@@ -60,7 +126,7 @@ def tool_fixture(server: SyncServer):
|
||||
tool = server.tool_manager.create_tool(tool, actor=user)
|
||||
|
||||
# Yield the created tool, organization, and user for use in tests
|
||||
yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool}
|
||||
yield {"tool": tool, "organization": org, "user": user, "tool_create": tool}
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
@@ -172,15 +238,13 @@ def test_update_user(server: SyncServer):
|
||||
# ======================================================================================================================
|
||||
# Tool Manager Tests
|
||||
# ======================================================================================================================
|
||||
def test_create_tool(server: SyncServer, tool_fixture):
|
||||
def test_create_tool(server: SyncServer, tool_fixture, default_user, default_organization):
|
||||
tool = tool_fixture["tool"]
|
||||
tool_create = tool_fixture["tool_create"]
|
||||
user = tool_fixture["user"]
|
||||
org = tool_fixture["organization"]
|
||||
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert tool.created_by_id == user.id
|
||||
assert tool.organization_id == org.id
|
||||
assert tool.created_by_id == default_user.id
|
||||
assert tool.organization_id == default_organization.id
|
||||
assert tool.description == tool_create.description
|
||||
assert tool.tags == tool_create.tags
|
||||
assert tool.source_code == tool_create.source_code
|
||||
@@ -188,12 +252,11 @@ def test_create_tool(server: SyncServer, tool_fixture):
|
||||
assert tool.json_schema == derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
|
||||
|
||||
|
||||
def test_get_tool_by_id(server: SyncServer, tool_fixture):
|
||||
def test_get_tool_by_id(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
|
||||
# Fetch the tool by ID using the manager method
|
||||
fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the fetched tool matches the created tool
|
||||
assert fetched_tool.id == tool.id
|
||||
@@ -204,54 +267,51 @@ def test_get_tool_by_id(server: SyncServer, tool_fixture):
|
||||
assert fetched_tool.source_type == tool.source_type
|
||||
|
||||
|
||||
def test_get_tool_with_actor(server: SyncServer, tool_fixture):
|
||||
def test_get_tool_with_actor(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
|
||||
# Fetch the tool by name and organization ID
|
||||
fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=user)
|
||||
fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=default_user)
|
||||
|
||||
# Assertions to check if the fetched tool matches the created tool
|
||||
assert fetched_tool.id == tool.id
|
||||
assert fetched_tool.name == tool.name
|
||||
assert fetched_tool.created_by_id == user.id
|
||||
assert fetched_tool.created_by_id == default_user.id
|
||||
assert fetched_tool.description == tool.description
|
||||
assert fetched_tool.tags == tool.tags
|
||||
assert fetched_tool.source_code == tool.source_code
|
||||
assert fetched_tool.source_type == tool.source_type
|
||||
|
||||
|
||||
def test_list_tools(server: SyncServer, tool_fixture):
|
||||
def test_list_tools(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
|
||||
# List tools (should include the one created by the fixture)
|
||||
tools = server.tool_manager.list_tools(actor=user)
|
||||
tools = server.tool_manager.list_tools(actor=default_user)
|
||||
|
||||
# Assertions to check that the created tool is listed
|
||||
assert len(tools) == 1
|
||||
assert any(t.id == tool.id for t in tools)
|
||||
|
||||
|
||||
def test_update_tool_by_id(server: SyncServer, tool_fixture):
|
||||
def test_update_tool_by_id(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
updated_description = "updated_description"
|
||||
|
||||
# Create a ToolUpdate object to modify the tool's description
|
||||
tool_update = ToolUpdate(description=updated_description)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the update was successful
|
||||
assert updated_tool.description == updated_description
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture):
|
||||
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture, default_user):
|
||||
def counter_tool(counter: int):
|
||||
"""
|
||||
Args:
|
||||
@@ -267,7 +327,6 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
|
||||
|
||||
# Test begins
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
og_json_schema = tool_fixture["tool_create"].json_schema
|
||||
|
||||
source_code = parse_source_code(counter_tool)
|
||||
@@ -276,10 +335,10 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
|
||||
tool_update = ToolUpdate(source_code=source_code)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the update was successful, and json_schema is updated as well
|
||||
assert updated_tool.source_code == source_code
|
||||
@@ -289,7 +348,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
|
||||
assert updated_tool.json_schema == new_schema
|
||||
|
||||
|
||||
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture):
|
||||
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture, default_user):
|
||||
def counter_tool(counter: int):
|
||||
"""
|
||||
Args:
|
||||
@@ -305,7 +364,6 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
|
||||
|
||||
# Test begins
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
og_json_schema = tool_fixture["tool_create"].json_schema
|
||||
|
||||
source_code = parse_source_code(counter_tool)
|
||||
@@ -315,10 +373,10 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
|
||||
tool_update = ToolUpdate(name=name, source_code=source_code)
|
||||
|
||||
# Update the tool using the manager method
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
|
||||
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
|
||||
# Assertions to check if the update was successful, and json_schema is updated as well
|
||||
assert updated_tool.source_code == source_code
|
||||
@@ -329,10 +387,8 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
|
||||
assert updated_tool.name == name
|
||||
|
||||
|
||||
def test_update_tool_multi_user(server: SyncServer, tool_fixture):
|
||||
def test_update_tool_multi_user(server: SyncServer, tool_fixture, default_user, other_user):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
other_user = tool_fixture["other_user"]
|
||||
updated_description = "updated_description"
|
||||
|
||||
# Create a ToolUpdate object to modify the tool's description
|
||||
@@ -343,18 +399,99 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture):
|
||||
|
||||
# Check that the created_by and last_updated_by fields are correct
|
||||
# Fetch the updated tool to verify the changes
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
|
||||
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
|
||||
|
||||
assert updated_tool.last_updated_by_id == other_user.id
|
||||
assert updated_tool.created_by_id == user.id
|
||||
assert updated_tool.created_by_id == default_user.id
|
||||
|
||||
|
||||
def test_delete_tool_by_id(server: SyncServer, tool_fixture):
|
||||
def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user):
|
||||
tool = tool_fixture["tool"]
|
||||
user = tool_fixture["user"]
|
||||
|
||||
# Delete the tool using the manager method
|
||||
server.tool_manager.delete_tool_by_id(tool.id, actor=user)
|
||||
server.tool_manager.delete_tool_by_id(tool.id, actor=default_user)
|
||||
|
||||
tools = server.tool_manager.list_tools(actor=user)
|
||||
tools = server.tool_manager.list_tools(actor=default_user)
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentsTagsManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_add_tag_to_agent(server: SyncServer, sarah_agent, default_user):
|
||||
# Add a tag to the agent
|
||||
tag_name = "test_tag"
|
||||
tag_association = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
|
||||
|
||||
# Assert that the tag association was created correctly
|
||||
assert tag_association.agent_id == sarah_agent.id
|
||||
assert tag_association.tag == tag_name
|
||||
|
||||
|
||||
def test_add_duplicate_tag_to_agent(server: SyncServer, sarah_agent, default_user):
|
||||
# Add the same tag twice to the agent
|
||||
tag_name = "test_tag"
|
||||
first_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
|
||||
duplicate_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
|
||||
|
||||
# Assert that the second addition returns the existing tag without creating a duplicate
|
||||
assert first_tag.agent_id == duplicate_tag.agent_id
|
||||
assert first_tag.tag == duplicate_tag.tag
|
||||
|
||||
# Get all the tags belonging to the agent
|
||||
tags = server.agents_tags_manager.get_tags_for_agent(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len(tags) == 1
|
||||
assert tags[0] == first_tag.tag
|
||||
|
||||
|
||||
def test_delete_tag_from_agent(server: SyncServer, sarah_agent, default_user):
|
||||
# Add a tag, then delete it
|
||||
tag_name = "test_tag"
|
||||
server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
|
||||
server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
|
||||
|
||||
# Assert the tag was deleted
|
||||
agent_tags = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user)
|
||||
assert sarah_agent.id not in agent_tags
|
||||
|
||||
|
||||
def test_delete_nonexistent_tag_from_agent(server: SyncServer, sarah_agent, default_user):
|
||||
# Attempt to delete a tag that doesn't exist
|
||||
tag_name = "nonexistent_tag"
|
||||
with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{sarah_agent.id}'"):
|
||||
server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
|
||||
|
||||
|
||||
def test_delete_tag_from_nonexistent_agent(server: SyncServer, default_user):
|
||||
# Attempt to delete a tag that doesn't exist
|
||||
tag_name = "nonexistent_tag"
|
||||
agent_id = "abc"
|
||||
with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{agent_id}'"):
|
||||
server.agents_tags_manager.delete_tag_from_agent(agent_id=agent_id, tag=tag_name, actor=default_user)
|
||||
|
||||
|
||||
def test_get_agents_by_tag(server: SyncServer, sarah_agent, charles_agent, default_user, default_organization):
|
||||
# Add a shared tag to multiple agents
|
||||
tag_name = "shared_tag"
|
||||
|
||||
# Add the same tag to both agents
|
||||
server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
|
||||
server.agents_tags_manager.add_tag_to_agent(agent_id=charles_agent.id, tag=tag_name, actor=default_user)
|
||||
|
||||
# Retrieve agents by tag
|
||||
agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user)
|
||||
|
||||
# Assert that both agents are returned for the tag
|
||||
assert sarah_agent.id in agent_ids
|
||||
assert charles_agent.id in agent_ids
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
# Delete tags from only sarah agent
|
||||
server.agents_tags_manager.delete_all_tags_from_agent(agent_id=sarah_agent.id, actor=default_user)
|
||||
agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user)
|
||||
# Assert that both agents are returned for the tag
|
||||
assert sarah_agent.id not in agent_ids
|
||||
assert charles_agent.id in agent_ids
|
||||
assert len(agent_ids) == 1
|
||||
|
||||
Reference in New Issue
Block a user