From b0cf7f0e93db3b475fa7b08f3e08fecec43fc533 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 28 Aug 2025 15:38:59 -0700 Subject: [PATCH] feat: add new default_requires_approval flag on tools (#4287) --- ...add_default_requires_approval_field_on_.py | 31 +++++ letta/orm/tool.py | 1 + letta/schemas/tool.py | 5 + letta/services/agent_manager.py | 36 ++++-- tests/test_managers.py | 116 ++++++++++++++++++ 5 files changed, 181 insertions(+), 8 deletions(-) create mode 100644 alembic/versions/c41c87205254_add_default_requires_approval_field_on_.py diff --git a/alembic/versions/c41c87205254_add_default_requires_approval_field_on_.py b/alembic/versions/c41c87205254_add_default_requires_approval_field_on_.py new file mode 100644 index 00000000..cb13822d --- /dev/null +++ b/alembic/versions/c41c87205254_add_default_requires_approval_field_on_.py @@ -0,0 +1,31 @@ +"""add default requires approval field on tools + +Revision ID: c41c87205254 +Revises: 068588268b02 +Create Date: 2025-08-28 13:17:51.636159 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c41c87205254" +down_revision: Union[str, None] = "068588268b02" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("tools", sa.Column("default_requires_approval", sa.Boolean(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("tools", "default_requires_approval") + # ### end Alembic commands ### diff --git a/letta/orm/tool.py b/letta/orm/tool.py index c564f765..e3bd9081 100644 --- a/letta/orm/tool.py +++ b/letta/orm/tool.py @@ -49,6 +49,7 @@ class Tool(SqlalchemyBase, OrganizationMixin): JSON, nullable=True, doc="Optional list of pip packages required by this tool." ) npm_requirements: Mapped[list | None] = mapped_column(JSON, doc="Optional list of npm packages required by this tool.") + default_requires_approval: Mapped[bool] = mapped_column(nullable=True, doc="Whether or not to require approval.") metadata_: Mapped[Optional[dict]] = mapped_column(JSON, default=lambda: {}, doc="A dictionary of additional metadata for the tool.") # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="tools", lazy="selectin") diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index fb71392a..3e630061 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -67,6 +67,9 @@ class Tool(BaseTool): return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.") pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.") npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.") + default_requires_approval: Optional[bool] = Field( + None, description="Default value for whether or not executing this tool requires approval." + ) # metadata fields created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") @@ -168,6 +171,7 @@ class ToolCreate(LettaBase): return_char_limit: int = Field(FUNCTION_RETURN_CHAR_LIMIT, description="The maximum number of characters in the response.") pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.") npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.") + default_requires_approval: Optional[bool] = Field(None, description="Whether or not to require approval before executing this tool.") @classmethod def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate": @@ -248,6 +252,7 @@ class ToolUpdate(LettaBase): pip_requirements: list[PipRequirement] | None = Field(None, description="Optional list of pip packages required by this tool.") npm_requirements: list[NpmRequirement] | None = Field(None, description="Optional list of npm packages required by this tool.") metadata_: Optional[Dict[str, Any]] = Field(None, description="A dictionary of additional metadata for the tool.") + default_requires_approval: Optional[bool] = Field(None, description="Whether or not to require approval before executing this tool.") model_config = ConfigDict(extra="ignore") # Allows extra fields without validation errors # TODO: Remove this, and clean usage of ToolUpdate everywhere else diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 8e91e695..b7d3d87b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -60,7 +60,7 @@ from letta.schemas.message import MessageCreate, MessageUpdate from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.source import Source as PydanticSource from letta.schemas.tool import Tool as PydanticTool -from letta.schemas.tool_rule import ContinueToolRule, TerminalToolRule +from letta.schemas.tool_rule import ContinueToolRule, RequiresApprovalToolRule, TerminalToolRule from letta.schemas.user import User as PydanticUser from letta.serialize_schemas import MarshmallowAgentSchema from letta.serialize_schemas.marshmallow_message import SerializedMessageSchema @@ -163,14 +163,16 @@ class AgentManager: return name_to_id, id_to_name @staticmethod - async def _resolve_tools_async(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]: + async def _resolve_tools_async( + session, names: Set[str], ids: Set[str], org_id: str + ) -> Tuple[Dict[str, str], Dict[str, str], List[str]]: """ Bulk‑fetch all ToolModel rows matching either name ∈ names or id ∈ ids (and scoped to this organization), and return two maps: name_to_id, id_to_name. Raises if any requested name or id was not found. """ - stmt = select(ToolModel.id, ToolModel.name).where( + stmt = select(ToolModel.id, ToolModel.name, ToolModel.default_requires_approval).where( ToolModel.organization_id == org_id, or_( ToolModel.name.in_(names), @@ -181,6 +183,7 @@ class AgentManager: rows = result.fetchall() # Use fetchall() name_to_id = {row[1]: row[0] for row in rows} # row[1] is name, row[0] is id id_to_name = {row[0]: row[1] for row in rows} # row[0] is id, row[1] is name + requires_approval = [row[1] for row in rows if row[2]] # row[1] is name, row[2] is default_requires_approval missing_names = names - set(name_to_id.keys()) missing_ids = ids - set(id_to_name.keys()) @@ -189,7 +192,7 @@ class AgentManager: if missing_ids: raise ValueError(f"Tools not found by id: {missing_ids}") - return name_to_id, id_to_name + return name_to_id, id_to_name, requires_approval @staticmethod def _bulk_insert_pivot(session, table, rows: list[dict]): @@ -556,7 +559,7 @@ class AgentManager: async with db_registry.async_session() as session: async with session.begin(): # Note: This will need to be modified if _resolve_tools needs an async version - name_to_id, id_to_name = await self._resolve_tools_async( + name_to_id, id_to_name, requires_approval = await self._resolve_tools_async( session, tool_names, supplied_ids, @@ -588,6 +591,9 @@ class AgentManager: elif tn in (BASE_TOOLS + BASE_MEMORY_TOOLS + BASE_MEMORY_TOOLS_V2 + BASE_SLEEPTIME_TOOLS): tool_rules.append(ContinueToolRule(tool_name=tn)) + for tool_with_requires_approval in requires_approval: + tool_rules.append(RequiresApprovalToolRule(tool_name=tool_with_requires_approval)) + if tool_rules: check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=tool_rules) @@ -2855,12 +2861,15 @@ class AgentManager: # verify tool exists and belongs to organization in a single query with the insert # first, check if tool exists with correct organization - tool_check_query = select(func.count(ToolModel.id)).where( + tool_check_query = select(ToolModel.name, ToolModel.default_requires_approval).where( ToolModel.id == tool_id, ToolModel.organization_id == actor.organization_id ) - tool_result = await session.execute(tool_check_query) - if tool_result.scalar() == 0: + result = await session.execute(tool_check_query) + tool_rows = result.fetchall() + + if len(tool_rows) == 0: raise NoResultFound(f"Tool with id={tool_id} not found in organization={actor.organization_id}") + tool_name, default_requires_approval = tool_rows[0] # use postgresql on conflict or mysql on duplicate key update for atomic operation if settings.letta_pg_uri_no_default: @@ -2884,6 +2893,17 @@ class AgentManager: else: logger.info(f"Tool id={tool_id} is already attached to agent id={agent_id}") + if default_requires_approval: + agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor) + existing_rules = [rule for rule in agent.tool_rules if rule.tool_name == tool_name and rule.type == "requires_approval"] + if len(existing_rules) == 0: + # Create a new list to ensure SQLAlchemy detects the change + # This is critical for JSON columns - modifying in place doesn't trigger change detection + tool_rules = list(agent.tool_rules) if agent.tool_rules else [] + tool_rules.append(RequiresApprovalToolRule(tool_name=tool_name)) + agent.tool_rules = tool_rules + session.add(agent) + await session.commit() @enforce_types diff --git a/tests/test_managers.py b/tests/test_managers.py index 5a15f6e4..87b227c8 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -247,6 +247,42 @@ async def print_tool(server: SyncServer, default_user, default_organization): yield tool +@pytest.fixture +async def bash_tool(server: SyncServer, default_user, default_organization): + """Fixture to create a bash tool with requires_approval and clean up after the test.""" + + def bash_tool(operation: str): + """ + Args: + operation (str): The bash operation to execute. + + Returns: + str: The result of the executed operation. + """ + print("scary bash operation") + return "success" + + # Set up tool details + source_code = parse_source_code(bash_tool) + source_type = "python" + description = "test_description" + tags = ["test"] + metadata = {"a": "b"} + + tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + tool.default_requires_approval = True + + tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user) + + # Yield the created tool + yield tool + + @pytest.fixture def composio_github_star_tool(server, default_user): tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER") @@ -1881,6 +1917,57 @@ async def test_detach_all_files_tools_async_idempotent(server: SyncServer, sarah assert len(final_agent_state.tools) == tool_count_after_first +@pytest.mark.asyncio +async def test_attach_tool_with_default_requires_approval(server: SyncServer, sarah_agent, bash_tool, default_user): + """Test that attaching a tool with default requires_approval adds associated tool rule.""" + # Attach the tool + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user) + + # Verify attachment through get_agent_by_id + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert bash_tool.id in [t.id for t in agent.tools] + tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name] + assert len(tool_rules) == 1 + assert tool_rules[0].type == "requires_approval" + + # Verify that attaching the same tool again doesn't cause duplication + await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user) + agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user) + assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1 + tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name] + assert len(tool_rules) == 1 + assert tool_rules[0].type == "requires_approval" + + +@pytest.mark.asyncio +async def test_attach_tool_with_default_requires_approval_on_creation(server: SyncServer, bash_tool, default_user): + """Test that attaching a tool with default requires_approval adds associated tool rule.""" + # Create agent with tool + agent = await server.agent_manager.create_agent_async( + agent_create=CreateAgent( + name="agent11", + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tools=[bash_tool.name], + include_base_tools=False, + ), + actor=default_user, + ) + + assert bash_tool.id in [t.id for t in agent.tools] + tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name] + assert len(tool_rules) == 1 + assert tool_rules[0].type == "requires_approval" + + # Verify that attaching the same tool again doesn't cause duplication + await server.agent_manager.attach_tool_async(agent_id=agent.id, tool_id=bash_tool.id, actor=default_user) + agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user) + assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1 + tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name] + assert len(tool_rules) == 1 + assert tool_rules[0].type == "requires_approval" + + # ====================================================================================================================== # AgentManager Tests - Sources Relationship # ====================================================================================================================== @@ -3627,6 +3714,13 @@ def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user server.tool_manager.create_tool(tool, actor=default_user) +def test_create_tool_requires_approval(server: SyncServer, bash_tool, default_user, default_organization): + # Assertions to ensure the created tool matches the expected values + assert bash_tool.created_by_id == default_user.id + assert bash_tool.tool_type == ToolType.CUSTOM + assert bash_tool.default_requires_approval == True + + def test_get_tool_by_id(server: SyncServer, print_tool, default_user): # Fetch the tool by ID using the manager method fetched_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user) @@ -4360,6 +4454,28 @@ async def test_pip_requirements_roundtrip(server: SyncServer, default_user, defa assert reqs_dict["numpy"] is None +async def test_update_default_requires_approval(server: SyncServer, bash_tool, default_user): + # Update field + tool_update = ToolUpdate(default_requires_approval=False) + await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user) + + # Fetch the updated tool + updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user) + + # Assertions + assert updated_tool.default_requires_approval == False + + # Revert update + tool_update = ToolUpdate(default_requires_approval=True) + await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user) + + # Fetch the updated tool + updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user) + + # Assertions + assert updated_tool.default_requires_approval == True + + # ====================================================================================================================== # Message Manager Tests # ======================================================================================================================