feat: add requires approval logic to tool rules solver (#4294)
This commit is contained in:
@@ -11,6 +11,7 @@ from letta.schemas.tool_rule import (
|
|||||||
MaxCountPerStepToolRule,
|
MaxCountPerStepToolRule,
|
||||||
ParentToolRule,
|
ParentToolRule,
|
||||||
RequiredBeforeExitToolRule,
|
RequiredBeforeExitToolRule,
|
||||||
|
RequiresApprovalToolRule,
|
||||||
TerminalToolRule,
|
TerminalToolRule,
|
||||||
ToolRule,
|
ToolRule,
|
||||||
)
|
)
|
||||||
@@ -44,6 +45,9 @@ class ToolRulesSolver(BaseModel):
|
|||||||
required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field(
|
required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field(
|
||||||
default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True
|
default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True
|
||||||
)
|
)
|
||||||
|
requires_approval_tool_rules: list[RequiresApprovalToolRule] = Field(
|
||||||
|
default_factory=list, description="Tool rules that trigger an approval request for human-in-the-loop.", exclude=True
|
||||||
|
)
|
||||||
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
||||||
|
|
||||||
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
|
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
|
||||||
@@ -68,6 +72,8 @@ class ToolRulesSolver(BaseModel):
|
|||||||
self.parent_tool_rules.append(rule)
|
self.parent_tool_rules.append(rule)
|
||||||
elif isinstance(rule, RequiredBeforeExitToolRule):
|
elif isinstance(rule, RequiredBeforeExitToolRule):
|
||||||
self.required_before_exit_tool_rules.append(rule)
|
self.required_before_exit_tool_rules.append(rule)
|
||||||
|
elif isinstance(rule, RequiresApprovalToolRule):
|
||||||
|
self.requires_approval_tool_rules.append(rule)
|
||||||
|
|
||||||
def register_tool_call(self, tool_name: str):
|
def register_tool_call(self, tool_name: str):
|
||||||
"""Update the internal state to track tool call history."""
|
"""Update the internal state to track tool call history."""
|
||||||
@@ -117,6 +123,10 @@ class ToolRulesSolver(BaseModel):
|
|||||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||||
|
|
||||||
|
def is_requires_approval_tool(self, tool_name: ToolName):
|
||||||
|
"""Check if the tool is defined as a requires-approval tool in the tool rules."""
|
||||||
|
return any(rule.tool_name == tool_name for rule in self.requires_approval_tool_rules)
|
||||||
|
|
||||||
def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool:
|
def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool:
|
||||||
"""Check if all required-before-exit tools have been called."""
|
"""Check if all required-before-exit tools have been called."""
|
||||||
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ from letta.schemas.tool_rule import (
|
|||||||
MaxCountPerStepToolRule,
|
MaxCountPerStepToolRule,
|
||||||
ParentToolRule,
|
ParentToolRule,
|
||||||
RequiredBeforeExitToolRule,
|
RequiredBeforeExitToolRule,
|
||||||
|
RequiresApprovalToolRule,
|
||||||
TerminalToolRule,
|
TerminalToolRule,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,6 +24,7 @@ UNRECOGNIZED_TOOL = "unrecognized_tool"
|
|||||||
REQUIRED_TOOL_1 = "required_tool_1"
|
REQUIRED_TOOL_1 = "required_tool_1"
|
||||||
REQUIRED_TOOL_2 = "required_tool_2"
|
REQUIRED_TOOL_2 = "required_tool_2"
|
||||||
SAVE_TOOL = "save_tool"
|
SAVE_TOOL = "save_tool"
|
||||||
|
REQUIRES_APPROVAL_TOOL = "requires_approval_tool"
|
||||||
|
|
||||||
|
|
||||||
def test_get_allowed_tool_names_with_init_rules():
|
def test_get_allowed_tool_names_with_init_rules():
|
||||||
@@ -55,6 +57,17 @@ def test_is_terminal_tool():
|
|||||||
assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool"
|
assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool"
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_requires_approval_tool():
|
||||||
|
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||||
|
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||||
|
requires_approval_tool = RequiresApprovalToolRule(tool_name=REQUIRES_APPROVAL_TOOL)
|
||||||
|
solver = ToolRulesSolver(tool_rules=[init_rule, terminal_rule, requires_approval_tool])
|
||||||
|
|
||||||
|
assert solver.is_requires_approval_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a requires approval tool"
|
||||||
|
assert solver.is_requires_approval_tool(END_TOOL) is False, "Should not recognize 'end_tool' as a requires approval tool"
|
||||||
|
assert solver.is_requires_approval_tool(REQUIRES_APPROVAL_TOOL) is True, "Should recognize 'requires_approval_tool' as a terminal tool"
|
||||||
|
|
||||||
|
|
||||||
def test_get_allowed_tool_names_no_matching_rule_error():
|
def test_get_allowed_tool_names_no_matching_rule_error():
|
||||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||||
solver = ToolRulesSolver(tool_rules=[init_rule])
|
solver = ToolRulesSolver(tool_rules=[init_rule])
|
||||||
|
|||||||
Reference in New Issue
Block a user