feat: Add ability to add tags to agents (#1984)

This commit is contained in:
Matthew Zhou
2024-11-06 16:16:23 -08:00
committed by GitHub
parent 8414a94b96
commit 960f7421c1
14 changed files with 465 additions and 197 deletions

View File

@@ -0,0 +1,52 @@
"""Add agents tags table
Revision ID: b6d7ca024aa9
Revises: d14ae606614c
Create Date: 2024-11-06 10:48:08.424108
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "b6d7ca024aa9"
down_revision: Union[str, None] = "d14ae606614c"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"agents_tags",
sa.Column("agent_id", sa.String(), nullable=False),
sa.Column("tag", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["agent_id"],
["agents.id"],
),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("agent_id", "id"),
sa.UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("agents_tags")
# ### end Alembic commands ###

View File

@@ -334,8 +334,12 @@ class RESTClient(AbstractClient):
self._default_llm_config = default_llm_config
self._default_embedding_config = default_embedding_config
def list_agents(self) -> List[AgentState]:
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers)
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
params = {}
if tags:
params["tags"] = tags
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
return [AgentState(**agent) for agent in response.json()]
def agent_exists(self, agent_id: str) -> bool:
@@ -480,6 +484,7 @@ class RESTClient(AbstractClient):
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
@@ -509,6 +514,7 @@ class RESTClient(AbstractClient):
name=name,
system=system,
tools=tools,
tags=tags,
description=description,
metadata_=metadata,
llm_config=llm_config,
@@ -1617,13 +1623,10 @@ class LocalClient(AbstractClient):
self.organization = self.server.get_organization_or_default(self.org_id)
# agents
def list_agents(self) -> List[AgentState]:
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
self.interface.clear()
# TODO: fix the server function
# return self.server.list_agents(user_id=self.user_id)
return self.server.ms.list_agents(user_id=self.user_id)
return self.server.list_agents(user_id=self.user_id, tags=tags)
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
"""
@@ -1757,6 +1760,7 @@ class LocalClient(AbstractClient):
description: Optional[str] = None,
system: Optional[str] = None,
tools: Optional[List[str]] = None,
tags: Optional[List[str]] = None,
metadata: Optional[Dict] = None,
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
@@ -1788,6 +1792,7 @@ class LocalClient(AbstractClient):
name=name,
system=system,
tools=tools,
tags=tags,
description=description,
metadata_=metadata,
llm_config=llm_config,
@@ -1872,7 +1877,7 @@ class LocalClient(AbstractClient):
agent_state (AgentState): State of the agent
"""
self.interface.clear()
return self.server.get_agent(agent_name=agent_name, user_id=self.user_id, agent_id=None)
return self.server.get_agent_state(agent_name=agent_name, user_id=self.user_id, agent_id=None)
def get_agent(self, agent_id: str) -> AgentState:
"""

View File

@@ -493,6 +493,7 @@ class MetadataStore:
fields = vars(agent)
fields["memory"] = agent.memory.to_dict()
del fields["_internal_memory"]
del fields["tags"]
session.add(AgentModel(**fields))
session.commit()
@@ -531,6 +532,7 @@ class MetadataStore:
if isinstance(agent.memory, Memory): # TODO: this is nasty but this whole class will soon be removed so whatever
fields["memory"] = agent.memory.to_dict()
del fields["_internal_memory"]
del fields["tags"]
session.query(AgentModel).filter(AgentModel.id == agent.id).update(fields)
session.commit()

28
letta/orm/agents_tags.py Normal file
View File

@@ -0,0 +1,28 @@
from typing import TYPE_CHECKING
from sqlalchemy import ForeignKey, String, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
if TYPE_CHECKING:
from letta.orm.organization import Organization
class AgentsTags(SqlalchemyBase, OrganizationMixin):
"""Associates tags with agents, allowing agents to have multiple tags and supporting tag-based filtering."""
__tablename__ = "agents_tags"
__pydantic_model__ = PydanticAgentsTags
__table_args__ = (UniqueConstraint("agent_id", "tag", name="unique_agent_tag"),)
# The agent associated with this tag
agent_id = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
# The name of the tag
tag: Mapped[str] = mapped_column(String, nullable=False, doc="The name of the tag associated with the agent.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="agents_tags")

View File

@@ -23,6 +23,7 @@ class Organization(SqlalchemyBase):
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
agents_tags: Mapped[List["AgentsTags"]] = relationship("AgentsTags", back_populates="organization", cascade="all, delete-orphan")
# TODO: Map these relationships later when we actually make these models
# below is just a suggestion

View File

@@ -112,6 +112,22 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
self.is_deleted = True
return self.update(db_session)
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
"""Permanently removes the record from the database."""
if actor:
logger.info(f"User {actor.id} requested hard deletion of {self.__class__.__name__} with ID {self.id}")
with db_session as session:
try:
session.delete(self)
session.commit()
except Exception as e:
session.rollback()
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
else:
logger.info(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
def update(self, db_session: "Session", actor: Optional["User"] = None) -> Type["SqlalchemyBase"]:
if actor:
self._set_created_and_updated_by_fields(actor.id)

View File

@@ -64,6 +64,9 @@ class AgentState(BaseAgent, validate_assignment=True):
# tool rules
tool_rules: Optional[List[BaseToolRule]] = Field(default=None, description="The list of tool rules.")
# tags
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
# system prompt
system: str = Field(..., description="The system prompt used by the agent.")
@@ -108,6 +111,7 @@ class CreateAgent(BaseAgent):
memory: Optional[Memory] = Field(None, description="The in-context memory of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tool_rules: Optional[List[BaseToolRule]] = Field(None, description="The tool rules governing the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
agent_type: Optional[AgentType] = Field(None, description="The type of agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
@@ -148,6 +152,7 @@ class UpdateAgentState(BaseAgent):
id: str = Field(..., description="The id of the agent.")
name: Optional[str] = Field(None, description="The name of the agent.")
tools: Optional[List[str]] = Field(None, description="The tools used by the agent.")
tags: Optional[List[str]] = Field(None, description="The tags associated with the agent.")
system: Optional[str] = Field(None, description="The system prompt used by the agent.")
llm_config: Optional[LLMConfig] = Field(None, description="The LLM configuration used by the agent.")
embedding_config: Optional[EmbeddingConfig] = Field(None, description="The embedding configuration used by the agent.")

View File

@@ -0,0 +1,33 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from letta.schemas.letta_base import LettaBase
class AgentsTagsBase(LettaBase):
__id_prefix__ = "agents_tags"
class AgentsTags(AgentsTagsBase):
"""
Schema representing the relationship between tags and agents.
Parameters:
agent_id (str): The ID of the associated agent.
tag_id (str): The ID of the associated tag.
tag_name (str): The name of the tag.
created_at (datetime): The date this relationship was created.
"""
id: str = AgentsTagsBase.generate_id_field()
agent_id: str = Field(..., description="The ID of the associated agent.")
tag: str = Field(..., description="The name of the tag.")
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
updated_at: Optional[datetime] = Field(None, description="The update date of the tag.")
is_deleted: bool = Field(False, description="Whether this tag is deleted or not.")
class AgentsTagsCreate(AgentsTagsBase):
tag: str = Field(..., description="The tag name.")

View File

@@ -41,6 +41,7 @@ router = APIRouter(prefix="/agents", tags=["agents"])
@router.get("/", response_model=List[AgentState], operation_id="list_agents")
def list_agents(
name: Optional[str] = Query(None, description="Name of the agent"),
tags: Optional[List[str]] = Query(None, description="List of tags to filter agents by"),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@@ -50,7 +51,7 @@ def list_agents(
"""
actor = server.get_user_or_default(user_id=user_id)
agents = server.list_agents(user_id=actor.id)
agents = server.list_agents(user_id=actor.id, tags=tags)
# TODO: move this logic to the ORM
if name:
agents = [a for a in agents if a.name == name]
@@ -534,124 +535,3 @@ async def send_message_to_agent(
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")
##### MISSING #######
# @router.post("/{agent_id}/command")
# def run_command(
# agent_id: "UUID",
# command: "AgentCommandRequest",
#
# server: "SyncServer" = Depends(get_letta_server),
# ):
# """
# Execute a command on a specified agent.
# This endpoint receives a command to be executed on an agent. It uses the user and agent identifiers to authenticate and route the command appropriately.
# Raises an HTTPException for any processing errors.
# """
# actor = server.get_current_user()
#
# response = server.run_command(user_id=actor.id,
# agent_id=agent_id,
# command=command.command)
# return AgentCommandResponse(response=response)
# @router.get("/{agent_id}/config")
# def get_agent_config(
# agent_id: "UUID",
#
# server: "SyncServer" = Depends(get_letta_server),
# ):
# """
# Retrieve the configuration for a specific agent.
# This endpoint fetches the configuration details for a given agent, identified by the user and agent IDs.
# """
# actor = server.get_current_user()
#
# if not server.ms.get_agent(user_id=actor.id, agent_id=agent_id):
## agent does not exist
# raise HTTPException(status_code=404, detail=f"Agent agent_id={agent_id} not found.")
# agent_state = server.get_agent_config(user_id=actor.id, agent_id=agent_id)
## get sources
# attached_sources = server.list_attached_sources(agent_id=agent_id)
## configs
# llm_config = LLMConfig(**vars(agent_state.llm_config))
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
# return GetAgentResponse(
# agent_state=AgentState(
# id=agent_state.id,
# name=agent_state.name,
# user_id=agent_state.user_id,
# llm_config=llm_config,
# embedding_config=embedding_config,
# state=agent_state.state,
# created_at=int(agent_state.created_at.timestamp()),
# tools=agent_state.tools,
# system=agent_state.system,
# metadata=agent_state._metadata,
# ),
# last_run_at=None, # TODO
# sources=attached_sources,
# )
# @router.patch("/{agent_id}/rename", response_model=GetAgentResponse)
# def update_agent_name(
# agent_id: "UUID",
# agent_rename: AgentRenameRequest,
#
# server: "SyncServer" = Depends(get_letta_server),
# ):
# """
# Updates the name of a specific agent.
# This changes the name of the agent in the database but does NOT edit the agent's persona.
# """
# valid_name = agent_rename.agent_name
# actor = server.get_current_user()
#
# agent_state = server.rename_agent(user_id=actor.id, agent_id=agent_id, new_agent_name=valid_name)
## get sources
# attached_sources = server.list_attached_sources(agent_id=agent_id)
# llm_config = LLMConfig(**vars(agent_state.llm_config))
# embedding_config = EmbeddingConfig(**vars(agent_state.embedding_config))
# return GetAgentResponse(
# agent_state=AgentState(
# id=agent_state.id,
# name=agent_state.name,
# user_id=agent_state.user_id,
# llm_config=llm_config,
# embedding_config=embedding_config,
# state=agent_state.state,
# created_at=int(agent_state.created_at.timestamp()),
# tools=agent_state.tools,
# system=agent_state.system,
# ),
# last_run_at=None, # TODO
# sources=attached_sources,
# )
# @router.get("/{agent_id}/archival/all", response_model=GetAgentArchivalMemoryResponse)
# def get_agent_archival_memory_all(
# agent_id: "UUID",
#
# server: "SyncServer" = Depends(get_letta_server),
# ):
# """
# Retrieve the memories in an agent's archival memory store (non-paginated, returns all entries at once).
# """
# actor = server.get_current_user()
#
# archival_memories = server.get_all_archival_memories(user_id=actor.id, agent_id=agent_id)
# print("archival_memories:", archival_memories)
# archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["contents"]) for passage in archival_memories]
# return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)

View File

@@ -82,6 +82,7 @@ from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agents_tags_manager import AgentsTagsManager
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
@@ -248,6 +249,7 @@ class SyncServer(Server):
self.organization_manager = OrganizationManager()
self.user_manager = UserManager()
self.tool_manager = ToolManager()
self.agents_tags_manager = AgentsTagsManager()
# Make default user and org
if init_with_default_org_and_user:
@@ -969,6 +971,19 @@ class SyncServer(Server):
if request.metadata_:
letta_agent.agent_state.metadata_ = request.metadata_
# Manage tag state
if request.tags is not None:
current_tags = set(self.agents_tags_manager.get_tags_for_agent(agent_id=letta_agent.agent_state.id, actor=actor))
target_tags = set(request.tags)
tags_to_add = target_tags - current_tags
tags_to_remove = current_tags - target_tags
for tag in tags_to_add:
self.agents_tags_manager.add_tag_to_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor)
for tag in tags_to_remove:
self.agents_tags_manager.delete_tag_from_agent(agent_id=letta_agent.agent_state.id, tag=tag, actor=actor)
# save the agent
assert isinstance(letta_agent.memory, Memory)
save_agent(letta_agent, self.ms)
@@ -1079,16 +1094,19 @@ class SyncServer(Server):
}
return agent_config
def list_agents(
self,
user_id: str,
) -> List[AgentState]:
def list_agents(self, user_id: str, tags: Optional[List[str]] = None) -> List[AgentState]:
"""List all available agents to a user"""
if self.user_manager.get_user_by_id(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
user = self.user_manager.get_user_by_id(user_id=user_id)
agents_states = self.ms.list_agents(user_id=user_id)
return agents_states
if tags is None:
agents_states = self.ms.list_agents(user_id=user_id)
return agents_states
else:
agent_ids = []
for tag in tags:
agent_ids += self.agents_tags_manager.get_agents_by_tag(tag=tag, actor=user)
return [self.get_agent_state(user_id=user.id, agent_id=agent_id) for agent_id in agent_ids]
def get_blocks(
self,
@@ -1160,18 +1178,6 @@ class SyncServer(Server):
raise ValueError("Source does not exist")
return existing_source.id
def get_agent(self, user_id: str, agent_id: Optional[str] = None, agent_name: Optional[str] = None):
"""Get the agent state"""
return self.ms.get_agent(agent_id=agent_id, agent_name=agent_name, user_id=user_id)
# def get_user(self, user_id: str) -> User:
# """Get the user"""
# user = self.user_manager.get_user_by_id(user_id=user_id)
# if user is None:
# raise ValueError(f"User with user_id {user_id} does not exist")
# else:
# return user
def get_agent_memory(self, agent_id: str) -> Memory:
"""Return the memory of an agent (core memory)"""
agent = self._get_or_load_agent(agent_id=agent_id)
@@ -1389,8 +1395,7 @@ class SyncServer(Server):
def get_agent_state(self, user_id: str, agent_id: Optional[str], agent_name: Optional[str] = None) -> Optional[AgentState]:
"""Return the config of an agent"""
if self.user_manager.get_user_by_id(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
user = self.user_manager.get_user_by_id(user_id=user_id)
if agent_id:
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
return None
@@ -1403,7 +1408,11 @@ class SyncServer(Server):
# Get the agent object (loaded in memory)
letta_agent = self._get_or_load_agent(agent_id=agent_id)
assert isinstance(letta_agent.memory, Memory)
return letta_agent.agent_state.model_copy(deep=True)
agent_state = letta_agent.agent_state.model_copy(deep=True)
# Load the tags in for the agent_state
agent_state.tags = self.agents_tags_manager.get_tags_for_agent(agent_id=agent_id, actor=user)
return agent_state
def get_server_config(self, include_defaults: bool = False) -> dict:
"""Return the base config"""
@@ -1485,8 +1494,11 @@ class SyncServer(Server):
def delete_agent(self, user_id: str, agent_id: str):
"""Delete an agent in the database"""
if self.user_manager.get_user_by_id(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
actor = self.user_manager.get_user_by_id(user_id=user_id)
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM MODEL
# TODO: EVENTUALLY WE GET AUTO-DELETES WHEN WE SPECIFY RELATIONSHIPS IN THE ORM
self.agents_tags_manager.delete_all_tags_from_agent(agent_id=agent_id, actor=actor)
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")

View File

@@ -0,0 +1,64 @@
from typing import List
from letta.orm.agents_tags import AgentsTags as AgentsTagsModel
from letta.orm.errors import NoResultFound
from letta.schemas.agents_tags import AgentsTags as PydanticAgentsTags
from letta.schemas.user import User as PydanticUser
from letta.utils import enforce_types
class AgentsTagsManager:
"""Manager class to handle business logic related to Tags."""
def __init__(self):
from letta.server.server import db_context
self.session_maker = db_context
@enforce_types
def add_tag_to_agent(self, agent_id: str, tag: str, actor: PydanticUser) -> PydanticAgentsTags:
"""Add a tag to an agent."""
with self.session_maker() as session:
# Check if the tag already exists for this agent
try:
agents_tags_model = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
return agents_tags_model.to_pydantic()
except NoResultFound:
agents_tags = PydanticAgentsTags(agent_id=agent_id, tag=tag).model_dump(exclude_none=True)
new_tag = AgentsTagsModel(**agents_tags, organization_id=actor.organization_id)
new_tag.create(session, actor=actor)
return new_tag.to_pydantic()
@enforce_types
def delete_all_tags_from_agent(self, agent_id: str, actor: PydanticUser):
"""Delete a tag from an agent. This is a permanent hard delete."""
tags = self.get_tags_for_agent(agent_id=agent_id, actor=actor)
for tag in tags:
self.delete_tag_from_agent(agent_id=agent_id, tag=tag, actor=actor)
@enforce_types
def delete_tag_from_agent(self, agent_id: str, tag: str, actor: PydanticUser):
"""Delete a tag from an agent."""
with self.session_maker() as session:
try:
# Retrieve and delete the tag association
tag_association = AgentsTagsModel.read(db_session=session, agent_id=agent_id, tag=tag, actor=actor)
tag_association.hard_delete(session, actor=actor)
except NoResultFound:
raise ValueError(f"Tag '{tag}' not found for agent '{agent_id}'.")
@enforce_types
def get_agents_by_tag(self, tag: str, actor: PydanticUser) -> List[str]:
"""Retrieve all agent IDs associated with a specific tag."""
with self.session_maker() as session:
# Query for all agents with the given tag
agents_with_tag = AgentsTagsModel.list(db_session=session, tag=tag, organization_id=actor.organization_id)
return [record.agent_id for record in agents_with_tag]
@enforce_types
def get_tags_for_agent(self, agent_id: str, actor: PydanticUser) -> List[str]:
"""Retrieve all tags associated with a specific agent."""
with self.session_maker() as session:
# Query for all tags associated with the given agent
tags_for_agent = AgentsTagsModel.list(db_session=session, agent_id=agent_id, organization_id=actor.organization_id)
return [record.tag for record in tags_for_agent]

View File

@@ -133,7 +133,7 @@ class ToolManager:
"""Delete a tool by its ID."""
with self.session_maker() as session:
try:
tool = ToolModel.read(db_session=session, identifier=tool_id)
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
tool.delete(db_session=session, actor=actor)
except NoResultFound:
raise ValueError(f"Tool with id {tool_id} not found.")

View File

@@ -76,7 +76,6 @@ def client(request):
client = create_client(base_url=server_url, token=None) # This yields control back to the test function
else:
# use local client (no server)
server_url = None
client = create_client()
client.set_default_llm_config(LLMConfig.default_config("gpt-4"))
@@ -88,7 +87,6 @@ def client(request):
@pytest.fixture(scope="module")
def agent(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name=test_agent_name)
print("AGENT ID", agent_state.id)
yield agent_state
# delete agent
@@ -676,3 +674,38 @@ def test_initial_message_sequence(client: Union[LocalClient, RESTClient], agent:
len(custom_agent_state.message_ids) == len(custom_sequence) + 1
), f"Expected {len(custom_sequence) + 1} messages, got {len(custom_agent_state.message_ids)}"
assert custom_agent_state.message_ids[1:] == [msg.id for msg in custom_sequence]
def test_add_and_manage_tags_for_agent(client: Union[LocalClient, RESTClient], agent: AgentState):
"""
Comprehensive happy path test for adding, retrieving, and managing tags on an agent.
"""
# Step 1: Add multiple tags to the agent
tags_to_add = ["test_tag_1", "test_tag_2", "test_tag_3"]
client.update_agent(agent_id=agent.id, tags=tags_to_add)
# Step 2: Retrieve tags for the agent and verify they match the added tags
retrieved_tags = client.get_agent(agent_id=agent.id).tags
assert set(retrieved_tags) == set(tags_to_add), f"Expected tags {tags_to_add}, but got {retrieved_tags}"
# Step 3: Retrieve agents by each tag to ensure the agent is associated correctly
for tag in tags_to_add:
agents_with_tag = client.list_agents(tags=[tag])
assert agent.id in [a.id for a in agents_with_tag], f"Expected agent {agent.id} to be associated with tag '{tag}'"
# Step 4: Delete a specific tag from the agent and verify its removal
tag_to_delete = tags_to_add.pop()
client.update_agent(agent_id=agent.id, tags=tags_to_add)
# Verify the tag is removed from the agent's tags
remaining_tags = client.get_agent(agent_id=agent.id).tags
assert tag_to_delete not in remaining_tags, f"Tag '{tag_to_delete}' was not removed as expected"
assert set(remaining_tags) == set(tags_to_add), f"Expected remaining tags to be {tags_to_add[1:]}, but got {remaining_tags}"
# Step 5: Delete all remaining tags from the agent
client.update_agent(agent_id=agent.id, tags=[])
# Verify all tags are removed
final_tags = client.get_agent(agent_id=agent.id).tags
assert len(final_tags) == 0, f"Expected no tags, but found {final_tags}"

View File

@@ -6,6 +6,10 @@ from letta.functions.functions import derive_openai_json_schema, parse_source_co
from letta.orm.organization import Organization
from letta.orm.tool import Tool
from letta.orm.user import User
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory
from letta.schemas.organization import Organization as PydanticOrganization
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolUpdate
@@ -29,7 +33,68 @@ def clear_tables(server: SyncServer):
@pytest.fixture
def tool_fixture(server: SyncServer):
def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
yield org
@pytest.fixture
def default_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_default_user(org_id=default_organization.id)
yield user
@pytest.fixture
def other_user(server: SyncServer, default_organization):
"""Fixture to create and return the default user within the default organization."""
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id))
yield user
@pytest.fixture
def sarah_agent(server: SyncServer, default_user, default_organization):
"""Fixture to create and return a sample agent within the default organization."""
agent_state = server.create_agent(
request=CreateAgent(
name="sarah_agent",
memory=ChatMemory(
human="Charles",
persona="I am a helpful assistant",
),
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=default_user,
)
yield agent_state
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
@pytest.fixture
def charles_agent(server: SyncServer, default_user, default_organization):
"""Fixture to create and return a sample agent within the default organization."""
agent_state = server.create_agent(
request=CreateAgent(
name="charles_agent",
memory=ChatMemory(
human="Sarah",
persona="I am a helpful assistant",
),
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=default_user,
)
yield agent_state
server.delete_agent(user_id=default_user.id, agent_id=agent_state.id)
@pytest.fixture
def tool_fixture(server: SyncServer, default_user, default_organization):
"""Fixture to create a tool with default settings and clean up after the test."""
def print_tool(message: str):
@@ -43,6 +108,7 @@ def tool_fixture(server: SyncServer):
print(message)
return message
# Set up tool details
source_code = parse_source_code(print_tool)
source_type = "python"
description = "test_description"
@@ -50,9 +116,9 @@ def tool_fixture(server: SyncServer):
org = server.organization_manager.create_default_organization()
user = server.user_manager.create_default_user()
other_user = server.user_manager.create_user(PydanticUser(name="other", organization_id=org.id))
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
@@ -60,7 +126,7 @@ def tool_fixture(server: SyncServer):
tool = server.tool_manager.create_tool(tool, actor=user)
# Yield the created tool, organization, and user for use in tests
yield {"tool": tool, "organization": org, "user": user, "other_user": other_user, "tool_create": tool}
yield {"tool": tool, "organization": org, "user": user, "tool_create": tool}
@pytest.fixture(scope="module")
@@ -172,15 +238,13 @@ def test_update_user(server: SyncServer):
# ======================================================================================================================
# Tool Manager Tests
# ======================================================================================================================
def test_create_tool(server: SyncServer, tool_fixture):
def test_create_tool(server: SyncServer, tool_fixture, default_user, default_organization):
tool = tool_fixture["tool"]
tool_create = tool_fixture["tool_create"]
user = tool_fixture["user"]
org = tool_fixture["organization"]
# Assertions to ensure the created tool matches the expected values
assert tool.created_by_id == user.id
assert tool.organization_id == org.id
assert tool.created_by_id == default_user.id
assert tool.organization_id == default_organization.id
assert tool.description == tool_create.description
assert tool.tags == tool_create.tags
assert tool.source_code == tool_create.source_code
@@ -188,12 +252,11 @@ def test_create_tool(server: SyncServer, tool_fixture):
assert tool.json_schema == derive_openai_json_schema(source_code=tool_create.source_code, name=tool_create.name)
def test_get_tool_by_id(server: SyncServer, tool_fixture):
def test_get_tool_by_id(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
# Fetch the tool by ID using the manager method
fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
fetched_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
# Assertions to check if the fetched tool matches the created tool
assert fetched_tool.id == tool.id
@@ -204,54 +267,51 @@ def test_get_tool_by_id(server: SyncServer, tool_fixture):
assert fetched_tool.source_type == tool.source_type
def test_get_tool_with_actor(server: SyncServer, tool_fixture):
def test_get_tool_with_actor(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
# Fetch the tool by name and organization ID
fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=user)
fetched_tool = server.tool_manager.get_tool_by_name(tool.name, actor=default_user)
# Assertions to check if the fetched tool matches the created tool
assert fetched_tool.id == tool.id
assert fetched_tool.name == tool.name
assert fetched_tool.created_by_id == user.id
assert fetched_tool.created_by_id == default_user.id
assert fetched_tool.description == tool.description
assert fetched_tool.tags == tool.tags
assert fetched_tool.source_code == tool.source_code
assert fetched_tool.source_type == tool.source_type
def test_list_tools(server: SyncServer, tool_fixture):
def test_list_tools(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
# List tools (should include the one created by the fixture)
tools = server.tool_manager.list_tools(actor=user)
tools = server.tool_manager.list_tools(actor=default_user)
# Assertions to check that the created tool is listed
assert len(tools) == 1
assert any(t.id == tool.id for t in tools)
def test_update_tool_by_id(server: SyncServer, tool_fixture):
def test_update_tool_by_id(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
updated_description = "updated_description"
# Create a ToolUpdate object to modify the tool's description
tool_update = ToolUpdate(description=updated_description)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
# Assertions to check if the update was successful
assert updated_tool.description == updated_description
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture):
def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, tool_fixture, default_user):
def counter_tool(counter: int):
"""
Args:
@@ -267,7 +327,6 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
# Test begins
tool = tool_fixture["tool"]
user = tool_fixture["user"]
og_json_schema = tool_fixture["tool_create"].json_schema
source_code = parse_source_code(counter_tool)
@@ -276,10 +335,10 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
tool_update = ToolUpdate(source_code=source_code)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
# Assertions to check if the update was successful, and json_schema is updated as well
assert updated_tool.source_code == source_code
@@ -289,7 +348,7 @@ def test_update_tool_source_code_refreshes_schema_and_name(server: SyncServer, t
assert updated_tool.json_schema == new_schema
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture):
def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_fixture, default_user):
def counter_tool(counter: int):
"""
Args:
@@ -305,7 +364,6 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
# Test begins
tool = tool_fixture["tool"]
user = tool_fixture["user"]
og_json_schema = tool_fixture["tool_create"].json_schema
source_code = parse_source_code(counter_tool)
@@ -315,10 +373,10 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
tool_update = ToolUpdate(name=name, source_code=source_code)
# Update the tool using the manager method
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=user)
server.tool_manager.update_tool_by_id(tool.id, tool_update, actor=default_user)
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
# Assertions to check if the update was successful, and json_schema is updated as well
assert updated_tool.source_code == source_code
@@ -329,10 +387,8 @@ def test_update_tool_source_code_refreshes_schema_only(server: SyncServer, tool_
assert updated_tool.name == name
def test_update_tool_multi_user(server: SyncServer, tool_fixture):
def test_update_tool_multi_user(server: SyncServer, tool_fixture, default_user, other_user):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
other_user = tool_fixture["other_user"]
updated_description = "updated_description"
# Create a ToolUpdate object to modify the tool's description
@@ -343,18 +399,99 @@ def test_update_tool_multi_user(server: SyncServer, tool_fixture):
# Check that the created_by and last_updated_by fields are correct
# Fetch the updated tool to verify the changes
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=user)
updated_tool = server.tool_manager.get_tool_by_id(tool.id, actor=default_user)
assert updated_tool.last_updated_by_id == other_user.id
assert updated_tool.created_by_id == user.id
assert updated_tool.created_by_id == default_user.id
def test_delete_tool_by_id(server: SyncServer, tool_fixture):
def test_delete_tool_by_id(server: SyncServer, tool_fixture, default_user):
tool = tool_fixture["tool"]
user = tool_fixture["user"]
# Delete the tool using the manager method
server.tool_manager.delete_tool_by_id(tool.id, actor=user)
server.tool_manager.delete_tool_by_id(tool.id, actor=default_user)
tools = server.tool_manager.list_tools(actor=user)
tools = server.tool_manager.list_tools(actor=default_user)
assert len(tools) == 0
# ======================================================================================================================
# AgentsTagsManager Tests
# ======================================================================================================================
def test_add_tag_to_agent(server: SyncServer, sarah_agent, default_user):
# Add a tag to the agent
tag_name = "test_tag"
tag_association = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
# Assert that the tag association was created correctly
assert tag_association.agent_id == sarah_agent.id
assert tag_association.tag == tag_name
def test_add_duplicate_tag_to_agent(server: SyncServer, sarah_agent, default_user):
# Add the same tag twice to the agent
tag_name = "test_tag"
first_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
duplicate_tag = server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
# Assert that the second addition returns the existing tag without creating a duplicate
assert first_tag.agent_id == duplicate_tag.agent_id
assert first_tag.tag == duplicate_tag.tag
# Get all the tags belonging to the agent
tags = server.agents_tags_manager.get_tags_for_agent(agent_id=sarah_agent.id, actor=default_user)
assert len(tags) == 1
assert tags[0] == first_tag.tag
def test_delete_tag_from_agent(server: SyncServer, sarah_agent, default_user):
# Add a tag, then delete it
tag_name = "test_tag"
server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
# Assert the tag was deleted
agent_tags = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user)
assert sarah_agent.id not in agent_tags
def test_delete_nonexistent_tag_from_agent(server: SyncServer, sarah_agent, default_user):
# Attempt to delete a tag that doesn't exist
tag_name = "nonexistent_tag"
with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{sarah_agent.id}'"):
server.agents_tags_manager.delete_tag_from_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
def test_delete_tag_from_nonexistent_agent(server: SyncServer, default_user):
# Attempt to delete a tag that doesn't exist
tag_name = "nonexistent_tag"
agent_id = "abc"
with pytest.raises(ValueError, match=f"Tag '{tag_name}' not found for agent '{agent_id}'"):
server.agents_tags_manager.delete_tag_from_agent(agent_id=agent_id, tag=tag_name, actor=default_user)
def test_get_agents_by_tag(server: SyncServer, sarah_agent, charles_agent, default_user, default_organization):
# Add a shared tag to multiple agents
tag_name = "shared_tag"
# Add the same tag to both agents
server.agents_tags_manager.add_tag_to_agent(agent_id=sarah_agent.id, tag=tag_name, actor=default_user)
server.agents_tags_manager.add_tag_to_agent(agent_id=charles_agent.id, tag=tag_name, actor=default_user)
# Retrieve agents by tag
agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user)
# Assert that both agents are returned for the tag
assert sarah_agent.id in agent_ids
assert charles_agent.id in agent_ids
assert len(agent_ids) == 2
# Delete tags from only sarah agent
server.agents_tags_manager.delete_all_tags_from_agent(agent_id=sarah_agent.id, actor=default_user)
agent_ids = server.agents_tags_manager.get_agents_by_tag(tag=tag_name, actor=default_user)
# Assert that both agents are returned for the tag
assert sarah_agent.id not in agent_ids
assert charles_agent.id in agent_ids
assert len(agent_ids) == 1