feat: add new default_requires_approval flag on tools (#4287)

This commit is contained in:
cthomas
2025-08-28 15:38:59 -07:00
committed by GitHub
parent f99fbfa280
commit e3d3bc09eb
5 changed files with 181 additions and 8 deletions

View File

@@ -0,0 +1,31 @@
"""add default requires approval field on tools
Revision ID: c41c87205254
Revises: 068588268b02
Create Date: 2025-08-28 13:17:51.636159
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c41c87205254"
down_revision: Union[str, None] = "068588268b02"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("tools", sa.Column("default_requires_approval", sa.Boolean(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("tools", "default_requires_approval")
# ### end Alembic commands ###

View File

@@ -49,6 +49,7 @@ class Tool(SqlalchemyBase, OrganizationMixin):
JSON, nullable=True, doc="Optional list of pip packages required by this tool."
)
npm_requirements: Mapped[list | None] = mapped_column(JSON, doc="Optional list of npm packages required by this tool.")
default_requires_approval: Mapped[bool] = mapped_column(nullable=True, doc="Whether or not to require approval.")
metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.")
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin")

View File

@@ -67,6 +67,9 @@ class Tool(BaseTool):
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.")
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
default_requires_approval: Optional[bool] = Field(
None, description="Default value for whether or not executing this tool requires approval."
)
# metadata fields
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
@@ -168,6 +171,7 @@ class ToolCreate(LettaBase):
return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.")
pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.")
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
default_requires_approval: Optional[bool] = Field(None, description="Whether or not to require approval before executing this tool.")
@classmethod
def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate":
@@ -248,6 +252,7 @@ class ToolUpdate(LettaBase):
pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.")
npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.")
metadata_: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional metadata for the tool.")
default_requires_approval: Optional[bool] = Field(None, description="Whether or not to require approval before executing this tool.")
model_config = ConfigDict(extra="ignore") # Allows extra fields without validation errors
# TODO: Remove this, and clean usage of ToolUpdate everywhere else

View File

@@ -60,7 +60,7 @@ from letta.schemas.message import MessageCreate, MessageUpdate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool_rule import ContinueToolRule, TerminalToolRule
from letta.schemas.tool_rule import ContinueToolRule, RequiresApprovalToolRule, TerminalToolRule
from letta.schemas.user import User as PydanticUser
from letta.serialize_schemas import MarshmallowAgentSchema
from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema
@@ -163,14 +163,16 @@ class AgentManager:
return name_to_id, id_to_name
@staticmethod
async def _resolve_tools_async(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]:
async def _resolve_tools_async(
session, names: Set[str], ids: Set[str], org_id: str
) -> Tuple[Dict[str, str], Dict[str, str], List[str]]:
"""
Bulkfetch all ToolModel rows matching either name ∈ names or id ∈ ids
(and scoped to this organization), and return two maps:
name_to_id, id_to_name.
Raises if any requested name or id was not found.
"""
stmt = select(ToolModel.id, ToolModel.name).where(
stmt = select(ToolModel.id, ToolModel.name, ToolModel.default_requires_approval).where(
ToolModel.organization_id == org_id,
or_(
ToolModel.name.in_(names),
@@ -181,6 +183,7 @@ class AgentManager:
rows = result.fetchall() # Use fetchall()
name_to_id = {row[1]: row[0] for row in rows} # row[1] is name, row[0] is id
id_to_name = {row[0]: row[1] for row in rows} # row[0] is id, row[1] is name
requires_approval = [row[1] for row in rows if row[2]] # row[1] is name, row[2] is default_requires_approval
missing_names = names - set(name_to_id.keys())
missing_ids = ids - set(id_to_name.keys())
@@ -189,7 +192,7 @@ class AgentManager:
if missing_ids:
raise ValueError(f"Tools not found by id: {missing_ids}")
return name_to_id, id_to_name
return name_to_id, id_to_name, requires_approval
@staticmethod
def _bulk_insert_pivot(session, table, rows: list[dict]):
@@ -556,7 +559,7 @@ class AgentManager:
async with db_registry.async_session() as session:
async with session.begin():
# Note: This will need to be modified if _resolve_tools needs an async version
name_to_id, id_to_name = await self._resolve_tools_async(
name_to_id, id_to_name, requires_approval = await self._resolve_tools_async(
session,
tool_names,
supplied_ids,
@@ -588,6 +591,9 @@ class AgentManager:
elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_MEMORY_TOOLS_V2 + BASE_SLEEPTIME_TOOLS):
tool_rules.append(ContinueToolRule(tool_name=tn))
for tool_with_requires_approval in requires_approval:
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_with_requires_approval))
if tool_rules:
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules)
@@ -2855,12 +2861,15 @@ class AgentManager:
# verify tool exists and belongs to organization in a single query with the insert
# first, check if tool exists with correct organization
tool_check_query = select(func.count(ToolModel.id)).where(
tool_check_query = select(ToolModel.name, ToolModel.default_requires_approval).where(
ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id
)
tool_result = await session.execute(tool_check_query)
if tool_result.scalar() == 0:
result = await session.execute(tool_check_query)
tool_rows = result.fetchall()
if len(tool_rows) == 0:
raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}")
tool_name, default_requires_approval = tool_rows[0]
# use postgresql on conflict or mysql on duplicate key update for atomic operation
if settings.letta_pg_uri_no_default:
@@ -2884,6 +2893,17 @@ class AgentManager:
else:
logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}")
if default_requires_approval:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
existing_rules = [rule for rule in agent.tool_rules if rule.tool_name == tool_name and rule.type == "requires_approval"]
if len(existing_rules) == 0:
# Create a new list to ensure SQLAlchemy detects the change
# This is critical for JSON columns - modifying in place doesn't trigger change detection
tool_rules = list(agent.tool_rules) if agent.tool_rules else []
tool_rules.append(RequiresApprovalToolRule(tool_name=tool_name))
agent.tool_rules = tool_rules
session.add(agent)
await session.commit()
@enforce_types

View File

@@ -247,6 +247,42 @@ async def print_tool(server: SyncServer, default_user, default_organization):
yield tool
@pytest.fixture
async def bash_tool(server: SyncServer, default_user, default_organization):
"""Fixture to create a bash tool with requires_approval and clean up after the test."""
def bash_tool(operation: str):
"""
Args:
operation (str): The bash operation to execute.
Returns:
str: The result of the executed operation.
"""
print("scary bash operation")
return "success"
# Set up tool details
source_code = parse_source_code(bash_tool)
source_type = "python"
description = "test_description"
tags = ["test"]
metadata = {"a": "b"}
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata)
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
derived_name = derived_json_schema["name"]
tool.json_schema = derived_json_schema
tool.name = derived_name
tool.default_requires_approval = True
tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
# Yield the created tool
yield tool
@pytest.fixture
def composio_github_star_tool(server, default_user):
tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER")
@@ -1881,6 +1917,57 @@ async def test_detach_all_files_tools_async_idempotent(server: SyncServer, sarah
assert len(final_agent_state.tools) == tool_count_after_first
@pytest.mark.asyncio
async def test_attach_tool_with_default_requires_approval(server: SyncServer, sarah_agent, bash_tool, default_user):
"""Test that attaching a tool with default requires_approval adds associated tool rule."""
# Attach the tool
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user)
# Verify attachment through get_agent_by_id
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert bash_tool.id in [t.id for t in agent.tools]
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
assert len(tool_rules) == 1
assert tool_rules[0].type == "requires_approval"
# Verify that attaching the same tool again doesn't cause duplication
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user)
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
assert len(tool_rules) == 1
assert tool_rules[0].type == "requires_approval"
@pytest.mark.asyncio
async def test_attach_tool_with_default_requires_approval_on_creation(server: SyncServer, bash_tool, default_user):
"""Test that attaching a tool with default requires_approval adds associated tool rule."""
# Create agent with tool
agent = await server.agent_manager.create_agent_async(
agent_create=CreateAgent(
name="agent11",
llm_config=LLMConfig.default_config("gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tools=[bash_tool.name],
include_base_tools=False,
),
actor=default_user,
)
assert bash_tool.id in [t.id for t in agent.tools]
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
assert len(tool_rules) == 1
assert tool_rules[0].type == "requires_approval"
# Verify that attaching the same tool again doesn't cause duplication
await server.agent_manager.attach_tool_async(agent_id=agent.id, tool_id=bash_tool.id, actor=default_user)
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
assert len(tool_rules) == 1
assert tool_rules[0].type == "requires_approval"
# ======================================================================================================================
# AgentManager Tests - Sources Relationship
# ======================================================================================================================
@@ -3627,6 +3714,13 @@ def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user
server.tool_manager.create_tool(tool, actor=default_user)
def test_create_tool_requires_approval(server: SyncServer, bash_tool, default_user, default_organization):
# Assertions to ensure the created tool matches the expected values
assert bash_tool.created_by_id == default_user.id
assert bash_tool.tool_type == ToolType.CUSTOM
assert bash_tool.default_requires_approval == True
def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
# Fetch the tool by ID using the manager method
fetched_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
@@ -4360,6 +4454,28 @@ async def test_pip_requirements_roundtrip(server: SyncServer, default_user, defa
assert reqs_dict["numpy"] is None
async def test_update_default_requires_approval(server: SyncServer, bash_tool, default_user):
# Update field
tool_update = ToolUpdate(default_requires_approval=False)
await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user)
# Fetch the updated tool
updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user)
# Assertions
assert updated_tool.default_requires_approval == False
# Revert update
tool_update = ToolUpdate(default_requires_approval=True)
await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user)
# Fetch the updated tool
updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user)
# Assertions
assert updated_tool.default_requires_approval == True
# ======================================================================================================================
# Message Manager Tests
# ======================================================================================================================