feat: add new default_requires_approval flag on tools (#4287)
This commit is contained in:
@@ -247,6 +247,42 @@ async def print_tool(server: SyncServer, default_user, default_organization):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
async def bash_tool(server: SyncServer, default_user, default_organization):
|
||||
"""Fixture to create a bash tool with requires_approval and clean up after the test."""
|
||||
|
||||
def bash_tool(operation: str):
|
||||
"""
|
||||
Args:
|
||||
operation (str): The bash operation to execute.
|
||||
|
||||
Returns:
|
||||
str: The result of the executed operation.
|
||||
"""
|
||||
print("scary bash operation")
|
||||
return "success"
|
||||
|
||||
# Set up tool details
|
||||
source_code = parse_source_code(bash_tool)
|
||||
source_type = "python"
|
||||
description = "test_description"
|
||||
tags = ["test"]
|
||||
metadata = {"a": "b"}
|
||||
|
||||
tool = PydanticTool(description=description, tags=tags, source_code=source_code, source_type=source_type, metadata_=metadata)
|
||||
derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name)
|
||||
|
||||
derived_name = derived_json_schema["name"]
|
||||
tool.json_schema = derived_json_schema
|
||||
tool.name = derived_name
|
||||
tool.default_requires_approval = True
|
||||
|
||||
tool = await server.tool_manager.create_or_update_tool_async(tool, actor=default_user)
|
||||
|
||||
# Yield the created tool
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def composio_github_star_tool(server, default_user):
|
||||
tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER")
|
||||
@@ -1881,6 +1917,57 @@ async def test_detach_all_files_tools_async_idempotent(server: SyncServer, sarah
|
||||
assert len(final_agent_state.tools) == tool_count_after_first
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_tool_with_default_requires_approval(server: SyncServer, sarah_agent, bash_tool, default_user):
|
||||
"""Test that attaching a tool with default requires_approval adds associated tool rule."""
|
||||
# Attach the tool
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user)
|
||||
|
||||
# Verify attachment through get_agent_by_id
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert bash_tool.id in [t.id for t in agent.tools]
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
# Verify that attaching the same tool again doesn't cause duplication
|
||||
await server.agent_manager.attach_tool_async(agent_id=sarah_agent.id, tool_id=bash_tool.id, actor=default_user)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=sarah_agent.id, actor=default_user)
|
||||
assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_attach_tool_with_default_requires_approval_on_creation(server: SyncServer, bash_tool, default_user):
|
||||
"""Test that attaching a tool with default requires_approval adds associated tool rule."""
|
||||
# Create agent with tool
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="agent11",
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tools=[bash_tool.name],
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert bash_tool.id in [t.id for t in agent.tools]
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
# Verify that attaching the same tool again doesn't cause duplication
|
||||
await server.agent_manager.attach_tool_async(agent_id=agent.id, tool_id=bash_tool.id, actor=default_user)
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id=agent.id, actor=default_user)
|
||||
assert len([t for t in agent.tools if t.id == bash_tool.id]) == 1
|
||||
tool_rules = [rule for rule in agent.tool_rules if rule.tool_name == bash_tool.name]
|
||||
assert len(tool_rules) == 1
|
||||
assert tool_rules[0].type == "requires_approval"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Sources Relationship
|
||||
# ======================================================================================================================
|
||||
@@ -3627,6 +3714,13 @@ def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user
|
||||
server.tool_manager.create_tool(tool, actor=default_user)
|
||||
|
||||
|
||||
def test_create_tool_requires_approval(server: SyncServer, bash_tool, default_user, default_organization):
|
||||
# Assertions to ensure the created tool matches the expected values
|
||||
assert bash_tool.created_by_id == default_user.id
|
||||
assert bash_tool.tool_type == ToolType.CUSTOM
|
||||
assert bash_tool.default_requires_approval == True
|
||||
|
||||
|
||||
def test_get_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
# Fetch the tool by ID using the manager method
|
||||
fetched_tool = server.tool_manager.get_tool_by_id(print_tool.id, actor=default_user)
|
||||
@@ -4360,6 +4454,28 @@ async def test_pip_requirements_roundtrip(server: SyncServer, default_user, defa
|
||||
assert reqs_dict["numpy"] is None
|
||||
|
||||
|
||||
async def test_update_default_requires_approval(server: SyncServer, bash_tool, default_user):
|
||||
# Update field
|
||||
tool_update = ToolUpdate(default_requires_approval=False)
|
||||
await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool
|
||||
updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert updated_tool.default_requires_approval == False
|
||||
|
||||
# Revert update
|
||||
tool_update = ToolUpdate(default_requires_approval=True)
|
||||
await server.tool_manager.update_tool_by_id_async(bash_tool.id, tool_update, actor=default_user)
|
||||
|
||||
# Fetch the updated tool
|
||||
updated_tool = await server.tool_manager.get_tool_by_id_async(bash_tool.id, actor=default_user)
|
||||
|
||||
# Assertions
|
||||
assert updated_tool.default_requires_approval == True
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Message Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user