feat: Add ConditionalToolRules (#2279)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -2,7 +2,12 @@ import pytest
|
||||
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.tool_rule_solver import ToolRuleValidationError
|
||||
from letta.schemas.tool_rule import ChildToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
InitToolRule,
|
||||
TerminalToolRule
|
||||
)
|
||||
|
||||
# Constants for tool names used in the tests
|
||||
START_TOOL = "start_tool"
|
||||
@@ -60,7 +65,7 @@ def test_get_allowed_tool_names_no_matching_rule_warning():
|
||||
# Action: Set last tool to an unrecognized tool and check warnings
|
||||
solver.update_tool_usage(UNRECOGNIZED_TOOL)
|
||||
|
||||
# NOTE: removed for now since this warning is getting triggered on every LLM call
|
||||
# # NOTE: removed for now since this warning is getting triggered on every LLM call
|
||||
# with warnings.catch_warnings(record=True) as w:
|
||||
# allowed_tools = solver.get_allowed_tool_names()
|
||||
|
||||
@@ -75,9 +80,9 @@ def test_get_allowed_tool_names_no_matching_rule_error():
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[])
|
||||
|
||||
# Action & Assert: Set last tool to an unrecognized tool and expect RuntimeError when error_on_empty=True
|
||||
# Action & Assert: Set last tool to an unrecognized tool and expect ValueError
|
||||
solver.update_tool_usage(UNRECOGNIZED_TOOL)
|
||||
with pytest.raises(RuntimeError, match="resolved to no more possible tool calls"):
|
||||
with pytest.raises(ValueError, match=f"No tool rule found for {UNRECOGNIZED_TOOL}"):
|
||||
solver.get_allowed_tool_names(error_on_empty=True)
|
||||
|
||||
|
||||
@@ -104,7 +109,46 @@ def test_update_tool_usage_and_get_allowed_tool_names_combined():
|
||||
assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal"
|
||||
|
||||
|
||||
def test_tool_rules_with_cycle_detection():
|
||||
def test_conditional_tool_rule():
|
||||
# Setup: Define a conditional tool rule
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
rule = ConditionalToolRule(
|
||||
tool_name=START_TOOL,
|
||||
default_child=None,
|
||||
child_output_mapping={True: END_TOOL, False: START_TOOL}
|
||||
)
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])
|
||||
|
||||
# Action & Assert: Verify the rule properties
|
||||
# Step 1: Initially allowed tools
|
||||
assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'"
|
||||
|
||||
# Step 2: After using 'start_tool'
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [END_TOOL], "After 'start_tool' returns true, should allow 'end_tool'"
|
||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [START_TOOL], "After 'start_tool' returns false, should allow 'start_tool'"
|
||||
|
||||
# Step 3: After using 'end_tool'
|
||||
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"
|
||||
|
||||
|
||||
def test_invalid_conditional_tool_rule():
|
||||
# Setup: Define an invalid conditional tool rule
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
invalid_rule_1 = ConditionalToolRule(
|
||||
tool_name=START_TOOL,
|
||||
default_child=END_TOOL,
|
||||
child_output_mapping={}
|
||||
)
|
||||
|
||||
# Test 1: Missing child output mapping
|
||||
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):
|
||||
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule])
|
||||
|
||||
|
||||
def test_tool_rules_with_invalid_path():
|
||||
# Setup: Define tool rules with both connected, disconnected nodes and a cycle
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
|
||||
@@ -113,15 +157,12 @@ def test_tool_rules_with_cycle_detection():
|
||||
rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
|
||||
# Action & Assert: Attempt to create the ToolRulesSolver with a cycle should raise ValidationError
|
||||
with pytest.raises(ToolRuleValidationError, match="Tool rules contain cycles"):
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule])
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule])
|
||||
|
||||
# Extra setup: Define tool rules without a cycle but with hanging nodes
|
||||
rule_5 = ChildToolRule(tool_name=PREP_TOOL, children=[FINAL_TOOL]) # Hanging node with no connection to start_tool
|
||||
|
||||
# Assert that a configuration without cycles does not raise an error
|
||||
try:
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_4, rule_5, terminal_rule])
|
||||
except ToolRuleValidationError:
|
||||
pytest.fail("ToolRulesSolver raised ValidationError unexpectedly on a valid DAG with hanging nodes")
|
||||
# Now: add a path from the start tool to the final tool
|
||||
rule_5 = ConditionalToolRule(
|
||||
tool_name=HELPER_TOOL,
|
||||
default_child=FINAL_TOOL,
|
||||
child_output_mapping={True: START_TOOL, False: FINAL_TOOL},
|
||||
)
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule])
|
||||
|
||||
Reference in New Issue
Block a user