diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py
index 73384971..4cb9c86c 100644
--- a/letta/helpers/tool_rule_solver.py
+++ b/letta/helpers/tool_rule_solver.py
@@ -209,3 +209,30 @@ class ToolRulesSolver(BaseModel):
violated_rules.append(rendered_prompt)
return violated_rules
+
+ def should_force_tool_call(self) -> bool:
+ """
+ Determine if a tool call should be forced (using 'required' instead of 'auto') based on active constrained tool rules.
+
+ Returns:
+ bool: True if a constrained tool rule is currently active, False otherwise
+ """
+ # check if we're at the start with init rules
+ if not self.tool_call_history and self.init_tool_rules:
+ return True
+
+ # check if any constrained rule is currently active
+ if self.tool_call_history:
+ last_tool = self.tool_call_history[-1]
+
+ # check child-based rules (ChildToolRule, ConditionalToolRule)
+ for rule in self.child_based_tool_rules:
+ if rule.requires_force_tool_call and rule.tool_name == last_tool:
+ return True
+
+ # check parent rules, `requires_force_tool_call` for safety in case this gets expanded
+ for rule in self.parent_tool_rules:
+ if rule.requires_force_tool_call and rule.tool_name == last_tool:
+ return True
+
+ return False
diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py
index 211ed925..c8bc5c09 100644
--- a/letta/schemas/tool_rule.py
+++ b/letta/schemas/tool_rule.py
@@ -18,6 +18,7 @@ class BaseToolRule(LettaBase):
None,
description="Optional template string (ignored). Rendering uses fast built-in formatting for performance.",
)
+ requires_force_tool_call: bool = False
def __hash__(self):
"""Base hash using tool_name and type."""
@@ -48,6 +49,7 @@ class ChildToolRule(BaseToolRule):
default=None,
description="Optional template string (ignored).",
)
+ requires_force_tool_call: bool = True
def __hash__(self):
"""Hash including children list (sorted for consistency)."""
@@ -76,6 +78,7 @@ class ParentToolRule(BaseToolRule):
type: Literal[ToolRuleType.parent_last_tool] = ToolRuleType.parent_last_tool
children: List[str] = Field(..., description="The children tools that can be invoked.")
prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).")
+ requires_force_tool_call: bool = True
def __hash__(self):
"""Hash including children list (sorted for consistency)."""
@@ -106,6 +109,7 @@ class ConditionalToolRule(BaseToolRule):
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")
prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).")
+ requires_force_tool_call: bool = True
def __hash__(self):
"""Hash including all configuration fields."""
@@ -187,6 +191,7 @@ class InitToolRule(BaseToolRule):
"""
type: Literal[ToolRuleType.run_first] = ToolRuleType.run_first
+ requires_force_tool_call: bool = True
class TerminalToolRule(BaseToolRule):
@@ -196,6 +201,7 @@ class TerminalToolRule(BaseToolRule):
type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop
prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).")
+ requires_force_tool_call: bool = False
def render_prompt(self) -> str | None:
return f"\n{self.tool_name} ends your response (yields control) when called\n"
@@ -208,6 +214,7 @@ class ContinueToolRule(BaseToolRule):
type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop
prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).")
+ requires_force_tool_call: bool = False
def render_prompt(self) -> str | None:
return f"\n{self.tool_name} requires continuing your response when called\n"
@@ -220,6 +227,7 @@ class RequiredBeforeExitToolRule(BaseToolRule):
type: Literal[ToolRuleType.required_before_exit] = ToolRuleType.required_before_exit
prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).")
+ requires_force_tool_call: bool = False
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
"""Returns all available tools - the logic for preventing exit is handled elsewhere."""
@@ -237,6 +245,7 @@ class MaxCountPerStepToolRule(BaseToolRule):
type: Literal[ToolRuleType.max_count_per_step] = ToolRuleType.max_count_per_step
max_count_limit: int = Field(..., description="The max limit for the total number of times this tool can be invoked in a single step.")
prompt_template: Optional[str] = Field(default=None, description="Optional template string (ignored).")
+ requires_force_tool_call: bool = False
def __hash__(self):
"""Hash including max_count_limit."""
@@ -268,6 +277,7 @@ class RequiresApprovalToolRule(BaseToolRule):
"""
type: Literal[ToolRuleType.requires_approval] = ToolRuleType.requires_approval
+ requires_force_tool_call: bool = False
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
"""Does not enforce any restrictions on which tools are valid"""
diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py
index 932cb1ea..2609bae7 100644
--- a/tests/test_tool_rule_solver.py
+++ b/tests/test_tool_rule_solver.py
@@ -567,3 +567,158 @@ def test_required_before_exit_tool_rule_clear_history():
assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False after clearing history"
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"
+
+
+def test_should_force_tool_call_no_rules():
+ """Test should_force_tool_call with no tool rules."""
+ solver = ToolRulesSolver(tool_rules=[])
+ assert solver.should_force_tool_call() is False, "Should return False when no tool rules are present"
+
+
+def test_should_force_tool_call_init_rule_no_history():
+ """Test should_force_tool_call with InitToolRule and no history."""
+ init_rule = InitToolRule(tool_name=START_TOOL)
+ solver = ToolRulesSolver(tool_rules=[init_rule])
+ assert solver.should_force_tool_call() is True, "Should return True when InitToolRule is present and no history"
+
+
+def test_should_force_tool_call_init_rule_after_first_call():
+ """Test should_force_tool_call with InitToolRule after first tool call."""
+ init_rule = InitToolRule(tool_name=START_TOOL)
+ solver = ToolRulesSolver(tool_rules=[init_rule])
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False after first tool call"
+
+
+def test_should_force_tool_call_child_rule_active():
+ """Test should_force_tool_call when ChildToolRule is active."""
+ child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL])
+ solver = ToolRulesSolver(tool_rules=[child_rule])
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is True, "Should return True when last tool matches ChildToolRule"
+
+
+def test_should_force_tool_call_child_rule_inactive():
+ """Test should_force_tool_call when ChildToolRule is not active."""
+ child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL])
+ solver = ToolRulesSolver(tool_rules=[child_rule])
+
+ solver.register_tool_call(HELPER_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False when last tool doesn't match ChildToolRule"
+
+
+def test_should_force_tool_call_conditional_rule_active():
+ """Test should_force_tool_call when ConditionalToolRule is active."""
+ conditional_rule = ConditionalToolRule(
+ tool_name=START_TOOL, child_output_mapping={True: END_TOOL, False: NEXT_TOOL}, default_child=None
+ )
+ solver = ToolRulesSolver(tool_rules=[conditional_rule])
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is True, "Should return True when last tool matches ConditionalToolRule"
+
+
+def test_should_force_tool_call_parent_rule_active():
+ """Test should_force_tool_call when ParentToolRule is active."""
+ parent_rule = ParentToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL])
+ solver = ToolRulesSolver(tool_rules=[parent_rule])
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is True, "Should return True when last tool matches ParentToolRule"
+
+
+def test_should_force_tool_call_max_count_rule():
+ """Test should_force_tool_call with MaxCountPerStepToolRule (non-constraining)."""
+ max_count_rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2)
+ solver = ToolRulesSolver(tool_rules=[max_count_rule])
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False for MaxCountPerStepToolRule (not a constraining rule)"
+
+
+def test_should_force_tool_call_terminal_rule():
+ """Test should_force_tool_call with TerminalToolRule."""
+ terminal_rule = TerminalToolRule(tool_name=END_TOOL)
+ solver = ToolRulesSolver(tool_rules=[terminal_rule])
+
+ solver.register_tool_call(END_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False for TerminalToolRule"
+
+
+def test_should_force_tool_call_continue_rule():
+ """Test should_force_tool_call with ContinueToolRule."""
+ continue_rule = ContinueToolRule(tool_name=NEXT_TOOL)
+ solver = ToolRulesSolver(tool_rules=[continue_rule])
+
+ solver.register_tool_call(NEXT_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False for ContinueToolRule"
+
+
+def test_should_force_tool_call_required_before_exit_rule():
+ """Test should_force_tool_call with RequiredBeforeExitToolRule."""
+ required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL)
+ solver = ToolRulesSolver(tool_rules=[required_rule])
+
+ solver.register_tool_call(SAVE_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False for RequiredBeforeExitToolRule"
+
+
+def test_should_force_tool_call_requires_approval_rule():
+ """Test should_force_tool_call with RequiresApprovalToolRule."""
+ approval_rule = RequiresApprovalToolRule(tool_name=REQUIRES_APPROVAL_TOOL)
+ solver = ToolRulesSolver(tool_rules=[approval_rule])
+
+ solver.register_tool_call(REQUIRES_APPROVAL_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False for RequiresApprovalToolRule"
+
+
+def test_should_force_tool_call_multiple_constrained_rules_one_active():
+ """Test should_force_tool_call with multiple constrained rules where one is active."""
+ child_rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
+ child_rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL])
+ parent_rule = ParentToolRule(tool_name=PREP_TOOL, children=[HELPER_TOOL])
+ solver = ToolRulesSolver(tool_rules=[child_rule_1, child_rule_2, parent_rule])
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is True, "Should return True when one constrained rule is active"
+
+ solver.register_tool_call(NEXT_TOOL)
+ assert solver.should_force_tool_call() is True, "Should return True when a different constrained rule becomes active"
+
+ solver.register_tool_call(FINAL_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False when no constrained rules are active"
+
+
+def test_should_force_tool_call_after_clear_with_init_rule():
+ """Test should_force_tool_call after clearing history with InitToolRule."""
+ init_rule = InitToolRule(tool_name=START_TOOL)
+ child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
+ solver = ToolRulesSolver(tool_rules=[init_rule, child_rule])
+
+ assert solver.should_force_tool_call() is True, "Should return True initially with InitToolRule"
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is True, "Should return True when ChildToolRule is active"
+
+ solver.clear_tool_history()
+ assert solver.should_force_tool_call() is True, "Should return True again after clearing history with InitToolRule"
+
+
+def test_should_force_tool_call_mixed_rules():
+ """Test should_force_tool_call with a mix of constraining and non-constraining rules."""
+ init_rule = InitToolRule(tool_name=START_TOOL)
+ child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
+ terminal_rule = TerminalToolRule(tool_name=END_TOOL)
+ continue_rule = ContinueToolRule(tool_name=HELPER_TOOL)
+ max_count_rule = MaxCountPerStepToolRule(tool_name=NEXT_TOOL, max_count_limit=2)
+ solver = ToolRulesSolver(tool_rules=[init_rule, child_rule, terminal_rule, continue_rule, max_count_rule])
+
+ assert solver.should_force_tool_call() is True, "Should return True with InitToolRule at start"
+
+ solver.register_tool_call(START_TOOL)
+ assert solver.should_force_tool_call() is True, "Should return True when ChildToolRule is active"
+
+ solver.register_tool_call(NEXT_TOOL)
+ assert solver.should_force_tool_call() is False, "Should return False when no constraining rules are active"