diff --git a/alembic/versions/b6d7ca024aa9_add_agents_tags_table.py b/alembic/versions/b6d7ca024aa9_add_agents_tags_table.py new file mode 100644 index 00000000..2aec8a09 --- /dev/null +++ b/alembic/versions/b6d7ca024aa9_add_agents_tags_table.py @@ -0,0 +1,52 @@ +"""Add agents tags table + +Revision ID: b6d7ca024aa9 +Revises: d14ae606614c +Create Date: 2024-11-06 10:48:08.424108 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "b6d7ca024aa9" +down_revision: Union[str, None] = "d14ae606614c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "agents_tags", + sa.Column("agent_id", sa.String(), nullable=False), + sa.Column("tag", sa.String(), nullable=False), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["agent_id"], + ["agents.id"], + ), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("agent_id", "id"), + sa.UniqueConstraint("agent_id", "tag", name="unique_agent_tag"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("agents_tags") + # ### end Alembic commands ### diff --git a/letta/client/client.py b/letta/client/client.py index 840c85b9..c69c5478 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -334,8 +334,12 @@ class RESTClient(AbstractClient): self._default_llm_config = default_llm_config self._default_embedding_config = default_embedding_config - def list_agents(self) -> List[AgentState]: - response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers) + def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]: + params = {} + if tags: + params["tags"] = tags + + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params) return [AgentState(**agent) for agent in response.json()] def agent_exists(self, agent_id: str) -> bool: @@ -480,6 +484,7 @@ class RESTClient(AbstractClient): description: Optional[str] = None, system: Optional[str] = None, tools: Optional[List[str]] = None, + tags: Optional[List[str]] = None, metadata: Optional[Dict] = None, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, @@ -509,6 +514,7 @@ class RESTClient(AbstractClient): name=name, system=system, tools=tools, + tags=tags, description=description, metadata_=metadata, llm_config=llm_config, @@ -1617,13 +1623,10 @@ class LocalClient(AbstractClient): self.organization = self.server.get_organization_or_default(self.org_id) # agents - def list_agents(self) -> List[AgentState]: + def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]: self.interface.clear() - # TODO: fix the server function - # return self.server.list_agents(user_id=self.user_id) - - return self.server.ms.list_agents(user_id=self.user_id) + return self.server.list_agents(user_id=self.user_id, tags=tags) def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: """ @@ -1757,6 +1760,7 @@ class LocalClient(AbstractClient): description: Optional[str] = None, system: Optional[str] = None, tools: Optional[List[str]] = None, + tags: Optional[List[str]] = None, metadata: Optional[Dict] = None, llm_config: Optional[LLMConfig] = None, embedding_config: Optional[EmbeddingConfig] = None, @@ -1788,6 +1792,7 @@ class LocalClient(AbstractClient): name=name, system=system, tools=tools, + tags=tags, description=description, metadata_=metadata, llm_config=llm_config, @@ -1872,7 +1877,7 @@ class LocalClient(AbstractClient): agent_state (AgentState): State of the agent """ self.interface.clear() - return self.server.get_agent(agent_name=agent_name, user_id=self.user_id, agent_id=None) + return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None) def get_agent(self, agent_id: str) -> AgentState: """ diff --git a/letta/metadata.py b/letta/metadata.py index 3f9eea5a..dddced11 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -493,6 +493,7 @@ class MetadataStore: fields = vars(agent) fields["memory"] = agent.memory.to_dict() del fields["_internal_memory"] + del fields["tags"] session.add(AgentModel(**fields)) session.commit() @@ -531,6 +532,7 @@ class MetadataStore: if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever fields["memory"] = agent.memory.to_dict() del fields["_internal_memory"] + del fields["tags"] session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields) session.commit() diff --git a/letta/orm/agents_tags.py b/letta/orm/agents_tags.py new file mode 100644 index 00000000..1910f528 --- /dev/null +++ b/letta/orm/agents_tags.py @@ -0,0 +1,28 @@ +from typing import TYPE_CHECKING + +from sqlalchemy import ForeignKey, String, UniqueConstraint +from sqlalchemy.orm import Mapped, mapped_column, relationship + +from letta.orm.mixins import OrganizationMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase +from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags + +if TYPE_CHECKING: + from letta.orm.organization import Organization + + +class AgentsTags(SqlalchemyBase, OrganizationMixin): + """Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering.""" + + __tablename__ = "agents_tags" + __pydantic_model__ = PydanticAgentsTags + __table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),) + + # The agent associated with this tag + agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True) + + # The name of the tag + tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.") + + # relationships + organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags") diff --git a/letta/orm/organization.py b/letta/orm/organization.py index 88f8ea5d..4b641607 100644 --- a/letta/orm/organization.py +++ b/letta/orm/organization.py @@ -23,6 +23,7 @@ class Organization(SqlalchemyBase): users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan") tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan") + agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan") # TODO: Map these relationships later when we actually make these models # below is just a suggestion diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 20728d7b..9469fdb7 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -112,6 +112,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): self.is_deleted = True return self.update(db_session) + def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None: + """Permanently removes the record from the database.""" + if actor: + logger.info(f"User {actor.id} requested hard deletion of {self.__class__.__name__} with ID {self.id}") + + with db_session as session: + try: + session.delete(self) + session.commit() + except Exception as e: + session.rollback() + logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}") + raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}") + else: + logger.info(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") + def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]: if actor: self._set_created_and_updated_by_fields(actor.id) diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 92661024..648546ef 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -64,6 +64,9 @@ class AgentState(BaseAgent, validate_assignment=True): # tool rules tool_rules: Optional[List[BaseToolRule]] = Field(default=None, description="The list of tool rules.") + # tags + tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") + # system prompt system: str = Field(..., description="The system prompt used by the agent.") @@ -108,6 +111,7 @@ class CreateAgent(BaseAgent): memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.") tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.") + tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") agent_type: Optional[AgentType] = Field(None, description="The type of agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") @@ -148,6 +152,7 @@ class UpdateAgentState(BaseAgent): id: str = Field(..., description="The id of the agent.") name: Optional[str] = Field(None, description="The name of the agent.") tools: Optional[List[str]] = Field(None, description="The tools used by the agent.") + tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.") system: Optional[str] = Field(None, description="The system prompt used by the agent.") llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.") embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.") diff --git a/letta/schemas/agents_tags.py b/letta/schemas/agents_tags.py new file mode 100644 index 00000000..eba5e0db --- /dev/null +++ b/letta/schemas/agents_tags.py @@ -0,0 +1,33 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from letta.schemas.letta_base import LettaBase + + +class AgentsTagsBase(LettaBase): + __id_prefix__ = "agents_tags" + + +class AgentsTags(AgentsTagsBase): + """ + Schema representing the relationship between tags and agents. + + Parameters: + agent_id (str): The ID of the associated agent. + tag_id (str): The ID of the associated tag. + tag_name (str): The name of the tag. + created_at (datetime): The date this relationship was created. + """ + + id: str = AgentsTagsBase.generate_id_field() + agent_id: str = Field(..., description="The ID of the associated agent.") + tag: str = Field(..., description="The name of the tag.") + created_at: Optional[datetime] = Field(None, description="The creation date of the association.") + updated_at: Optional[datetime] = Field(None, description="The update date of the tag.") + is_deleted: bool = Field(False, description="Whether this tag is deleted or not.") + + +class AgentsTagsCreate(AgentsTagsBase): + tag: str = Field(..., description="The tag name.") diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 6836d765..d100f9bf 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -41,6 +41,7 @@ router = APIRouter(prefix="/agents", tags=["agents"]) @router.get("/", response_model=List[AgentState], operation_id="list_agents") def list_agents( name: Optional[str] = Query(None, description="Name of the agent"), + tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): @@ -50,7 +51,7 @@ def list_agents( """ actor = server.get_user_or_default(user_id=user_id) - agents = server.list_agents(user_id=actor.id) + agents = server.list_agents(user_id=actor.id, tags=tags) # TODO: move this logic to the ORM if name: agents = [a for a in agents if a.name == name] @@ -534,124 +535,3 @@ async def send_message_to_agent( traceback.print_exc() raise HTTPException(status_code=500, detail=f"{e}") - - -##### MISSING ####### - -# @router.post("/{agent_id}/command") -# def run_command( -# agent_id: "UUID", -# command: "AgentCommandRequest", -# -# server: "SyncServer" = Depends(get_letta_server), -# ): -# """ -# Execute a command on a specified agent. - -# This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately. - -# Raises an HTTPException for any processing errors. -# """ -# actor = server.get_current_user() -# -# response = server.run_command(user_id=actor.id, -# agent_id=agent_id, -# command=command.command) - -# return AgentCommandResponse(response=response) - -# @router.get("/{agent_id}/config") -# def get_agent_config( -# agent_id: "UUID", -# -# server: "SyncServer" = Depends(get_letta_server), -# ): -# """ -# Retrieve the configuration for a specific agent. - -# This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs. -# """ -# actor = server.get_current_user() -# -# if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id): -## agent does not exist -# raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.") - -# agent_state = server.get_agent_config(user_id=actor.id, agent_id=agent_id) -## get sources -# attached_sources = server.list_attached_sources(agent_id=agent_id) - -## configs -# llm_config = LLMConfig(**vars(agent_state.llm_config)) -# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config)) - -# return GetAgentResponse( -# agent_state=AgentState( -# id=agent_state.id, -# name=agent_state.name, -# user_id=agent_state.user_id, -# llm_config=llm_config, -# embedding_config=embedding_config, -# state=agent_state.state, -# created_at=int(agent_state.created_at.timestamp()), -# tools=agent_state.tools, -# system=agent_state.system, -# metadata=agent_state._metadata, -# ), -# last_run_at=None, # TODO -# sources=attached_sources, -# ) - -# @router.patch("/{agent_id}/rename", response_model=GetAgentResponse) -# def update_agent_name( -# agent_id: "UUID", -# agent_rename: AgentRenameRequest, -# -# server: "SyncServer" = Depends(get_letta_server), -# ): -# """ -# Updates the name of a specific agent. - -# This changes the name of the agent in the database but does NOT edit the agent's persona. -# """ -# valid_name = agent_rename.agent_name -# actor = server.get_current_user() -# -# agent_state = server.rename_agent(user_id=actor.id, agent_id=agent_id, new_agent_name=valid_name) -## get sources -# attached_sources = server.list_attached_sources(agent_id=agent_id) -# llm_config = LLMConfig(**vars(agent_state.llm_config)) -# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config)) - -# return GetAgentResponse( -# agent_state=AgentState( -# id=agent_state.id, -# name=agent_state.name, -# user_id=agent_state.user_id, -# llm_config=llm_config, -# embedding_config=embedding_config, -# state=agent_state.state, -# created_at=int(agent_state.created_at.timestamp()), -# tools=agent_state.tools, -# system=agent_state.system, -# ), -# last_run_at=None, # TODO -# sources=attached_sources, -# ) - - -# @router.get("/{agent_id}/archival/all", response_model=GetAgentArchivalMemoryResponse) -# def get_agent_archival_memory_all( -# agent_id: "UUID", -# -# server: "SyncServer" = Depends(get_letta_server), -# ): -# """ -# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once). -# """ -# actor = server.get_current_user() -# -# archival_memories = server.get_all_archival_memories(user_id=actor.id, agent_id=agent_id) -# print("archival_memories:", archival_memories) -# archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories] -# return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects) diff --git a/letta/server/server.py b/letta/server/server.py index 03272a77..39baf02a 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -82,6 +82,7 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User +from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.organization_manager import OrganizationManager from letta.services.tool_manager import ToolManager from letta.services.user_manager import UserManager @@ -248,6 +249,7 @@ class SyncServer(Server): self.organization_manager = OrganizationManager() self.user_manager = UserManager() self.tool_manager = ToolManager() + self.agents_tags_manager = AgentsTagsManager() # Make default user and org if init_with_default_org_and_user: @@ -969,6 +971,19 @@ class SyncServer(Server): if request.metadata_: letta_agent.agent_state.metadata_ = request.metadata_ + # Manage tag state + if request.tags is not None: + current_tags = set(self.agents_tags_manager.get_tags_for_agent(agent_id=letta_agent.agent_state.id, actor=actor)) + target_tags = set(request.tags) + + tags_to_add = target_tags - current_tags + tags_to_remove = current_tags - target_tags + + for tag in tags_to_add: + self.agents_tags_manager.add_tag_to_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor) + for tag in tags_to_remove: + self.agents_tags_manager.delete_tag_from_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor) + # save the agent assert isinstance(letta_agent.memory, Memory) save_agent(letta_agent, self.ms) @@ -1079,16 +1094,19 @@ class SyncServer(Server): } return agent_config - def list_agents( - self, - user_id: str, - ) -> List[AgentState]: + def list_agents(self, user_id: str, tags: Optional[List[str]] = None) -> List[AgentState]: """List all available agents to a user""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") + user = self.user_manager.get_user_by_id(user_id=user_id) - agents_states = self.ms.list_agents(user_id=user_id) - return agents_states + if tags is None: + agents_states = self.ms.list_agents(user_id=user_id) + return agents_states + else: + agent_ids = [] + for tag in tags: + agent_ids += self.agents_tags_manager.get_agents_by_tag(tag=tag, actor=user) + + return [self.get_agent_state(user_id=user.id, agent_id=agent_id) for agent_id in agent_ids] def get_blocks( self, @@ -1160,18 +1178,6 @@ class SyncServer(Server): raise ValueError("Source does not exist") return existing_source.id - def get_agent(self, user_id: str, agent_id: Optional[str] = None, agent_name: Optional[str] = None): - """Get the agent state""" - return self.ms.get_agent(agent_id=agent_id, agent_name=agent_name, user_id=user_id) - - # def get_user(self, user_id: str) -> User: - # """Get the user""" - # user = self.user_manager.get_user_by_id(user_id=user_id) - # if user is None: - # raise ValueError(f"User with user_id {user_id} does not exist") - # else: - # return user - def get_agent_memory(self, agent_id: str) -> Memory: """Return the memory of an agent (core memory)""" agent = self._get_or_load_agent(agent_id=agent_id) @@ -1389,8 +1395,7 @@ class SyncServer(Server): def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]: """Return the config of an agent""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") + user = self.user_manager.get_user_by_id(user_id=user_id) if agent_id: if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: return None @@ -1403,7 +1408,11 @@ class SyncServer(Server): # Get the agent object (loaded in memory) letta_agent = self._get_or_load_agent(agent_id=agent_id) assert isinstance(letta_agent.memory, Memory) - return letta_agent.agent_state.model_copy(deep=True) + agent_state = letta_agent.agent_state.model_copy(deep=True) + + # Load the tags in for the agent_state + agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user) + return agent_state def get_server_config(self, include_defaults: bool = False) -> dict: """Return the base config""" @@ -1485,8 +1494,11 @@ class SyncServer(Server): def delete_agent(self, user_id: str, agent_id: str): """Delete an agent in the database""" - if self.user_manager.get_user_by_id(user_id=user_id) is None: - raise ValueError(f"User user_id={user_id} does not exist") + actor = self.user_manager.get_user_by_id(user_id=user_id) + # TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL + # TODO: EVENTUALLY WE GET AUTO-DELETES WHEN WE SPECIFY RELATIONSHIPS IN THE ORM + self.agents_tags_manager.delete_all_tags_from_agent(agent_id=agent_id, actor=actor) + if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None: raise ValueError(f"Agent agent_id={agent_id} does not exist") diff --git a/letta/services/agents_tags_manager.py b/letta/services/agents_tags_manager.py new file mode 100644 index 00000000..f84ea11b --- /dev/null +++ b/letta/services/agents_tags_manager.py @@ -0,0 +1,64 @@ +from typing import List + +from letta.orm.agents_tags import AgentsTags as AgentsTagsModel +from letta.orm.errors import NoResultFound +from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags +from letta.schemas.user import User as PydanticUser +from letta.utils import enforce_types + + +class AgentsTagsManager: + """Manager class to handle business logic related to Tags.""" + + def __init__(self): + from letta.server.server import db_context + + self.session_maker = db_context + + @enforce_types + def add_tag_to_agent(self, agent_id: str, tag: str, actor: PydanticUser) -> PydanticAgentsTags: + """Add a tag to an agent.""" + with self.session_maker() as session: + # Check if the tag already exists for this agent + try: + agents_tags_model = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor) + return agents_tags_model.to_pydantic() + except NoResultFound: + agents_tags = PydanticAgentsTags(agent_id=agent_id, tag=tag).model_dump(exclude_none=True) + new_tag = AgentsTagsModel(**agents_tags, organization_id=actor.organization_id) + new_tag.create(session, actor=actor) + return new_tag.to_pydantic() + + @enforce_types + def delete_all_tags_from_agent(self, agent_id: str, actor: PydanticUser): + """Delete a tag from an agent. This is a permanent hard delete.""" + tags = self.get_tags_for_agent(agent_id=agent_id, actor=actor) + for tag in tags: + self.delete_tag_from_agent(agent_id=agent_id, tag=tag, actor=actor) + + @enforce_types + def delete_tag_from_agent(self, agent_id: str, tag: str, actor: PydanticUser): + """Delete a tag from an agent.""" + with self.session_maker() as session: + try: + # Retrieve and delete the tag association + tag_association = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor) + tag_association.hard_delete(session, actor=actor) + except NoResultFound: + raise ValueError(f"Tag '{tag}' not found for agent '{agent_id}'.") + + @enforce_types + def get_agents_by_tag(self, tag: str, actor: PydanticUser) -> List[str]: + """Retrieve all agent IDs associated with a specific tag.""" + with self.session_maker() as session: + # Query for all agents with the given tag + agents_with_tag = AgentsTagsModel.list(db_session=session, tag=tag, organization_id=actor.organization_id) + return [record.agent_id for record in agents_with_tag] + + @enforce_types + def get_tags_for_agent(self, agent_id: str, actor: PydanticUser) -> List[str]: + """Retrieve all tags associated with a specific agent.""" + with self.session_maker() as session: + # Query for all tags associated with the given agent + tags_for_agent = AgentsTagsModel.list(db_session=session, agent_id=agent_id, organization_id=actor.organization_id) + return [record.tag for record in tags_for_agent] diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 94c73188..778e5307 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -133,7 +133,7 @@ class ToolManager: """Delete a tool by its ID.""" with self.session_maker() as session: try: - tool = ToolModel.read(db_session=session, identifier=tool_id) + tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor) tool.delete(db_session=session, actor=actor) except NoResultFound: raise ValueError(f"Tool with id {tool_id} not found.") diff --git a/tests/test_client.py b/tests/test_client.py index dda27cce..3c1647b7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -76,7 +76,6 @@ def client(request): client = create_client(base_url=server_url, token=None) # This yields control back to the test function else: # use local client (no server) - server_url = None client = create_client() client.set_default_llm_config(LLMConfig.default_config("gpt-4")) @@ -88,7 +87,6 @@ def client(request): @pytest.fixture(scope="module") def agent(client: Union[LocalClient, RESTClient]): agent_state = client.create_agent(name=test_agent_name) - print("AGENT ID", agent_state.id) yield agent_state # delete agent @@ -676,3 +674,38 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent: len(custom_agent_state.message_ids) == len(custom_sequence) + 1 ), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}" assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence] + + +def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState): + """ + Comprehensive happy path test for adding, retrieving, and managing tags on an agent. + """ + + # Step 1: Add multiple tags to the agent + tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"] + client.update_agent(agent_id=agent.id, tags=tags_to_add) + + # Step 2: Retrieve tags for the agent and verify they match the added tags + retrieved_tags = client.get_agent(agent_id=agent.id).tags + assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}" + + # Step 3: Retrieve agents by each tag to ensure the agent is associated correctly + for tag in tags_to_add: + agents_with_tag = client.list_agents(tags=[tag]) + assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'" + + # Step 4: Delete a specific tag from the agent and verify its removal + tag_to_delete = tags_to_add.pop() + client.update_agent(agent_id=agent.id, tags=tags_to_add) + + # Verify the tag is removed from the agent's tags + remaining_tags = client.get_agent(agent_id=agent.id).tags + assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected" + assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}" + + # Step 5: Delete all remaining tags from the agent + client.update_agent(agent_id=agent.id, tags=[]) + + # Verify all tags are removed + final_tags = client.get_agent(agent_id=agent.id).tags + assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}" diff --git a/tests/test_managers.py b/tests/test_managers.py index 82c256d1..6bfbf548 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -6,6 +6,10 @@ from letta.functions.functions import derive_openai_json_schema, parse_source_co from letta.orm.organization import Organization from letta.orm.tool import Tool from letta.orm.user import User +from letta.schemas.agent import CreateAgent +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.memory import ChatMemory from letta.schemas.organization import Organization as PydanticOrganization from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolUpdate @@ -29,7 +33,68 @@ def clear_tables(server: SyncServer): @pytest.fixture -def tool_fixture(server: SyncServer): +def default_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + org = server.organization_manager.create_default_organization() + yield org + + +@pytest.fixture +def default_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_default_user(org_id=default_organization.id) + yield user + + +@pytest.fixture +def other_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id)) + yield user + + +@pytest.fixture +def sarah_agent(server: SyncServer, default_user, default_organization): + """Fixture to create and return a sample agent within the default organization.""" + agent_state = server.create_agent( + request=CreateAgent( + name="sarah_agent", + memory=ChatMemory( + human="Charles", + persona="I am a helpful assistant", + ), + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + yield agent_state + + server.delete_agent(user_id=default_user.id, agent_id=agent_state.id) + + +@pytest.fixture +def charles_agent(server: SyncServer, default_user, default_organization): + """Fixture to create and return a sample agent within the default organization.""" + agent_state = server.create_agent( + request=CreateAgent( + name="charles_agent", + memory=ChatMemory( + human="Sarah", + persona="I am a helpful assistant", + ), + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + yield agent_state + + server.delete_agent(user_id=default_user.id, agent_id=agent_state.id) + + +@pytest.fixture +def tool_fixture(server: SyncServer, default_user, default_organization): """Fixture to create a tool with default settings and clean up after the test.""" def print_tool(message: str): @@ -43,6 +108,7 @@ def tool_fixture(server: SyncServer): print(message) return message + # Set up tool details source_code = parse_source_code(print_tool) source_type = "python" description = "test_description" @@ -50,9 +116,9 @@ def tool_fixture(server: SyncServer): org = server.organization_manager.create_default_organization() user = server.user_manager.create_default_user() - other_user = server.user_manager.create_user(PydanticUser(name="other", organization_id=org.id)) tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type) derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + derived_name = derived_json_schema["name"] tool.json_schema = derived_json_schema tool.name = derived_name @@ -60,7 +126,7 @@ def tool_fixture(server: SyncServer): tool = server.tool_manager.create_tool(tool, actor=user) # Yield the created tool, organization, and user for use in tests - yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool} + yield {"tool": tool, "organization": org, "user": user, "tool_create": tool} @pytest.fixture(scope="module") @@ -172,15 +238,13 @@ def test_update_user(server: SyncServer): # ====================================================================================================================== # Tool Manager Tests # ====================================================================================================================== -def test_create_tool(server: SyncServer, tool_fixture): +def test_create_tool(server: SyncServer, tool_fixture, default_user, default_organization): tool = tool_fixture["tool"] tool_create = tool_fixture["tool_create"] - user = tool_fixture["user"] - org = tool_fixture["organization"] # Assertions to ensure the created tool matches the expected values - assert tool.created_by_id == user.id - assert tool.organization_id == org.id + assert tool.created_by_id == default_user.id + assert tool.organization_id == default_organization.id assert tool.description == tool_create.description assert tool.tags == tool_create.tags assert tool.source_code == tool_create.source_code @@ -188,12 +252,11 @@ def test_create_tool(server: SyncServer, tool_fixture): assert tool.json_schema == derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name) -def test_get_tool_by_id(server: SyncServer, tool_fixture): +def test_get_tool_by_id(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # Fetch the tool by ID using the manager method - fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the fetched tool matches the created tool assert fetched_tool.id == tool.id @@ -204,54 +267,51 @@ def test_get_tool_by_id(server: SyncServer, tool_fixture): assert fetched_tool.source_type == tool.source_type -def test_get_tool_with_actor(server: SyncServer, tool_fixture): +def test_get_tool_with_actor(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # Fetch the tool by name and organization ID - fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=user) + fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=default_user) # Assertions to check if the fetched tool matches the created tool assert fetched_tool.id == tool.id assert fetched_tool.name == tool.name - assert fetched_tool.created_by_id == user.id + assert fetched_tool.created_by_id == default_user.id assert fetched_tool.description == tool.description assert fetched_tool.tags == tool.tags assert fetched_tool.source_code == tool.source_code assert fetched_tool.source_type == tool.source_type -def test_list_tools(server: SyncServer, tool_fixture): +def test_list_tools(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # List tools (should include the one created by the fixture) - tools = server.tool_manager.list_tools(actor=user) + tools = server.tool_manager.list_tools(actor=default_user) # Assertions to check that the created tool is listed assert len(tools) == 1 assert any(t.id == tool.id for t in tools) -def test_update_tool_by_id(server: SyncServer, tool_fixture): +def test_update_tool_by_id(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] updated_description = "updated_description" # Create a ToolUpdate object to modify the tool's description tool_update = ToolUpdate(description=updated_description) # Update the tool using the manager method - server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user) + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user) # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the update was successful assert updated_tool.description == updated_description -def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture): +def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture, default_user): def counter_tool(counter: int): """ Args: @@ -267,7 +327,6 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t # Test begins tool = tool_fixture["tool"] - user = tool_fixture["user"] og_json_schema = tool_fixture["tool_create"].json_schema source_code = parse_source_code(counter_tool) @@ -276,10 +335,10 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t tool_update = ToolUpdate(source_code=source_code) # Update the tool using the manager method - server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user) + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user) # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the update was successful, and json_schema is updated as well assert updated_tool.source_code == source_code @@ -289,7 +348,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t assert updated_tool.json_schema == new_schema -def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture): +def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture, default_user): def counter_tool(counter: int): """ Args: @@ -305,7 +364,6 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_ # Test begins tool = tool_fixture["tool"] - user = tool_fixture["user"] og_json_schema = tool_fixture["tool_create"].json_schema source_code = parse_source_code(counter_tool) @@ -315,10 +373,10 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_ tool_update = ToolUpdate(name=name, source_code=source_code) # Update the tool using the manager method - server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user) + server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user) # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) # Assertions to check if the update was successful, and json_schema is updated as well assert updated_tool.source_code == source_code @@ -329,10 +387,8 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_ assert updated_tool.name == name -def test_update_tool_multi_user(server: SyncServer, tool_fixture): +def test_update_tool_multi_user(server: SyncServer, tool_fixture, default_user, other_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] - other_user = tool_fixture["other_user"] updated_description = "updated_description" # Create a ToolUpdate object to modify the tool's description @@ -343,18 +399,99 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture): # Check that the created_by and last_updated_by fields are correct # Fetch the updated tool to verify the changes - updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user) + updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user) assert updated_tool.last_updated_by_id == other_user.id - assert updated_tool.created_by_id == user.id + assert updated_tool.created_by_id == default_user.id -def test_delete_tool_by_id(server: SyncServer, tool_fixture): +def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user): tool = tool_fixture["tool"] - user = tool_fixture["user"] # Delete the tool using the manager method - server.tool_manager.delete_tool_by_id(tool.id, actor=user) + server.tool_manager.delete_tool_by_id(tool.id, actor=default_user) - tools = server.tool_manager.list_tools(actor=user) + tools = server.tool_manager.list_tools(actor=default_user) assert len(tools) == 0 + + +# ====================================================================================================================== +# AgentsTagsManager Tests +# ====================================================================================================================== + + +def test_add_tag_to_agent(server: SyncServer, sarah_agent, default_user): + # Add a tag to the agent + tag_name = "test_tag" + tag_association = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) + + # Assert that the tag association was created correctly + assert tag_association.agent_id == sarah_agent.id + assert tag_association.tag == tag_name + + +def test_add_duplicate_tag_to_agent(server: SyncServer, sarah_agent, default_user): + # Add the same tag twice to the agent + tag_name = "test_tag" + first_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) + duplicate_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) + + # Assert that the second addition returns the existing tag without creating a duplicate + assert first_tag.agent_id == duplicate_tag.agent_id + assert first_tag.tag == duplicate_tag.tag + + # Get all the tags belonging to the agent + tags = server.agents_tags_manager.get_tags_for_agent(agent_id=sarah_agent.id, actor=default_user) + assert len(tags) == 1 + assert tags[0] == first_tag.tag + + +def test_delete_tag_from_agent(server: SyncServer, sarah_agent, default_user): + # Add a tag, then delete it + tag_name = "test_tag" + server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) + server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) + + # Assert the tag was deleted + agent_tags = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user) + assert sarah_agent.id not in agent_tags + + +def test_delete_nonexistent_tag_from_agent(server: SyncServer, sarah_agent, default_user): + # Attempt to delete a tag that doesn't exist + tag_name = "nonexistent_tag" + with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{sarah_agent.id}'"): + server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) + + +def test_delete_tag_from_nonexistent_agent(server: SyncServer, default_user): + # Attempt to delete a tag that doesn't exist + tag_name = "nonexistent_tag" + agent_id = "abc" + with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{agent_id}'"): + server.agents_tags_manager.delete_tag_from_agent(agent_id=agent_id, tag=tag_name, actor=default_user) + + +def test_get_agents_by_tag(server: SyncServer, sarah_agent, charles_agent, default_user, default_organization): + # Add a shared tag to multiple agents + tag_name = "shared_tag" + + # Add the same tag to both agents + server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user) + server.agents_tags_manager.add_tag_to_agent(agent_id=charles_agent.id, tag=tag_name, actor=default_user) + + # Retrieve agents by tag + agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user) + + # Assert that both agents are returned for the tag + assert sarah_agent.id in agent_ids + assert charles_agent.id in agent_ids + assert len(agent_ids) == 2 + + # Delete tags from only sarah agent + server.agents_tags_manager.delete_all_tags_from_agent(agent_id=sarah_agent.id, actor=default_user) + agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user) + # Assert that both agents are returned for the tag + assert sarah_agent.id not in agent_ids + assert charles_agent.id in agent_ids + assert len(agent_ids) == 1