feat: add requires approval logic to tool rules solver (#4294)

This commit is contained in:
cthomas
2025-08-28 17:06:46 -07:00
committed by GitHub
parent c1f8c48818
commit e6c2c2121e
2 changed files with 23 additions and 0 deletions

View File

@@ -11,6 +11,7 @@ from letta.schemas.tool_rule import (
MaxCountPerStepToolRule,
ParentToolRule,
RequiredBeforeExitToolRule,
RequiresApprovalToolRule,
TerminalToolRule,
ToolRule,
)
@@ -44,6 +45,9 @@ class ToolRulesSolver(BaseModel):
required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field(
default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True
)
requires_approval_tool_rules: list[RequiresApprovalToolRule] = Field(
default_factory=list, description="Tool rules that trigger an approval request for human-in-the-loop.", exclude=True
)
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
@@ -68,6 +72,8 @@ class ToolRulesSolver(BaseModel):
self.parent_tool_rules.append(rule)
elif isinstance(rule, RequiredBeforeExitToolRule):
self.required_before_exit_tool_rules.append(rule)
elif isinstance(rule, RequiresApprovalToolRule):
self.requires_approval_tool_rules.append(rule)
def register_tool_call(self, tool_name: str):
"""Update the internal state to track tool call history."""
@@ -117,6 +123,10 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def is_requires_approval_tool(self, tool_name: ToolName):
"""Check if the tool is defined as a requires-approval tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.requires_approval_tool_rules)
def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool:
"""Check if all required-before-exit tools have been called."""
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0

View File

@@ -9,6 +9,7 @@ from letta.schemas.tool_rule import (
MaxCountPerStepToolRule,
ParentToolRule,
RequiredBeforeExitToolRule,
RequiresApprovalToolRule,
TerminalToolRule,
)
@@ -23,6 +24,7 @@ UNRECOGNIZED_TOOL = "unrecognized_tool"
REQUIRED_TOOL_1 = "required_tool_1"
REQUIRED_TOOL_2 = "required_tool_2"
SAVE_TOOL = "save_tool"
REQUIRES_APPROVAL_TOOL = "requires_approval_tool"
def test_get_allowed_tool_names_with_init_rules():
@@ -55,6 +57,17 @@ def test_is_terminal_tool():
assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool"
def test_is_requires_approval_tool():
init_rule = InitToolRule(tool_name=START_TOOL)
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
requires_approval_tool = RequiresApprovalToolRule(tool_name=REQUIRES_APPROVAL_TOOL)
solver = ToolRulesSolver(tool_rules=[init_rule, terminal_rule, requires_approval_tool])
assert solver.is_requires_approval_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a requires approval tool"
assert solver.is_requires_approval_tool(END_TOOL) is False, "Should not recognize 'end_tool' as a requires approval tool"
assert solver.is_requires_approval_tool(REQUIRES_APPROVAL_TOOL) is True, "Should recognize 'requires_approval_tool' as a terminal tool"
def test_get_allowed_tool_names_no_matching_rule_error():
init_rule = InitToolRule(tool_name=START_TOOL)
solver = ToolRulesSolver(tool_rules=[init_rule])