feat: orm ToolsAgents migration (#2173)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -551,6 +551,7 @@ tags
|
||||
[._]*.un~
|
||||
|
||||
### VisualStudioCode ###
|
||||
.vscode/
|
||||
.vscode/*
|
||||
!.vscode/settings.json
|
||||
!.vscode/tasks.json
|
||||
|
||||
44
alembic/versions/08b2f8225812_adding_toolsagents_orm.py
Normal file
44
alembic/versions/08b2f8225812_adding_toolsagents_orm.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""adding ToolsAgents ORM
|
||||
|
||||
Revision ID: 08b2f8225812
|
||||
Revises: 3c683a662c82
|
||||
Create Date: 2024-12-05 16:46:51.258831
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '08b2f8225812'
|
||||
down_revision: Union[str, None] = '3c683a662c82'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tools_agents',
|
||||
sa.Column('agent_id', sa.String(), nullable=False),
|
||||
sa.Column('tool_id', sa.String(), nullable=False),
|
||||
sa.Column('tool_name', sa.String(), nullable=False),
|
||||
sa.Column('id', sa.String(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=True),
|
||||
sa.Column('is_deleted', sa.Boolean(), server_default=sa.text('FALSE'), nullable=False),
|
||||
sa.Column('_created_by_id', sa.String(), nullable=True),
|
||||
sa.Column('_last_updated_by_id', sa.String(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['agent_id'], ['agents.id'], ),
|
||||
sa.ForeignKeyConstraint(['tool_id'], ['tools.id'], name='fk_tool_id'),
|
||||
sa.PrimaryKeyConstraint('agent_id', 'tool_id', 'tool_name', 'id'),
|
||||
sa.UniqueConstraint('agent_id', 'tool_name', name='unique_tool_per_agent')
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table('tools_agents')
|
||||
# ### end Alembic commands ###
|
||||
@@ -56,7 +56,7 @@ def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
|
||||
pause_heartbeats.__doc__ = pause_heartbeats_docstring
|
||||
|
||||
|
||||
def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
"""
|
||||
Search prior conversation history using case-insensitive string matching.
|
||||
|
||||
@@ -91,7 +91,7 @@ def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Opt
|
||||
return results_str
|
||||
|
||||
|
||||
def conversation_search_date(self: Agent, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
"""
|
||||
Search prior conversation history using a date range.
|
||||
|
||||
@@ -126,7 +126,7 @@ def conversation_search_date(self: Agent, start_date: str, end_date: str, page:
|
||||
return results_str
|
||||
|
||||
|
||||
def archival_memory_insert(self: Agent, content: str) -> Optional[str]:
|
||||
def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
||||
"""
|
||||
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
||||
|
||||
@@ -140,7 +140,7 @@ def archival_memory_insert(self: Agent, content: str) -> Optional[str]:
|
||||
return None
|
||||
|
||||
|
||||
def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
def archival_memory_search(self: "Agent", query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
|
||||
|
||||
@@ -7,4 +7,5 @@ from letta.orm.organization import Organization
|
||||
from letta.orm.sandbox_config import SandboxConfig, SandboxEnvironmentVariable
|
||||
from letta.orm.source import Source
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
from letta.orm.user import User
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy import JSON, String, UniqueConstraint, event
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
# TODO everything in functions should live in this model
|
||||
@@ -11,6 +11,7 @@ from letta.schemas.tool import Tool as PydanticTool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
|
||||
|
||||
class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
@@ -40,3 +41,23 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||
tools_agents: Mapped[List["ToolsAgents"]] = relationship("ToolsAgents", back_populates="tool", cascade="all, delete-orphan")
|
||||
|
||||
|
||||
# Add event listener to update tool_name in ToolsAgents when Tool name changes
|
||||
@event.listens_for(Tool, 'before_update')
|
||||
def update_tool_name_in_tools_agents(mapper, connection, target):
|
||||
"""Update tool_name in ToolsAgents when Tool name changes."""
|
||||
state = target._sa_instance_state
|
||||
history = state.get_history('name', passive=True)
|
||||
if not history.has_changes():
|
||||
return
|
||||
|
||||
# Get the new name and update all associated ToolsAgents records
|
||||
new_name = target.name
|
||||
from letta.orm.tools_agents import ToolsAgents
|
||||
connection.execute(
|
||||
ToolsAgents.__table__.update().where(
|
||||
ToolsAgents.tool_id == target.id
|
||||
).values(tool_name=new_name)
|
||||
)
|
||||
|
||||
32
letta/orm/tools_agents.py
Normal file
32
letta/orm/tools_agents.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from sqlalchemy import ForeignKey, ForeignKeyConstraint, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
|
||||
|
||||
|
||||
class ToolsAgents(SqlalchemyBase):
|
||||
"""Agents can have one or many tools associated with them."""
|
||||
|
||||
__tablename__ = "tools_agents"
|
||||
__pydantic_model__ = PydanticToolsAgents
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"agent_id",
|
||||
"tool_name",
|
||||
name="unique_tool_per_agent",
|
||||
),
|
||||
ForeignKeyConstraint(
|
||||
["tool_id"],
|
||||
["tools.id"],
|
||||
name="fk_tool_id",
|
||||
),
|
||||
)
|
||||
|
||||
# Each agent must have unique tool names
|
||||
agent_id: Mapped[str] = mapped_column(String, ForeignKey("agents.id"), primary_key=True)
|
||||
tool_id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
tool_name: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
# relationships
|
||||
tool: Mapped["Tool"] = relationship("Tool", back_populates="tools_agents") # agent: Mapped["Agent"] = relationship("Agent", back_populates="tools_agents")
|
||||
32
letta/schemas/tools_agents.py
Normal file
32
letta/schemas/tools_agents.py
Normal file
@@ -0,0 +1,32 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class ToolsAgentsBase(LettaBase):
|
||||
__id_prefix__ = "tools_agents"
|
||||
|
||||
|
||||
class ToolsAgents(ToolsAgentsBase):
|
||||
"""
|
||||
Schema representing the relationship between tools and agents.
|
||||
|
||||
Parameters:
|
||||
agent_id (str): The ID of the associated agent.
|
||||
tool_id (str): The ID of the associated tool.
|
||||
tool_name (str): The name of the tool.
|
||||
created_at (datetime): The date this relationship was created.
|
||||
updated_at (datetime): The date this relationship was last updated.
|
||||
is_deleted (bool): Whether this tool-agent relationship is deleted or not.
|
||||
"""
|
||||
|
||||
id: str = ToolsAgentsBase.generate_id_field()
|
||||
agent_id: str = Field(..., description="The ID of the associated agent.")
|
||||
tool_id: str = Field(..., description="The ID of the associated tool.")
|
||||
tool_name: str = Field(..., description="The name of the tool.")
|
||||
created_at: Optional[datetime] = Field(None, description="The creation date of the association.")
|
||||
updated_at: Optional[datetime] = Field(None, description="The update date of the association.")
|
||||
is_deleted: bool = Field(False, description="Whether this tool-agent relationship is deleted or not.")
|
||||
@@ -77,6 +77,7 @@ from letta.schemas.user import User
|
||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||
from letta.services.tools_agents_manager import ToolsAgentsManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
@@ -259,6 +260,7 @@ class SyncServer(Server):
|
||||
self.agents_tags_manager = AgentsTagsManager()
|
||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||
self.blocks_agents_manager = BlocksAgentsManager()
|
||||
self.tools_agents_manager = ToolsAgentsManager()
|
||||
self.job_manager = JobManager()
|
||||
|
||||
# Managers that interface with parallelism
|
||||
|
||||
@@ -130,7 +130,7 @@ class ToolManager:
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
tool = ToolModel.read(db_session=session, identifier=tool_id, actor=actor)
|
||||
tool.delete(db_session=session, actor=actor)
|
||||
tool.hard_delete(db_session=session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
|
||||
94
letta/services/tools_agents_manager.py
Normal file
94
letta/services/tools_agents_manager.py
Normal file
@@ -0,0 +1,94 @@
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.tool import Tool
|
||||
from letta.orm.tools_agents import ToolsAgents as ToolsAgentsModel
|
||||
from letta.schemas.tools_agents import ToolsAgents as PydanticToolsAgents
|
||||
|
||||
class ToolsAgentsManager:
|
||||
"""Manages the relationship between tools and agents."""
|
||||
|
||||
def __init__(self):
|
||||
from letta.server.server import db_context
|
||||
self.session_maker = db_context
|
||||
|
||||
def add_tool_to_agent(self, agent_id: str, tool_id: str, tool_name: str) -> PydanticToolsAgents:
|
||||
"""Add a tool to an agent.
|
||||
|
||||
When a tool is added to an agent, it will be added to all agents in the same organization.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Check if the tool-agent combination already exists for this agent
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||
warnings.warn(f"Tool name '{tool_name}' already exists for agent '{agent_id}'.")
|
||||
except NoResultFound:
|
||||
tools_agents_record = PydanticToolsAgents(agent_id=agent_id, tool_id=tool_id, tool_name=tool_name)
|
||||
tools_agents_record = ToolsAgentsModel(**tools_agents_record.model_dump(exclude_none=True))
|
||||
tools_agents_record.create(session)
|
||||
|
||||
return tools_agents_record.to_pydantic()
|
||||
|
||||
def remove_tool_with_name_from_agent(self, agent_id: str, tool_name: str) -> None:
|
||||
"""Remove a tool from an agent by its name.
|
||||
|
||||
When a tool is removed from an agent, it will be removed from all agents in the same organization.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
# Find and delete the tool-agent association for the agent
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||
tools_agents_record.hard_delete(session)
|
||||
return tools_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
|
||||
|
||||
def remove_tool_with_id_from_agent(self, agent_id: str, tool_id: str) -> PydanticToolsAgents:
|
||||
"""Remove a tool with an ID from an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_id=tool_id)
|
||||
tools_agents_record.hard_delete(session)
|
||||
return tools_agents_record.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool ID '{tool_id}' not found for agent '{agent_id}'.")
|
||||
|
||||
def list_tool_ids_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List all tool IDs associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
return [record.tool_id for record in tools_agents_record]
|
||||
|
||||
def list_tool_names_for_agent(self, agent_id: str) -> List[str]:
|
||||
"""List all tool names associated with a specific agent."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_record = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
return [record.tool_name for record in tools_agents_record]
|
||||
|
||||
def list_agent_ids_with_tool(self, tool_id: str) -> List[str]:
|
||||
"""List all agents associated with a specific tool."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_record = ToolsAgentsModel.list(db_session=session, tool_id=tool_id)
|
||||
return [record.agent_id for record in tools_agents_record]
|
||||
|
||||
def get_tool_id_for_name(self, agent_id: str, tool_name: str) -> str:
|
||||
"""Get the tool ID for a specific tool name for an agent."""
|
||||
with self.session_maker() as session:
|
||||
try:
|
||||
tools_agents_record = ToolsAgentsModel.read(db_session=session, agent_id=agent_id, tool_name=tool_name)
|
||||
return tools_agents_record.tool_id
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Tool name '{tool_name}' not found for agent '{agent_id}'.")
|
||||
|
||||
def remove_all_agent_tools(self, agent_id: str) -> None:
|
||||
"""Remove all tools associated with an agent."""
|
||||
with self.session_maker() as session:
|
||||
tools_agents_records = ToolsAgentsModel.list(db_session=session, agent_id=agent_id)
|
||||
for record in tools_agents_records:
|
||||
record.hard_delete(session)
|
||||
@@ -16,6 +16,7 @@ from letta.orm import (
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
Tool,
|
||||
ToolsAgents,
|
||||
User,
|
||||
)
|
||||
from letta.orm.agents_tags import AgentsTags
|
||||
@@ -46,9 +47,10 @@ from letta.schemas.sandbox_config import (
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.source import SourceUpdate
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolUpdate
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import tool_settings
|
||||
|
||||
utils.DEBUG = True
|
||||
@@ -76,6 +78,7 @@ def clear_tables(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Job))
|
||||
session.execute(delete(ToolsAgents)) # Clear ToolsAgents first
|
||||
session.execute(delete(BlocksAgents))
|
||||
session.execute(delete(AgentsTags))
|
||||
session.execute(delete(SandboxEnvironmentVariable))
|
||||
@@ -240,6 +243,37 @@ def other_block(server: SyncServer, default_user):
|
||||
yield block
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_tool(server: SyncServer, default_user, default_organization):
|
||||
def print_other_tool(message: str):
|
||||
"""
|
||||
Args:
|
||||
message (str): The message to print.
|
||||
|
||||
Returns:
|
||||
str: The message that was printed.
|
||||
"""
|
||||
print(message)
|
||||
return message
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(print_other_tool)
|
||||
source_type = "python"
|
||||
description = "other_tool_description"
|
||||
tags = ["test"]
|
||||
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
|
||||
tool = server.tool_manager.create_tool(tool, actor=default_user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
@@ -1155,6 +1189,107 @@ def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user
|
||||
server.blocks_agents_manager.add_block_to_agent(agent_id=sarah_agent.id, block_id=default_block.id, block_label=default_block.label)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# ToolsAgentsManager Tests
|
||||
# ======================================================================================================================
|
||||
def test_add_tool_to_agent(server, sarah_agent, default_user, print_tool):
|
||||
tool_association = server.tools_agents_manager.add_tool_to_agent(
|
||||
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
|
||||
)
|
||||
|
||||
assert tool_association.agent_id == sarah_agent.id
|
||||
assert tool_association.tool_id == print_tool.id
|
||||
assert tool_association.tool_name == print_tool.name
|
||||
|
||||
|
||||
def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent, default_user, print_tool):
|
||||
# Add the tool
|
||||
tool_association = server.tools_agents_manager.add_tool_to_agent(
|
||||
agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name
|
||||
)
|
||||
assert tool_association.tool_name == print_tool.name
|
||||
|
||||
# Change the tool name
|
||||
new_name = "banana"
|
||||
tool = server.tool_manager.update_tool_by_id(
|
||||
tool_id=print_tool.id, tool_update=ToolUpdate(name=new_name), actor=default_user
|
||||
)
|
||||
assert tool.name == new_name
|
||||
|
||||
# Get the association
|
||||
names = server.tools_agents_manager.list_tool_names_for_agent(agent_id=sarah_agent.id)
|
||||
assert new_name in names
|
||||
assert print_tool.name not in names
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(
|
||||
agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name"
|
||||
)
|
||||
|
||||
|
||||
def test_add_tool_to_agent_duplicate_name(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
with pytest.warns(UserWarning, match=f"Tool name '{print_tool.name}' already exists for agent '{sarah_agent.id}'"):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=print_tool.name)
|
||||
|
||||
|
||||
def test_remove_tool_with_name_from_agent(server, sarah_agent, default_user, print_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
removed_tool = server.tools_agents_manager.remove_tool_with_name_from_agent(
|
||||
agent_id=sarah_agent.id, tool_name=print_tool.name
|
||||
)
|
||||
|
||||
assert removed_tool.tool_name == print_tool.name
|
||||
assert removed_tool.tool_id == print_tool.id
|
||||
assert removed_tool.agent_id == sarah_agent.id
|
||||
|
||||
with pytest.raises(ValueError, match=f"Tool name '{print_tool.name}' not found for agent '{sarah_agent.id}'"):
|
||||
server.tools_agents_manager.remove_tool_with_name_from_agent(agent_id=sarah_agent.id, tool_name=print_tool.name)
|
||||
|
||||
|
||||
def test_list_tool_ids_for_agent(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
|
||||
|
||||
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
|
||||
|
||||
assert set(retrieved_tool_ids) == {print_tool.id, other_tool.id}
|
||||
|
||||
|
||||
def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_user, print_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=charles_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
agent_ids = server.tools_agents_manager.list_agent_ids_with_tool(tool_id=print_tool.id)
|
||||
|
||||
assert sarah_agent.id in agent_ids
|
||||
assert charles_agent.id in agent_ids
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
|
||||
def test_remove_all_agent_tools(server, sarah_agent, default_user, print_tool, other_tool):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=print_tool.id, tool_name=print_tool.name)
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id=other_tool.id, tool_name=other_tool.name)
|
||||
|
||||
server.tools_agents_manager.remove_all_agent_tools(agent_id=sarah_agent.id)
|
||||
|
||||
retrieved_tool_ids = server.tools_agents_manager.list_tool_ids_for_agent(agent_id=sarah_agent.id)
|
||||
|
||||
assert not retrieved_tool_ids
|
||||
|
||||
# ======================================================================================================================
|
||||
# JobManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user