diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index b3ffc402..e0f1f4d5 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -11,6 +11,7 @@ from letta.schemas.tool_rule import ( MaxCountPerStepToolRule, ParentToolRule, RequiredBeforeExitToolRule, + RequiresApprovalToolRule, TerminalToolRule, ToolRule, ) @@ -44,6 +45,9 @@ class ToolRulesSolver(BaseModel): required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field( default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True ) + requires_approval_tool_rules: list[RequiresApprovalToolRule] = Field( + default_factory=list, description="Tool rules that trigger an approval request for human-in-the-loop.", exclude=True + ) tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.") def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs): @@ -68,6 +72,8 @@ class ToolRulesSolver(BaseModel): self.parent_tool_rules.append(rule) elif isinstance(rule, RequiredBeforeExitToolRule): self.required_before_exit_tool_rules.append(rule) + elif isinstance(rule, RequiresApprovalToolRule): + self.requires_approval_tool_rules.append(rule) def register_tool_call(self, tool_name: str): """Update the internal state to track tool call history.""" @@ -117,6 +123,10 @@ class ToolRulesSolver(BaseModel): """Check if the tool is defined as a continue tool in the tool rules.""" return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) + def is_requires_approval_tool(self, tool_name: ToolName): + """Check if the tool is defined as a requires-approval tool in the tool rules.""" + return any(rule.tool_name == tool_name for rule in self.requires_approval_tool_rules) + def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool: """Check if all required-before-exit tools have been called.""" return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0 diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 5c6a5e86..54077563 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -9,6 +9,7 @@ from letta.schemas.tool_rule import ( MaxCountPerStepToolRule, ParentToolRule, RequiredBeforeExitToolRule, + RequiresApprovalToolRule, TerminalToolRule, ) @@ -23,6 +24,7 @@ UNRECOGNIZED_TOOL = "unrecognized_tool" REQUIRED_TOOL_1 = "required_tool_1" REQUIRED_TOOL_2 = "required_tool_2" SAVE_TOOL = "save_tool" +REQUIRES_APPROVAL_TOOL = "requires_approval_tool" def test_get_allowed_tool_names_with_init_rules(): @@ -55,6 +57,17 @@ def test_is_terminal_tool(): assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool" +def test_is_requires_approval_tool(): + init_rule = InitToolRule(tool_name=START_TOOL) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + requires_approval_tool = RequiresApprovalToolRule(tool_name=REQUIRES_APPROVAL_TOOL) + solver = ToolRulesSolver(tool_rules=[init_rule, terminal_rule, requires_approval_tool]) + + assert solver.is_requires_approval_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a requires approval tool" + assert solver.is_requires_approval_tool(END_TOOL) is False, "Should not recognize 'end_tool' as a requires approval tool" + assert solver.is_requires_approval_tool(REQUIRES_APPROVAL_TOOL) is True, "Should recognize 'requires_approval_tool' as a terminal tool" + + def test_get_allowed_tool_names_no_matching_rule_error(): init_rule = InitToolRule(tool_name=START_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule])