diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 73384971..4cb9c86c 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -209,3 +209,30 @@ class ToolRulesSolver(BaseModel): violated_rules.append(rendered_prompt) return violated_rules + + def should_force_tool_call(self) -> bool: + """ + Determine if a tool call should be forced (using 'required' instead of 'auto') based on active constrained tool rules. + + Returns: + bool: True if a constrained tool rule is currently active, False otherwise + """ + # check if we're at the start with init rules + if not self.tool_call_history and self.init_tool_rules: + return True + + # check if any constrained rule is currently active + if self.tool_call_history: + last_tool = self.tool_call_history[-1] + + # check child-based rules (ChildToolRule, ConditionalToolRule) + for rule in self.child_based_tool_rules: + if rule.requires_force_tool_call and rule.tool_name == last_tool: + return True + + # check parent rules, `requires_force_tool_call` for safety in case this gets expanded + for rule in self.parent_tool_rules: + if rule.requires_force_tool_call and rule.tool_name == last_tool: + return True + + return False diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 211ed925..c8bc5c09 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -18,6 +18,7 @@ class BaseToolRule(LettaBase): None, description="Optional template string (ignored). Rendering uses fast built-in formatting for performance.", ) + requires_force_tool_call: bool = False def __hash__(self): """Base hash using tool_name and type.""" @@ -48,6 +49,7 @@ class ChildToolRule(BaseToolRule): default=None, description="Optional template string (ignored).", ) + requires_force_tool_call: bool = True def __hash__(self): """Hash including children list (sorted for consistency).""" @@ -76,6 +78,7 @@ class ParentToolRule(BaseToolRule): type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool children: List[str] = Field(..., description="The children tools that can be invoked.") prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).") + requires_force_tool_call: bool = True def __hash__(self): """Hash including children list (sorted for consistency).""" @@ -106,6 +109,7 @@ class ConditionalToolRule(BaseToolRule): child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping") require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case") prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).") + requires_force_tool_call: bool = True def __hash__(self): """Hash including all configuration fields.""" @@ -187,6 +191,7 @@ class InitToolRule(BaseToolRule): """ type: Literal[ToolRuleType.run_first] = ToolRuleType.run_first + requires_force_tool_call: bool = True class TerminalToolRule(BaseToolRule): @@ -196,6 +201,7 @@ class TerminalToolRule(BaseToolRule): type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).") + requires_force_tool_call: bool = False def render_prompt(self) -> str | None: return f"\n{self.tool_name} ends your response (yields control) when called\n" @@ -208,6 +214,7 @@ class ContinueToolRule(BaseToolRule): type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).") + requires_force_tool_call: bool = False def render_prompt(self) -> str | None: return f"\n{self.tool_name} requires continuing your response when called\n" @@ -220,6 +227,7 @@ class RequiredBeforeExitToolRule(BaseToolRule): type: Literal[ToolRuleType.required_before_exit] = ToolRuleType.required_before_exit prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).") + requires_force_tool_call: bool = False def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: """Returns all available tools - the logic for preventing exit is handled elsewhere.""" @@ -237,6 +245,7 @@ class MaxCountPerStepToolRule(BaseToolRule): type: Literal[ToolRuleType.max_count_per_step] = ToolRuleType.max_count_per_step max_count_limit: int = Field(..., description="The max limit for the total number of times this tool can be invoked in a single step.") prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).") + requires_force_tool_call: bool = False def __hash__(self): """Hash including max_count_limit.""" @@ -268,6 +277,7 @@ class RequiresApprovalToolRule(BaseToolRule): """ type: Literal[ToolRuleType.requires_approval] = ToolRuleType.requires_approval + requires_force_tool_call: bool = False def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: """Does not enforce any restrictions on which tools are valid""" diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 932cb1ea..2609bae7 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -567,3 +567,158 @@ def test_required_before_exit_tool_rule_clear_history(): assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False after clearing history" assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should show required tool as uncalled after clearing history" + + +def test_should_force_tool_call_no_rules(): + """Test should_force_tool_call with no tool rules.""" + solver = ToolRulesSolver(tool_rules=[]) + assert solver.should_force_tool_call() is False, "Should return False when no tool rules are present" + + +def test_should_force_tool_call_init_rule_no_history(): + """Test should_force_tool_call with InitToolRule and no history.""" + init_rule = InitToolRule(tool_name=START_TOOL) + solver = ToolRulesSolver(tool_rules=[init_rule]) + assert solver.should_force_tool_call() is True, "Should return True when InitToolRule is present and no history" + + +def test_should_force_tool_call_init_rule_after_first_call(): + """Test should_force_tool_call with InitToolRule after first tool call.""" + init_rule = InitToolRule(tool_name=START_TOOL) + solver = ToolRulesSolver(tool_rules=[init_rule]) + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is False, "Should return False after first tool call" + + +def test_should_force_tool_call_child_rule_active(): + """Test should_force_tool_call when ChildToolRule is active.""" + child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) + solver = ToolRulesSolver(tool_rules=[child_rule]) + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is True, "Should return True when last tool matches ChildToolRule" + + +def test_should_force_tool_call_child_rule_inactive(): + """Test should_force_tool_call when ChildToolRule is not active.""" + child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) + solver = ToolRulesSolver(tool_rules=[child_rule]) + + solver.register_tool_call(HELPER_TOOL) + assert solver.should_force_tool_call() is False, "Should return False when last tool doesn't match ChildToolRule" + + +def test_should_force_tool_call_conditional_rule_active(): + """Test should_force_tool_call when ConditionalToolRule is active.""" + conditional_rule = ConditionalToolRule( + tool_name=START_TOOL, child_output_mapping={True: END_TOOL, False: NEXT_TOOL}, default_child=None + ) + solver = ToolRulesSolver(tool_rules=[conditional_rule]) + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is True, "Should return True when last tool matches ConditionalToolRule" + + +def test_should_force_tool_call_parent_rule_active(): + """Test should_force_tool_call when ParentToolRule is active.""" + parent_rule = ParentToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) + solver = ToolRulesSolver(tool_rules=[parent_rule]) + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is True, "Should return True when last tool matches ParentToolRule" + + +def test_should_force_tool_call_max_count_rule(): + """Test should_force_tool_call with MaxCountPerStepToolRule (non-constraining).""" + max_count_rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) + solver = ToolRulesSolver(tool_rules=[max_count_rule]) + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is False, "Should return False for MaxCountPerStepToolRule (not a constraining rule)" + + +def test_should_force_tool_call_terminal_rule(): + """Test should_force_tool_call with TerminalToolRule.""" + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + solver = ToolRulesSolver(tool_rules=[terminal_rule]) + + solver.register_tool_call(END_TOOL) + assert solver.should_force_tool_call() is False, "Should return False for TerminalToolRule" + + +def test_should_force_tool_call_continue_rule(): + """Test should_force_tool_call with ContinueToolRule.""" + continue_rule = ContinueToolRule(tool_name=NEXT_TOOL) + solver = ToolRulesSolver(tool_rules=[continue_rule]) + + solver.register_tool_call(NEXT_TOOL) + assert solver.should_force_tool_call() is False, "Should return False for ContinueToolRule" + + +def test_should_force_tool_call_required_before_exit_rule(): + """Test should_force_tool_call with RequiredBeforeExitToolRule.""" + required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) + solver = ToolRulesSolver(tool_rules=[required_rule]) + + solver.register_tool_call(SAVE_TOOL) + assert solver.should_force_tool_call() is False, "Should return False for RequiredBeforeExitToolRule" + + +def test_should_force_tool_call_requires_approval_rule(): + """Test should_force_tool_call with RequiresApprovalToolRule.""" + approval_rule = RequiresApprovalToolRule(tool_name=REQUIRES_APPROVAL_TOOL) + solver = ToolRulesSolver(tool_rules=[approval_rule]) + + solver.register_tool_call(REQUIRES_APPROVAL_TOOL) + assert solver.should_force_tool_call() is False, "Should return False for RequiresApprovalToolRule" + + +def test_should_force_tool_call_multiple_constrained_rules_one_active(): + """Test should_force_tool_call with multiple constrained rules where one is active.""" + child_rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) + child_rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL]) + parent_rule = ParentToolRule(tool_name=PREP_TOOL, children=[HELPER_TOOL]) + solver = ToolRulesSolver(tool_rules=[child_rule_1, child_rule_2, parent_rule]) + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is True, "Should return True when one constrained rule is active" + + solver.register_tool_call(NEXT_TOOL) + assert solver.should_force_tool_call() is True, "Should return True when a different constrained rule becomes active" + + solver.register_tool_call(FINAL_TOOL) + assert solver.should_force_tool_call() is False, "Should return False when no constrained rules are active" + + +def test_should_force_tool_call_after_clear_with_init_rule(): + """Test should_force_tool_call after clearing history with InitToolRule.""" + init_rule = InitToolRule(tool_name=START_TOOL) + child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) + solver = ToolRulesSolver(tool_rules=[init_rule, child_rule]) + + assert solver.should_force_tool_call() is True, "Should return True initially with InitToolRule" + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is True, "Should return True when ChildToolRule is active" + + solver.clear_tool_history() + assert solver.should_force_tool_call() is True, "Should return True again after clearing history with InitToolRule" + + +def test_should_force_tool_call_mixed_rules(): + """Test should_force_tool_call with a mix of constraining and non-constraining rules.""" + init_rule = InitToolRule(tool_name=START_TOOL) + child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) + terminal_rule = TerminalToolRule(tool_name=END_TOOL) + continue_rule = ContinueToolRule(tool_name=HELPER_TOOL) + max_count_rule = MaxCountPerStepToolRule(tool_name=NEXT_TOOL, max_count_limit=2) + solver = ToolRulesSolver(tool_rules=[init_rule, child_rule, terminal_rule, continue_rule, max_count_rule]) + + assert solver.should_force_tool_call() is True, "Should return True with InitToolRule at start" + + solver.register_tool_call(START_TOOL) + assert solver.should_force_tool_call() is True, "Should return True when ChildToolRule is active" + + solver.register_tool_call(NEXT_TOOL) + assert solver.should_force_tool_call() is False, "Should return False when no constraining rules are active"