import pytest from letta.agents.helpers import merge_and_validate_prefilled_args from letta.helpers import ToolRulesSolver from letta.schemas.enums import ToolType from letta.schemas.tool import Tool from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, ParentToolRule, RequiredBeforeExitToolRule, RequiresApprovalToolRule, TerminalToolRule, ToolCallNode, ) # Constants for tool names used in the tests START_TOOL = "start_tool" PREP_TOOL = "prep_tool" NEXT_TOOL = "next_tool" HELPER_TOOL = "helper_tool" FINAL_TOOL = "final_tool" END_TOOL = "end_tool" UNRECOGNIZED_TOOL = "unrecognized_tool" REQUIRED_TOOL_1 = "required_tool_1" REQUIRED_TOOL_2 = "required_tool_2" SAVE_TOOL = "save_tool" REQUIRES_APPROVAL_TOOL = "requires_approval_tool" def test_get_allowed_tool_names_with_init_rules(): init_rule_1 = InitToolRule(tool_name=START_TOOL) init_rule_2 = InitToolRule(tool_name=PREP_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule_1, init_rule_2]) allowed_tools = solver.get_allowed_tool_names(set()) assert allowed_tools == [START_TOOL, PREP_TOOL], "Should allow only InitToolRule tools at the start" def test_get_allowed_tool_names_with_subsequent_rule(): init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) solver = ToolRulesSolver(tool_rules=[init_rule, rule_1]) solver.register_tool_call(START_TOOL) allowed_tools = solver.get_allowed_tool_names({START_TOOL, NEXT_TOOL, HELPER_TOOL}) assert sorted(allowed_tools) == sorted([NEXT_TOOL, HELPER_TOOL]), "Should allow only children of the last tool used" def test_is_terminal_tool(): init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule, terminal_rule]) assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool" assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool" def test_is_requires_approval_tool(): init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) requires_approval_tool = RequiresApprovalToolRule(tool_name=REQUIRES_APPROVAL_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule, terminal_rule, requires_approval_tool]) assert solver.is_requires_approval_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a requires approval tool" assert solver.is_requires_approval_tool(END_TOOL) is False, "Should not recognize 'end_tool' as a requires approval tool" assert solver.is_requires_approval_tool(REQUIRES_APPROVAL_TOOL) is True, "Should recognize 'requires_approval_tool' as a terminal tool" def test_get_allowed_tool_names_no_matching_rule_error(): init_rule = InitToolRule(tool_name=START_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule]) solver.register_tool_call(UNRECOGNIZED_TOOL) with pytest.raises(ValueError, match=r"No valid tools found based on tool rules."): solver.get_allowed_tool_names(set(), error_on_empty=True) def test_update_tool_usage_and_get_allowed_tool_names_combined(): init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL]) terminal_rule = TerminalToolRule(tool_name=FINAL_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, terminal_rule]) assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initial allowed tool should be 'start_tool'" solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({NEXT_TOOL}) == [NEXT_TOOL], "After 'start_tool', should allow 'next_tool'" solver.register_tool_call(NEXT_TOOL) assert solver.get_allowed_tool_names({FINAL_TOOL}) == [FINAL_TOOL], "After 'next_tool', should allow 'final_tool'" assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal" def test_conditional_tool_rule(): init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) rule = ConditionalToolRule(tool_name=START_TOOL, default_child=None, child_output_mapping={True: END_TOOL, False: START_TOOL}) solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initial allowed tool should be 'start_tool'" solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({END_TOOL}, last_function_response='{"message": "true"}') == [END_TOOL], ( "After 'start_tool' returns true, should allow 'end_tool'" ) assert solver.get_allowed_tool_names({START_TOOL}, last_function_response='{"message": "false"}') == [START_TOOL], ( "After 'start_tool' returns false, should allow 'start_tool'" ) assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal" def test_invalid_conditional_tool_rule(): with pytest.raises(ValueError, match=r"Conditional tool rule must have at least one child tool."): ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={}) def test_tool_rules_with_invalid_path(): init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL]) rule_3 = ChildToolRule(tool_name=HELPER_TOOL, children=[START_TOOL]) rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) terminal_rule = TerminalToolRule(tool_name=END_TOOL) ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) rule_5 = ConditionalToolRule( tool_name=HELPER_TOOL, default_child=FINAL_TOOL, child_output_mapping={True: START_TOOL, False: FINAL_TOOL}, ) ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule]) def test_max_count_per_step_tool_rule(): init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) solver = ToolRulesSolver(tool_rules=[init_rule, rule_1]) assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initially should allow 'start_tool'" solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'" solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], ( "After reaching max count, 'start_tool' should no longer be allowed" ) def test_max_count_per_step_tool_rule_allows_usage_up_to_limit(): """Ensure the tool is allowed exactly max_count_limit times.""" rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=3) solver = ToolRulesSolver(tool_rules=[rule]) assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initially should allow 'start_tool'" solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 1 use" solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses" solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should no longer allow 'start_tool' after 3 uses" def test_max_count_per_step_tool_rule_does_not_affect_other_tools(): """Ensure exceeding max count for one tool does not impact others.""" rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) another_tool_rules = ChildToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL]) solver = ToolRulesSolver(tool_rules=[rule, another_tool_rules]) solver.register_tool_call(START_TOOL) solver.register_tool_call(START_TOOL) assert sorted(solver.get_allowed_tool_names({START_TOOL, NEXT_TOOL, HELPER_TOOL})) == sorted([NEXT_TOOL, HELPER_TOOL]), ( "Other tools should still be allowed even if 'start_tool' is over limit" ) def test_max_count_per_step_tool_rule_resets_on_clear(): """Ensure clearing tool history resets the rule's limit.""" rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) solver = ToolRulesSolver(tool_rules=[rule]) solver.register_tool_call(START_TOOL) solver.register_tool_call(START_TOOL) assert solver.get_allowed_tool_names({START_TOOL}, error_on_empty=False) == [], "Should not allow 'start_tool' after reaching limit" solver.clear_tool_history() assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should allow 'start_tool' again after clearing history" def test_tool_rule_equality_and_hashing(): """Test __eq__ and __hash__ methods for all tool rule types.""" # test InitToolRule equality rule1 = InitToolRule(tool_name="test_tool") rule2 = InitToolRule(tool_name="test_tool") rule3 = InitToolRule(tool_name="different_tool") assert rule1 == rule2, "InitToolRules with same tool_name should be equal" assert rule1 != rule3, "InitToolRules with different tool_name should not be equal" assert hash(rule1) == hash(rule2), "Equal InitToolRules should have same hash" assert hash(rule1) != hash(rule3), "Different InitToolRules should have different hash" # test ChildToolRule equality child_rule1 = ChildToolRule(tool_name="parent", children=["child1", "child2"]) child_rule2 = ChildToolRule(tool_name="parent", children=["child2", "child1"]) # different order child_rule3 = ChildToolRule(tool_name="parent", children=["child1"]) child_rule4 = ChildToolRule(tool_name="different_parent", children=["child1", "child2"]) assert child_rule1 == child_rule2, "ChildToolRules with same children (different order) should be equal" assert child_rule1 != child_rule3, "ChildToolRules with different children should not be equal" assert child_rule1 != child_rule4, "ChildToolRules with different tool_name should not be equal" assert hash(child_rule1) == hash(child_rule2), "Equal ChildToolRules should have same hash" assert hash(child_rule1) != hash(child_rule3), "Different ChildToolRules should have different hash" # test ConditionalToolRule equality cond_rule1 = ConditionalToolRule( tool_name="conditional", child_output_mapping={"yes": "tool1", "no": "tool2"}, default_child="tool3", require_output_mapping=True ) cond_rule2 = ConditionalToolRule( tool_name="conditional", child_output_mapping={"no": "tool2", "yes": "tool1"}, # different order default_child="tool3", require_output_mapping=True, ) cond_rule3 = ConditionalToolRule( tool_name="conditional", child_output_mapping={"yes": "tool1", "no": "tool2"}, default_child="different_tool", require_output_mapping=True, ) cond_rule4 = ConditionalToolRule( tool_name="conditional", child_output_mapping={"yes": "tool1", "no": "tool2"}, default_child="tool3", require_output_mapping=False, # different require_output_mapping ) assert cond_rule1 == cond_rule2, "ConditionalToolRules with same mapping (different order) should be equal" assert cond_rule1 != cond_rule3, "ConditionalToolRules with different default_child should not be equal" assert cond_rule1 != cond_rule4, "ConditionalToolRules with different require_output_mapping should not be equal" assert hash(cond_rule1) == hash(cond_rule2), "Equal ConditionalToolRules should have same hash" assert hash(cond_rule1) != hash(cond_rule3), "Different ConditionalToolRules should have different hash" # test MaxCountPerStepToolRule equality max_rule1 = MaxCountPerStepToolRule(tool_name="limited_tool", max_count_limit=3) max_rule2 = MaxCountPerStepToolRule(tool_name="limited_tool", max_count_limit=3) max_rule3 = MaxCountPerStepToolRule(tool_name="limited_tool", max_count_limit=5) max_rule4 = MaxCountPerStepToolRule(tool_name="different_tool", max_count_limit=3) assert max_rule1 == max_rule2, "MaxCountPerStepToolRules with same limit should be equal" assert max_rule1 != max_rule3, "MaxCountPerStepToolRules with different limit should not be equal" assert max_rule1 != max_rule4, "MaxCountPerStepToolRules with different tool_name should not be equal" assert hash(max_rule1) == hash(max_rule2), "Equal MaxCountPerStepToolRules should have same hash" assert hash(max_rule1) != hash(max_rule3), "Different MaxCountPerStepToolRules should have different hash" # test TerminalToolRule equality term_rule1 = TerminalToolRule(tool_name="exit_tool") term_rule2 = TerminalToolRule(tool_name="exit_tool") term_rule3 = TerminalToolRule(tool_name="different_exit_tool") assert term_rule1 == term_rule2, "TerminalToolRules with same tool_name should be equal" assert term_rule1 != term_rule3, "TerminalToolRules with different tool_name should not be equal" assert hash(term_rule1) == hash(term_rule2), "Equal TerminalToolRules should have same hash" # test RequiredBeforeExitToolRule equality req_rule1 = RequiredBeforeExitToolRule(tool_name="required_tool") req_rule2 = RequiredBeforeExitToolRule(tool_name="required_tool") req_rule3 = RequiredBeforeExitToolRule(tool_name="different_required_tool") assert req_rule1 == req_rule2, "RequiredBeforeExitToolRules with same tool_name should be equal" assert req_rule1 != req_rule3, "RequiredBeforeExitToolRules with different tool_name should not be equal" assert hash(req_rule1) == hash(req_rule2), "Equal RequiredBeforeExitToolRules should have same hash" # test cross-type inequality assert rule1 != child_rule1, "Different rule types should never be equal" assert child_rule1 != cond_rule1, "Different rule types should never be equal" assert max_rule1 != term_rule1, "Different rule types should never be equal" def test_tool_rule_deduplication_in_set(): """Test that duplicate tool rules are properly deduplicated when used in sets.""" # create duplicate rules rule1 = InitToolRule(tool_name="start") rule2 = InitToolRule(tool_name="start") # duplicate rule3 = InitToolRule(tool_name="different_start") child1 = ChildToolRule(tool_name="parent", children=["a", "b"]) child2 = ChildToolRule(tool_name="parent", children=["b", "a"]) # duplicate (different order) child3 = ChildToolRule(tool_name="parent", children=["a", "b", "c"]) # different max1 = MaxCountPerStepToolRule(tool_name="limited", max_count_limit=2) max2 = MaxCountPerStepToolRule(tool_name="limited", max_count_limit=2) # duplicate max3 = MaxCountPerStepToolRule(tool_name="limited", max_count_limit=3) # different # test set deduplication rules_set = {rule1, rule2, rule3, child1, child2, child3, max1, max2, max3} assert len(rules_set) == 6, "Set should contain only unique rules" # test list deduplication using dict.fromkeys rules_list = [rule1, rule2, rule3, child1, child2, child3, max1, max2, max3] deduplicated = list(dict.fromkeys(rules_list)) assert len(deduplicated) == 6, "dict.fromkeys should deduplicate rules" assert deduplicated[0] == rule1, "Order should be preserved" assert deduplicated[1] == rule3, "Order should be preserved" assert deduplicated[2] == child1, "Order should be preserved" assert deduplicated[3] == child3, "Order should be preserved" assert deduplicated[4] == max1, "Order should be preserved" assert deduplicated[5] == max3, "Order should be preserved" def test_parent_tool_rule_equality(): """Test ParentToolRule equality and hashing.""" parent_rule1 = ParentToolRule(tool_name="parent", children=["child1", "child2"]) parent_rule2 = ParentToolRule(tool_name="parent", children=["child2", "child1"]) # different order parent_rule3 = ParentToolRule(tool_name="parent", children=["child1"]) parent_rule4 = ParentToolRule(tool_name="different_parent", children=["child1", "child2"]) assert parent_rule1 == parent_rule2, "ParentToolRules with same children (different order) should be equal" assert parent_rule1 != parent_rule3, "ParentToolRules with different children should not be equal" assert parent_rule1 != parent_rule4, "ParentToolRules with different tool_name should not be equal" assert hash(parent_rule1) == hash(parent_rule2), "Equal ParentToolRules should have same hash" assert hash(parent_rule1) != hash(parent_rule3), "Different ParentToolRules should have different hash" def test_continue_tool_rule_equality_and_hashing(): r1 = ContinueToolRule(tool_name="go_on") r2 = ContinueToolRule(tool_name="go_on") r3 = ContinueToolRule(tool_name="different") assert r1 == r2 assert hash(r1) == hash(r2) assert r1 != r3 assert hash(r1) != hash(r3) @pytest.mark.parametrize( "rule_factory, kwargs_a, kwargs_b", [ (lambda **kw: InitToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), (lambda **kw: TerminalToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), (lambda **kw: ContinueToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), (lambda **kw: RequiredBeforeExitToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), (lambda **kw: MaxCountPerStepToolRule(**kw), dict(tool_name="t", max_count_limit=2), dict(tool_name="t", max_count_limit=2)), (lambda **kw: ChildToolRule(**kw), dict(tool_name="t", children=["a", "b"]), dict(tool_name="t", children=["a", "b"])), (lambda **kw: ParentToolRule(**kw), dict(tool_name="t", children=["a", "b"]), dict(tool_name="t", children=["a", "b"])), ( lambda **kw: ConditionalToolRule(**kw), dict(tool_name="t", child_output_mapping={"x": "a"}, default_child=None, require_output_mapping=False), dict(tool_name="t", child_output_mapping={"x": "a"}, default_child=None, require_output_mapping=False), ), ], ) def test_prompt_template_ignored(rule_factory, kwargs_a, kwargs_b): r1 = rule_factory(**kwargs_a, prompt_template="A") r2 = rule_factory(**kwargs_b, prompt_template="B") assert r1 == r2, f"{type(r1).__name__} should ignore prompt_template in equality" assert hash(r1) == hash(r2), f"{type(r1).__name__} should ignore prompt_template in hash" @pytest.mark.parametrize( "a,b", [ (InitToolRule(tool_name="same"), TerminalToolRule(tool_name="same")), (ContinueToolRule(tool_name="same"), RequiredBeforeExitToolRule(tool_name="same")), (ChildToolRule(tool_name="same", children=["x"]), ParentToolRule(tool_name="same", children=["x"])), ], ) def test_cross_type_hash_distinguishes_types(a, b): assert a != b assert hash(a) != hash(b) @pytest.mark.parametrize( "rule", [ InitToolRule(tool_name="x"), TerminalToolRule(tool_name="x"), ContinueToolRule(tool_name="x"), RequiredBeforeExitToolRule(tool_name="x"), MaxCountPerStepToolRule(tool_name="x", max_count_limit=1), ChildToolRule(tool_name="x", children=["a"]), ParentToolRule(tool_name="x", children=["a"]), ConditionalToolRule(tool_name="x", child_output_mapping={"k": "a"}, default_child=None, require_output_mapping=False), ], ) def test_equality_with_non_rule_objects(rule): assert rule != object() assert rule != None def test_conditional_tool_rule_mapping_order_and_hash(): r1 = ConditionalToolRule( tool_name="cond", child_output_mapping={"yes": "tool1", "no": "tool2"}, default_child="tool3", require_output_mapping=True ) r2 = ConditionalToolRule( tool_name="cond", child_output_mapping={"no": "tool2", "yes": "tool1"}, default_child="tool3", require_output_mapping=True ) assert r1 == r2 assert hash(r1) == hash(r2) def test_conditional_tool_rule_mapping_numeric_and_bool_keys_equivalence_current_behavior(): # NOTE: Python dict equality treats True == 1 and 1 == 1.0 as equal keys. # This test documents current behavior of __eq__ on mapping equality. r_bool = ConditionalToolRule(tool_name="cond", child_output_mapping={True: "A"}, default_child=None, require_output_mapping=False) r_int = ConditionalToolRule(tool_name="cond", child_output_mapping={1: "A"}, default_child=None, require_output_mapping=False) r_float = ConditionalToolRule(tool_name="cond", child_output_mapping={1.0: "A"}, default_child=None, require_output_mapping=False) # Document current semantics: these are equal under Python's dict equality. assert r_bool == r_int assert r_int == r_float assert hash(r_bool) == hash(r_int) == hash(r_float) def test_conditional_tool_rule_mapping_string_vs_numeric_not_equal(): r_num = ConditionalToolRule(tool_name="cond", child_output_mapping={1: "A"}, default_child=None, require_output_mapping=False) r_str = ConditionalToolRule(tool_name="cond", child_output_mapping={"1": "A"}, default_child=None, require_output_mapping=False) assert r_num != r_str assert hash(r_num) != hash(r_str) def test_child_and_parent_order_invariance_multiple_permutations(): pass # permute a few ways variants = [ ["a", "b", "c"], ["b", "c", "a"], ["c", "a", "b"], ] child_rules = [ChildToolRule(tool_name="t", children=ch) for ch in variants] parent_rules = [ParentToolRule(tool_name="t", children=ch) for ch in variants] # All child rules equal and same hash for r in child_rules[1:]: assert child_rules[0] == r assert hash(child_rules[0]) == hash(r) # All parent rules equal and same hash for r in parent_rules[1:]: assert parent_rules[0] == r assert hash(parent_rules[0]) == hash(r) def test_conditional_order_invariance_multiple_permutations(): maps = [ {"x": "a", "y": "b", "z": "c"}, {"z": "c", "y": "b", "x": "a"}, {"y": "b", "x": "a", "z": "c"}, ] rules = [ConditionalToolRule(tool_name="t", child_output_mapping=m, default_child=None, require_output_mapping=False) for m in maps] for r in rules[1:]: assert rules[0] == r assert hash(rules[0]) == hash(r) # ---------- 7) Dict/dedup across all types including ContinueToolRule ---------- def test_dedup_in_set_with_continue_and_required_and_terminal(): s = { ContinueToolRule(tool_name="x"), ContinueToolRule(tool_name="x"), # dup RequiredBeforeExitToolRule(tool_name="y"), RequiredBeforeExitToolRule(tool_name="y"), # dup TerminalToolRule(tool_name="z"), TerminalToolRule(tool_name="z"), # dup } assert len(s) == 3 def test_required_before_exit_tool_rule_has_required_tools_been_called(): """Test has_required_tools_been_called() with no required tools.""" solver = ToolRulesSolver(tool_rules=[]) assert solver.has_required_tools_been_called(set()) is True, "Should return True when no required tools are defined" def test_required_before_exit_tool_rule_single_required_tool(): """Test with a single required-before-exit tool.""" required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) solver = ToolRulesSolver(tool_rules=[required_rule]) assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False when required tool hasn't been called" assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should return list with uncalled required tool" solver.register_tool_call(SAVE_TOOL) assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called" assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called" def test_required_before_exit_tool_rule_multiple_required_tools(): """Test with multiple required-before-exit tools.""" required_rule_1 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_1) required_rule_2 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_2) solver = ToolRulesSolver(tool_rules=[required_rule_1, required_rule_2]) assert solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False, ( "Should return False when no required tools have been called" ) uncalled_tools = solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) assert set(uncalled_tools) == {REQUIRED_TOOL_1, REQUIRED_TOOL_2}, "Should return both uncalled required tools" # Call first required tool solver.register_tool_call(REQUIRED_TOOL_1) assert solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False, ( "Should return False when only one required tool has been called" ) assert solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == [REQUIRED_TOOL_2], ( "Should return remaining uncalled required tool" ) # Call second required tool solver.register_tool_call(REQUIRED_TOOL_2) assert solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is True, ( "Should return True when all required tools have been called" ) assert solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == [], ( "Should return empty list when all required tools have been called" ) def test_required_before_exit_tool_rule_mixed_with_other_tools(): """Test required-before-exit tools mixed with other tool calls.""" required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) solver = ToolRulesSolver(tool_rules=[required_rule]) # Call other tools first solver.register_tool_call(START_TOOL) solver.register_tool_call(HELPER_TOOL) assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False even after calling other tools" assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should still show required tool as uncalled" # Call required tool solver.register_tool_call(SAVE_TOOL) assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called" assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called" def test_required_before_exit_tool_rule_clear_history(): """Test that clearing history resets the required tools state.""" required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) solver = ToolRulesSolver(tool_rules=[required_rule]) # Call required tool solver.register_tool_call(SAVE_TOOL) assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called" # Clear history solver.clear_tool_history() assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False after clearing history" assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should show required tool as uncalled after clearing history" def test_should_force_tool_call_no_rules(): """Test should_force_tool_call with no tool rules.""" solver = ToolRulesSolver(tool_rules=[]) assert solver.should_force_tool_call() is False, "Should return False when no tool rules are present" def test_should_force_tool_call_init_rule_no_history(): """Test should_force_tool_call with InitToolRule and no history.""" init_rule = InitToolRule(tool_name=START_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule]) assert solver.should_force_tool_call() is True, "Should return True when InitToolRule is present and no history" def test_should_force_tool_call_init_rule_after_first_call(): """Test should_force_tool_call with InitToolRule after first tool call.""" init_rule = InitToolRule(tool_name=START_TOOL) solver = ToolRulesSolver(tool_rules=[init_rule]) solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is False, "Should return False after first tool call" def test_should_force_tool_call_child_rule_active(): """Test should_force_tool_call when ChildToolRule is active.""" child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) solver = ToolRulesSolver(tool_rules=[child_rule]) solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is True, "Should return True when last tool matches ChildToolRule" def test_should_force_tool_call_child_rule_inactive(): """Test should_force_tool_call when ChildToolRule is not active.""" child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) solver = ToolRulesSolver(tool_rules=[child_rule]) solver.register_tool_call(HELPER_TOOL) assert solver.should_force_tool_call() is False, "Should return False when last tool doesn't match ChildToolRule" def test_should_force_tool_call_conditional_rule_active(): """Test should_force_tool_call when ConditionalToolRule is active.""" conditional_rule = ConditionalToolRule( tool_name=START_TOOL, child_output_mapping={True: END_TOOL, False: NEXT_TOOL}, default_child=None ) solver = ToolRulesSolver(tool_rules=[conditional_rule]) solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is True, "Should return True when last tool matches ConditionalToolRule" def test_should_force_tool_call_parent_rule_active(): """Test should_force_tool_call when ParentToolRule is active.""" parent_rule = ParentToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) solver = ToolRulesSolver(tool_rules=[parent_rule]) solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is True, "Should return True when last tool matches ParentToolRule" def test_should_force_tool_call_max_count_rule(): """Test should_force_tool_call with MaxCountPerStepToolRule (non-constraining).""" max_count_rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) solver = ToolRulesSolver(tool_rules=[max_count_rule]) solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is False, "Should return False for MaxCountPerStepToolRule (not a constraining rule)" def test_should_force_tool_call_terminal_rule(): """Test should_force_tool_call with TerminalToolRule.""" terminal_rule = TerminalToolRule(tool_name=END_TOOL) solver = ToolRulesSolver(tool_rules=[terminal_rule]) solver.register_tool_call(END_TOOL) assert solver.should_force_tool_call() is False, "Should return False for TerminalToolRule" def test_should_force_tool_call_continue_rule(): """Test should_force_tool_call with ContinueToolRule.""" continue_rule = ContinueToolRule(tool_name=NEXT_TOOL) solver = ToolRulesSolver(tool_rules=[continue_rule]) solver.register_tool_call(NEXT_TOOL) assert solver.should_force_tool_call() is False, "Should return False for ContinueToolRule" def test_should_force_tool_call_required_before_exit_rule(): """Test should_force_tool_call with RequiredBeforeExitToolRule.""" required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) solver = ToolRulesSolver(tool_rules=[required_rule]) solver.register_tool_call(SAVE_TOOL) assert solver.should_force_tool_call() is False, "Should return False for RequiredBeforeExitToolRule" def test_should_force_tool_call_requires_approval_rule(): """Test should_force_tool_call with RequiresApprovalToolRule.""" approval_rule = RequiresApprovalToolRule(tool_name=REQUIRES_APPROVAL_TOOL) solver = ToolRulesSolver(tool_rules=[approval_rule]) solver.register_tool_call(REQUIRES_APPROVAL_TOOL) assert solver.should_force_tool_call() is False, "Should return False for RequiresApprovalToolRule" def test_should_force_tool_call_multiple_constrained_rules_one_active(): """Test should_force_tool_call with multiple constrained rules where one is active.""" child_rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) child_rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL]) parent_rule = ParentToolRule(tool_name=PREP_TOOL, children=[HELPER_TOOL]) solver = ToolRulesSolver(tool_rules=[child_rule_1, child_rule_2, parent_rule]) solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is True, "Should return True when one constrained rule is active" solver.register_tool_call(NEXT_TOOL) assert solver.should_force_tool_call() is True, "Should return True when a different constrained rule becomes active" solver.register_tool_call(FINAL_TOOL) assert solver.should_force_tool_call() is False, "Should return False when no constrained rules are active" def test_should_force_tool_call_after_clear_with_init_rule(): """Test should_force_tool_call after clearing history with InitToolRule.""" init_rule = InitToolRule(tool_name=START_TOOL) child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) solver = ToolRulesSolver(tool_rules=[init_rule, child_rule]) assert solver.should_force_tool_call() is True, "Should return True initially with InitToolRule" solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is True, "Should return True when ChildToolRule is active" solver.clear_tool_history() assert solver.should_force_tool_call() is True, "Should return True again after clearing history with InitToolRule" def test_should_force_tool_call_mixed_rules(): """Test should_force_tool_call with a mix of constraining and non-constraining rules.""" init_rule = InitToolRule(tool_name=START_TOOL) child_rule = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) terminal_rule = TerminalToolRule(tool_name=END_TOOL) continue_rule = ContinueToolRule(tool_name=HELPER_TOOL) max_count_rule = MaxCountPerStepToolRule(tool_name=NEXT_TOOL, max_count_limit=2) solver = ToolRulesSolver(tool_rules=[init_rule, child_rule, terminal_rule, continue_rule, max_count_rule]) assert solver.should_force_tool_call() is True, "Should return True with InitToolRule at start" solver.register_tool_call(START_TOOL) assert solver.should_force_tool_call() is True, "Should return True when ChildToolRule is active" solver.register_tool_call(NEXT_TOOL) assert solver.should_force_tool_call() is False, "Should return False when no constraining rules are active" def make_tool(name: str, properties: dict) -> Tool: """Helper to build a minimal custom Tool with a JSON schema.""" return Tool( name=name, tool_type=ToolType.CUSTOM, json_schema={ "name": name, "parameters": { "type": "object", "properties": properties, "required": [], "additionalProperties": False, }, }, ) def test_init_rule_args_are_cached_in_solver(): solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"x": 1, "y": "s"})]) allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) assert set(allowed) == {"alpha"} # Cached mappings assert solver.last_prefilled_args_by_tool == {"alpha": {"x": 1, "y": "s"}} assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)" def test_cached_provenance_format(): solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="tool_one", args={"a": 123})]) _ = solver.get_allowed_tool_names(available_tools={"tool_one"}) prov = solver.last_prefilled_args_provenance.get("tool_one") assert prov.startswith("InitToolRule(") and prov.endswith(")") and "tool_one" in prov def test_cache_empty_when_no_args(): solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha")]) allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) assert set(allowed) == {"alpha"} assert solver.last_prefilled_args_by_tool == {} assert solver.last_prefilled_args_provenance == {} def test_cache_recomputed_on_next_call(): # First call caches args for init tool solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"p": 5})]) _ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) assert solver.last_prefilled_args_by_tool == {"alpha": {"p": 5}} # After a tool call, init rules no longer apply; next computation should clear caches solver.register_tool_call("alpha") _ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) assert solver.last_prefilled_args_by_tool == {} assert solver.last_prefilled_args_provenance == {} def test_merge_and_validate_prefilled_args_overrides_llm_values(): tool = make_tool("my_tool", properties={"a": {"type": "integer"}, "b": {"type": "string"}}) llm_args = {"a": 1, "b": "hello"} prefilled = {"a": 42} merged = merge_and_validate_prefilled_args(tool, llm_args, prefilled) assert merged == {"a": 42, "b": "hello"} def test_merge_and_validate_prefilled_args_type_validation(): tool = make_tool("typed_tool", properties={"a": {"type": "integer"}}) llm_args = {"a": 1} prefilled = {"a": "not-an-int"} with pytest.raises(ValueError) as ei: _ = merge_and_validate_prefilled_args(tool, llm_args, prefilled) assert "Invalid value for 'a'" in str(ei.value) assert "integer" in str(ei.value) def test_merge_and_validate_prefilled_args_unknown_key_fails(): tool = make_tool("limited_tool", properties={"a": {"type": "integer"}}) with pytest.raises(ValueError) as ei: _ = merge_and_validate_prefilled_args(tool, llm_args={}, prefilled_args={"z": 3}) assert "Unknown argument 'z'" in str(ei.value) def test_merge_and_validate_prefilled_args_enum_const_anyof_oneof(): tool = make_tool( "rich_tool", properties={ "c": {"enum": ["x", "y"]}, "d": {"const": 5}, "e": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "f": {"oneOf": [{"type": "string"}, {"type": "integer"}]}, "g": {"type": "number"}, }, ) # Valid cases merged = merge_and_validate_prefilled_args(tool, {}, {"c": "x"}) assert merged["c"] == "x" merged = merge_and_validate_prefilled_args(tool, {}, {"d": 5}) assert merged["d"] == 5 merged = merge_and_validate_prefilled_args(tool, {}, {"e": 7}) assert merged["e"] == 7 merged = merge_and_validate_prefilled_args(tool, {}, {"f": "hello"}) assert merged["f"] == "hello" merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3.14}) assert merged["g"] == 3.14 merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3}) assert merged["g"] == 3 # Invalid cases with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"c": "z"}) # enum fail with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"d": 6}) # const fail with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"e": []}) # anyOf none match with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"f": []}) # oneOf none match with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"g": True}) # bool not a number def test_merge_and_validate_prefilled_args_union_with_null(): tool = make_tool("union_tool", properties={"h": {"type": ["string", "null"]}}) merged = merge_and_validate_prefilled_args(tool, {}, {"h": None}) assert "h" in merged and merged["h"] is None merged = merge_and_validate_prefilled_args(tool, {}, {"h": "ok"}) assert merged["h"] == "ok" with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"h": 5}) def test_merge_and_validate_prefilled_args_object_and_array_types(): tool = make_tool( "container_tool", properties={ "obj": {"type": "object"}, "arr": {"type": "array"}, }, ) merged = merge_and_validate_prefilled_args(tool, {}, {"obj": {"k": 1}}) assert merged["obj"] == {"k": 1} merged = merge_and_validate_prefilled_args(tool, {}, {"arr": [1, 2, 3]}) assert merged["arr"] == [1, 2, 3] with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"obj": "nope"}) with pytest.raises(ValueError): _ = merge_and_validate_prefilled_args(tool, {}, {"arr": {}}) def test_multiple_rules_args_last_write_wins_and_provenance(): # Two init rules for the same tool; the latter should overwrite overlapping keys and provenance r1 = InitToolRule(tool_name="alpha", args={"x": 1, "y": "first"}) r2 = InitToolRule(tool_name="alpha", args={"y": "second", "z": True}) solver = ToolRulesSolver(tool_rules=[r1, r2]) allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) assert set(allowed) == {"alpha"} assert solver.last_prefilled_args_by_tool["alpha"] == {"x": 1, "y": "second", "z": True} assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)" def test_child_rule_args_cached_only_when_parent_last_tool(): # Child with args and one without rule = ChildToolRule( tool_name="parent", children=["child_a", "child_b"], child_arg_nodes=[ToolCallNode(name="child_a", args={"x": 1})], ) solver = ToolRulesSolver(tool_rules=[rule]) # Before parent call, child args should not be cached allowed = solver.get_allowed_tool_names(available_tools={"parent", "child_a", "child_b"}) assert set(allowed) == {"parent", "child_a", "child_b"} assert solver.last_prefilled_args_by_tool == {} # After parent is last tool, cache should include child_a's args solver.register_tool_call("parent") allowed = solver.get_allowed_tool_names(available_tools={"parent", "child_a", "child_b"}) assert set(allowed) == {"child_a", "child_b"} assert solver.last_prefilled_args_by_tool.get("child_a") == {"x": 1} assert solver.last_prefilled_args_provenance.get("child_a") == "ChildToolRule(parent->child_a)" def test_init_then_child_args_applied_in_correct_phases(): # Init provides args for alpha; child provides args for beta init = InitToolRule(tool_name="alpha", args={"seed": "A"}) child = ChildToolRule( tool_name="alpha", children=["beta"], child_arg_nodes=[ToolCallNode(name="beta", args={"k": 1})], ) solver = ToolRulesSolver(tool_rules=[init, child]) # Phase 1: start — init args apply allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) assert set(allowed) == {"alpha"} assert solver.last_prefilled_args_by_tool == {"alpha": {"seed": "A"}} # Phase 2: after alpha executed — child args apply solver.register_tool_call("alpha") allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) assert set(allowed) == {"beta"} assert solver.last_prefilled_args_by_tool == {"beta": {"k": 1}} def test_multi_child_rules_last_write_wins_for_same_child(): # Two ChildToolRules for the same parent/child; second overrides overlapping keys child1 = ChildToolRule( tool_name="p", children=["c"], child_arg_nodes=[ToolCallNode(name="c", args={"x": 1, "y": "a"})], ) child2 = ChildToolRule( tool_name="p", children=["c"], child_arg_nodes=[ToolCallNode(name="c", args={"y": "b", "z": 3})], ) solver = ToolRulesSolver(tool_rules=[child1, child2]) solver.register_tool_call("p") allowed = solver.get_allowed_tool_names(available_tools={"p", "c"}) assert set(allowed) == {"c"} assert solver.last_prefilled_args_by_tool["c"] == {"x": 1, "y": "b", "z": 3} # Provenance reflects the last write source assert solver.last_prefilled_args_provenance.get("c") == "ChildToolRule(p->c)" def test_child_args_only_for_allowed_children(): # Provide args for two children, but restrict available_tools to one child rule = ChildToolRule( tool_name="p", children=["allowed", "blocked"], child_arg_nodes=[ ToolCallNode(name="allowed", args={"a": 1}), ToolCallNode(name="blocked", args={"b": 2}), ], ) solver = ToolRulesSolver(tool_rules=[rule]) solver.register_tool_call("p") allowed = solver.get_allowed_tool_names(available_tools={"allowed"}) assert set(allowed) == {"allowed"} assert solver.last_prefilled_args_by_tool == {"allowed": {"a": 1}} assert "blocked" not in solver.last_prefilled_args_by_tool def test_child_args_intersection_with_conditional_mapping(): # Child list has args for both, ConditionalToolRule limits to one based on output child = ChildToolRule( tool_name="decider", children=["c1", "c2"], child_arg_nodes=[ToolCallNode(name="c1", args={"x": 10}), ToolCallNode(name="c2", args={"y": 20})], ) cond = ConditionalToolRule( tool_name="decider", default_child=None, child_output_mapping={True: "c2", False: "c1"}, require_output_mapping=True, ) solver = ToolRulesSolver(tool_rules=[child, cond]) solver.register_tool_call("decider") allowed = solver.get_allowed_tool_names(available_tools={"c1", "c2"}, last_function_response='{"message": "true"}') assert set(allowed) == {"c2"} assert solver.last_prefilled_args_by_tool == {"c2": {"y": 20}} def test_child_rule_prefilled_complex_args_validation_success(): # Define complex child args with multiple JSON schema types complex_args = { "obj": {"k": 1, "nest": {"a": 2}}, "arr": [1, 2, 3], "union": None, # string | null "any": "text", # anyOf string|integer "one": 42, # oneOf string|integer "num": 3.5, "flag": True, "str": "hello", } rule = ChildToolRule( tool_name="p", children=["complex_child"], child_arg_nodes=[ToolCallNode(name="complex_child", args=complex_args)], ) solver = ToolRulesSolver(tool_rules=[rule]) solver.register_tool_call("p") allowed = solver.get_allowed_tool_names(available_tools={"complex_child"}) assert set(allowed) == {"complex_child"} assert solver.last_prefilled_args_by_tool.get("complex_child") == complex_args # Validate and merge against a tool schema with matching types properties = { "obj": {"type": "object"}, "arr": {"type": "array"}, "union": {"type": ["string", "null"]}, "any": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "one": {"oneOf": [{"type": "string"}, {"type": "integer"}]}, "num": {"type": "number"}, "flag": {"type": "boolean"}, "str": {"type": "string"}, } tool = make_tool("complex_child", properties) # LLM suggests competing values; prefilled should override llm_args = {"str": "fake", "num": 7, "extra": "ignored"} merged = merge_and_validate_prefilled_args(tool, llm_args, complex_args) for k, v in complex_args.items(): assert merged[k] == v assert merged.get("extra") == "ignored" # untouched by prefill validation def test_child_rule_prefilled_complex_args_validation_fail(): # Provide intentionally bad types for several keys bad_args = { "obj": "not-an-object", # should be object "arr": {"not": "an array"}, # should be array "union": 5, # should be string|null "any": [], # anyOf string|integer "one": [], # oneOf string|integer "num": True, # bool is not accepted as number "flag": "yes", # should be boolean "str": 123, # should be string } rule = ChildToolRule( tool_name="p", children=["complex_child"], child_arg_nodes=[ToolCallNode(name="complex_child", args=bad_args)], ) solver = ToolRulesSolver(tool_rules=[rule]) solver.register_tool_call("p") _ = solver.get_allowed_tool_names(available_tools={"complex_child"}) assert solver.last_prefilled_args_by_tool.get("complex_child") == bad_args properties = { "obj": {"type": "object"}, "arr": {"type": "array"}, "union": {"type": ["string", "null"]}, "any": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, "one": {"oneOf": [{"type": "string"}, {"type": "integer"}]}, "num": {"type": "number"}, "flag": {"type": "boolean"}, "str": {"type": "string"}, } tool = make_tool("complex_child", properties) with pytest.raises(ValueError) as ei: _ = merge_and_validate_prefilled_args(tool, llm_args={}, prefilled_args=bad_args) msg = str(ei.value) # Spot-check a few failures assert "Unknown argument" not in msg # keys exist assert "Invalid value" in msg def test_child_tool_rule_validation_unknown_child_in_arg_nodes(): """ChildToolRule should reject child_arg_nodes that reference names not in children.""" with pytest.raises(ValueError) as ei: _ = ChildToolRule( tool_name="parent", children=["known_child"], child_arg_nodes=[ToolCallNode(name="unknown_child", args={"x": 1})], ) assert "not in children" in str(ei.value) def test_child_tool_rule_validation_args_type_enforced(): """ToolCallNode.args must be a dict when present; otherwise Pydantic should raise.""" with pytest.raises(Exception) as ei: _ = ChildToolRule( tool_name="p", children=["c1"], child_arg_nodes=[ToolCallNode(name="c1", args="not-a-dict")], # type: ignore[arg-type] ) # Pydantic should raise a validation error about args type assert "dict" in str(ei.value) or "dictionary" in str(ei.value) def test_child_tool_rule_validation_accepts_valid_nodes(): """A valid ChildToolRule with matching child and typed arg node should construct cleanly.""" rule = ChildToolRule( tool_name="p", children=["c1", "c2"], child_arg_nodes=[ToolCallNode(name="c2", args={"k": 1})], ) assert isinstance(rule, ChildToolRule)