feat: add new default_requires_approval flag on tools (#4287)
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""add default requires approval field on tools
|
||||
|
||||
Revision ID: c41c87205254
|
||||
Revises: 068588268b02
|
||||
Create Date: 2025-08-28 13:17:51.636159
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c41c87205254"
|
||||
down_revision: Union[str, None] = "068588268b02"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("tools", sa.Column("default_requires_approval", sa.Boolean(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("tools", "default_requires_approval")
|
||||
# ### end Alembic commands ###
|
||||
@@ -49,6 +49,7 @@ class Tool(SqlalchemyBase, OrganizationMixin):
|
||||
JSON, nullable=True, doc="Optional list of pip packages required by this tool."
|
||||
)
|
||||
npm_requirements: Mapped[list | None] = mapped_column(JSON, doc="Optional list of npm packages required by this tool.")
|
||||
default_requires_approval: Mapped[bool] = mapped_column(nullable=True, doc="Whether or not to require approval.")
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.")
|
||||
# relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")
|
||||
|
||||
@@ -67,6 +67,9 @@ class Tool(BaseTool):
|
||||
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
|
||||
pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
|
||||
default_requires_approval: Optional[bool] = Field(
|
||||
None, description="Default value for whether or not executing this tool requires approval."
|
||||
)
|
||||
|
||||
# metadata fields
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
@@ -168,6 +171,7 @@ class ToolCreate(LettaBase):
|
||||
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
|
||||
pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
|
||||
default_requires_approval: Optional[bool] = Field(None, description="Whether or not to require approval before executing this tool.")
|
||||
|
||||
@classmethod
|
||||
def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate":
|
||||
@@ -248,6 +252,7 @@ class ToolUpdate(LettaBase):
|
||||
pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.")
|
||||
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional metadata for the tool.")
|
||||
default_requires_approval: Optional[bool] = Field(None, description="Whether or not to require approval before executing this tool.")
|
||||
|
||||
model_config = ConfigDict(extra="ignore") # Allows extra fields without validation errors
|
||||
# TODO: Remove this, and clean usage of ToolUpdate everywhere else
|
||||
|
||||
@@ -60,7 +60,7 @@ from letta.schemas.message import MessageCreate, MessageUpdate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool_rule import ContinueToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import ContinueToolRule, RequiresApprovalToolRule, TerminalToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import MarshmallowAgentSchema
|
||||
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
|
||||
@@ -163,14 +163,16 @@ class AgentManager:
|
||||
return name_to_id, id_to_name
|
||||
|
||||
@staticmethod
|
||||
async def _resolve_tools_async(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
|
||||
async def _resolve_tools_async(
|
||||
session, names: Set[str], ids: Set[str], org_id: str
|
||||
) -> Tuple[Dict[str, str], Dict[str, str], List[str]]:
|
||||
"""
|
||||
Bulk‑fetch all ToolModel rows matching either name ∈ names or id ∈ ids
|
||||
(and scoped to this organization), and return two maps:
|
||||
name_to_id, id_to_name.
|
||||
Raises if any requested name or id was not found.
|
||||
"""
|
||||
stmt = select(ToolModel.id, ToolModel.name).where(
|
||||
stmt = select(ToolModel.id, ToolModel.name, ToolModel.default_requires_approval).where(
|
||||
ToolModel.organization_id == org_id,
|
||||
or_(
|
||||
ToolModel.name.in_(names),
|
||||
@@ -181,6 +183,7 @@ class AgentManager:
|
||||
rows = result.fetchall() # Use fetchall()
|
||||
name_to_id = {row[1]: row[0] for row in rows} # row[1] is name, row[0] is id
|
||||
id_to_name = {row[0]: row[1] for row in rows} # row[0] is id, row[1] is name
|
||||
requires_approval = [row[1] for row in rows if row[2]] # row[1] is name, row[2] is default_requires_approval
|
||||
|
||||
missing_names = names - set(name_to_id.keys())
|
||||
missing_ids = ids - set(id_to_name.keys())
|
||||
@@ -189,7 +192,7 @@ class AgentManager:
|
||||
if missing_ids:
|
||||
raise ValueError(f"Tools not found by id: {missing_ids}")
|
||||
|
||||
return name_to_id, id_to_name
|
||||
return name_to_id, id_to_name, requires_approval
|
||||
|
||||
@staticmethod
|
||||
def _bulk_insert_pivot(session, table, rows: list[dict]):
|
||||
@@ -556,7 +559,7 @@ class AgentManager:
|
||||
async with db_registry.async_session() as session:
|
||||
async with session.begin():
|
||||
# Note: This will need to be modified if _resolve_tools needs an async version
|
||||
name_to_id, id_to_name = await self._resolve_tools_async(
|
||||
name_to_id, id_to_name, requires_approval = await self._resolve_tools_async(
|
||||
session,
|
||||
tool_names,
|
||||
supplied_ids,
|
||||
@@ -588,6 +591,9 @@ class AgentManager:
|
||||
elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_MEMORY_TOOLS_V2 + BASE_SLEEPTIME_TOOLS):
|
||||
tool_rules.append(ContinueToolRule(tool_name=tn))
|
||||
|
||||
for tool_with_requires_approval in requires_approval:
|
||||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_with_requires_approval))
|
||||
|
||||
if tool_rules:
|
||||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules)
|
||||
|
||||
@@ -2855,12 +2861,15 @@ class AgentManager:
|
||||
|
||||
# verify tool exists and belongs to organization in a single query with the insert
|
||||
# first, check if tool exists with correct organization
|
||||
tool_check_query = select(func.count(ToolModel.id)).where(
|
||||
tool_check_query = select(ToolModel.name, ToolModel.default_requires_approval).where(
|
||||
ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id
|
||||
)
|
||||
tool_result = await session.execute(tool_check_query)
|
||||
if tool_result.scalar() == 0:
|
||||
result = await session.execute(tool_check_query)
|
||||
tool_rows = result.fetchall()
|
||||
|
||||
if len(tool_rows) == 0:
|
||||
raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}")
|
||||
tool_name, default_requires_approval = tool_rows[0]
|
||||
|
||||
# use postgresql on conflict or mysql on duplicate key update for atomic operation
|
||||
if settings.letta_pg_uri_no_default:
|
||||
@@ -2884,6 +2893,17 @@ class AgentManager:
|
||||
else:
|
||||
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
|
||||
|
||||
if default_requires_approval:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
existing_rules = [rule for rule in agent.tool_rules if rule.tool_name == tool_name and rule.type == "requires_approval"]
|
||||
if len(existing_rules) == 0:
|
||||
# Create a new list to ensure SQLAlchemy detects the change
|
||||
# This is critical for JSON columns - modifying in place doesn't trigger change detection
|
||||
tool_rules = list(agent.tool_rules) if agent.tool_rules else []
|
||||
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_name))
|
||||
agent.tool_rules = tool_rules
|
||||
session.add(agent)
|
||||
|
||||
await session.commit()
|
||||
|
||||
@enforce_types
|
||||
|
||||
@@ -247,6 +247,42 @@ async def print_tool(server: SyncServer, default_user, default_organization):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def bash_tool(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create a bash tool with requires_approval and clean up after the test."""
|
||||
|
||||
def bash_tool(operation: str):
|
||||
"""
|
||||
Args:
|
||||
operation (str): The bash operation to execute.
|
||||
|
||||
Returns:
|
||||
str: The result of the executed operation.
|
||||
"""
|
||||
print("scary bash operation")
|
||||
return "success"
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(bash_tool)
|
||||
source_type = "python"
|
||||
description = "test_description"
|
||||
tags = ["test"]
|
||||
metadata = {"a": "b"}
|
||||
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
tool.default_requires_approval = True
|
||||
|
||||
tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composio_github_star_tool(server, default_user):
|
||||
tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER")
|
||||
@@ -1881,6 +1917,57 @@ async def test_detach_all_files_tools_async_idempotent(server: SyncServer, sarah
|
||||
assert len(final_agent_state.tools) == tool_count_after_first
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_tool_with_default_requires_approval(server: SyncServer, sarah_agent, bash_tool, default_user):
|
||||
"""Test that attaching a tool with default requires_approval adds associated tool rule."""
|
||||
# Attach the tool
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user)
|
||||
|
||||
# Verify attachment through get_agent_by_id
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert bash_tool.id in [t.id for t in agent.tools]
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
# Verify that attaching the same tool again doesn't cause duplication
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_tool_with_default_requires_approval_on_creation(server: SyncServer, bash_tool, default_user):
|
||||
"""Test that attaching a tool with default requires_approval adds associated tool rule."""
|
||||
# Create agent with tool
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent11",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tools=[bash_tool.name],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert bash_tool.id in [t.id for t in agent.tools]
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
# Verify that attaching the same tool again doesn't cause duplication
|
||||
await server.agent_manager.attach_tool_async(agent_id=agent.id, tool_id=bash_tool.id, actor=default_user)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
|
||||
assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Sources Relationship
|
||||
# ======================================================================================================================
|
||||
@@ -3627,6 +3714,13 @@ def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user
|
||||
server.tool_manager.create_tool(tool, actor=default_user)
|
||||
|
||||
|
||||
def test_create_tool_requires_approval(server: SyncServer, bash_tool, default_user, default_organization):
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert bash_tool.created_by_id == default_user.id
|
||||
assert bash_tool.tool_type == ToolType.CUSTOM
|
||||
assert bash_tool.default_requires_approval == True
|
||||
|
||||
|
||||
def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
# Fetch the tool by ID using the manager method
|
||||
fetched_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
@@ -4360,6 +4454,28 @@ async def test_pip_requirements_roundtrip(server: SyncServer, default_user, defa
|
||||
assert reqs_dict["numpy"] is None
|
||||
|
||||
|
||||
async def test_update_default_requires_approval(server: SyncServer, bash_tool, default_user):
|
||||
# Update field
|
||||
tool_update = ToolUpdate(default_requires_approval=False)
|
||||
await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool
|
||||
updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert updated_tool.default_requires_approval == False
|
||||
|
||||
# Revert update
|
||||
tool_update = ToolUpdate(default_requires_approval=True)
|
||||
await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool
|
||||
updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert updated_tool.default_requires_approval == True
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Message Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user