From 641eb0354f83254939920635a04f8ca4c28a33e0 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 28 Aug 2025 12:43:32 -0700 Subject: [PATCH] feat: De-dupe tool rules [LET-4091] (#4282) * Add hash/eqs for de-dupe * Add sdk test --- letta/helpers/converters.py | 5 +- letta/schemas/tool_rule.py | 58 +++++++ tests/test_sdk_client.py | 57 ++++++- tests/test_tool_rule_solver.py | 283 +++++++++++++++++++++++++++++++++ 4 files changed, 400 insertions(+), 3 deletions(-) diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index daa92522..3ad3ce67 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -91,8 +91,11 @@ def serialize_tool_rules(tool_rules: Optional[List[ToolRule]]) -> List[Dict[str, if not tool_rules: return [] + # de-duplicate tool rules using dict.fromkeys (preserves order in Python 3.7+) + deduplicated_rules = list(dict.fromkeys(tool_rules)) + data = [ - {**rule.model_dump(mode="json"), "type": rule.type.value} for rule in tool_rules + {**rule.model_dump(mode="json"), "type": rule.type.value} for rule in deduplicated_rules ] # Convert Enum to string for JSON compatibility # Validate ToolRule structure diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index b4744abf..9c310f58 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -20,6 +20,16 @@ class BaseToolRule(LettaBase): description="Optional Jinja2 template for generating agent prompt about this tool rule. Template can use variables like 'tool_name' and rule-specific attributes.", ) + def __hash__(self): + """Base hash using tool_name and type.""" + return hash((self.tool_name, self.type)) + + def __eq__(self, other): + """Base equality using tool_name and type.""" + if not isinstance(other, BaseToolRule): + return False + return self.tool_name == other.tool_name and self.type == other.type + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> set[str]: raise NotImplementedError @@ -54,6 +64,16 @@ class ChildToolRule(BaseToolRule): description="Optional Jinja2 template for generating agent prompt about this tool rule.", ) + def __hash__(self): + """Hash including children list (sorted for consistency).""" + return hash((self.tool_name, self.type, tuple(sorted(self.children)))) + + def __eq__(self, other): + """Equality including children list.""" + if not isinstance(other, ChildToolRule): + return False + return self.tool_name == other.tool_name and self.type == other.type and sorted(self.children) == sorted(other.children) + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: last_tool = tool_call_history[-1] if tool_call_history else None return set(self.children) if last_tool == self.tool_name else available_tools @@ -71,6 +91,16 @@ class ParentToolRule(BaseToolRule): description="Optional Jinja2 template for generating agent prompt about this tool rule.", ) + def __hash__(self): + """Hash including children list (sorted for consistency).""" + return hash((self.tool_name, self.type, tuple(sorted(self.children)))) + + def __eq__(self, other): + """Equality including children list.""" + if not isinstance(other, ParentToolRule): + return False + return self.tool_name == other.tool_name and self.type == other.type and sorted(self.children) == sorted(other.children) + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: last_tool = tool_call_history[-1] if tool_call_history else None return set(self.children) if last_tool == self.tool_name else available_tools - set(self.children) @@ -90,6 +120,24 @@ class ConditionalToolRule(BaseToolRule): description="Optional Jinja2 template for generating agent prompt about this tool rule.", ) + def __hash__(self): + """Hash including all configuration fields.""" + # convert dict to sorted tuple of items for consistent hashing + mapping_items = tuple(sorted(self.child_output_mapping.items())) + return hash((self.tool_name, self.type, self.default_child, mapping_items, self.require_output_mapping)) + + def __eq__(self, other): + """Equality including all configuration fields.""" + if not isinstance(other, ConditionalToolRule): + return False + return ( + self.tool_name == other.tool_name + and self.type == other.type + and self.default_child == other.default_child + and self.child_output_mapping == other.child_output_mapping + and self.require_output_mapping == other.require_output_mapping + ) + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: """Determine valid tools based on function output mapping.""" if not tool_call_history or tool_call_history[-1] != self.tool_name: @@ -203,6 +251,16 @@ class MaxCountPerStepToolRule(BaseToolRule): description="Optional Jinja2 template for generating agent prompt about this tool rule.", ) + def __hash__(self): + """Hash including max_count_limit.""" + return hash((self.tool_name, self.type, self.max_count_limit)) + + def __eq__(self, other): + """Equality including max_count_limit.""" + if not isinstance(other, MaxCountPerStepToolRule): + return False + return self.tool_name == other.tool_name and self.type == other.type and self.max_count_limit == other.max_count_limit + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: """Restricts the tool if it has been called max_count_limit times in the current step.""" count = tool_call_history.count(self.tool_name) diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 4ed51b35..1ac6fddd 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -9,9 +9,9 @@ from typing import List, Type import pytest from dotenv import load_dotenv -from letta_client import CreateBlock +from letta_client import ContinueToolRule, CreateBlock from letta_client import Letta as LettaSDKClient -from letta_client import LettaRequest, MessageCreate, TerminalToolRule, TextContent +from letta_client import LettaRequest, MaxCountPerStepToolRule, MessageCreate, TerminalToolRule, TextContent from letta_client.client import BaseTool from letta_client.core import ApiError from letta_client.types import AgentState, ToolReturnMessage @@ -1251,6 +1251,59 @@ def test_agent_tools_list(client: LettaSDKClient): client.agents.delete(agent_id=agent_state.id) +def test_agent_tool_rules_deduplication(client: LettaSDKClient): + """Test that duplicate tool rules are properly deduplicated when creating/updating agents.""" + # Create agent with duplicate tool rules + duplicate_rules = [ + TerminalToolRule(tool_name="send_message"), + TerminalToolRule(tool_name="send_message"), # exact duplicate + TerminalToolRule(tool_name="send_message"), # another duplicate + ] + + agent_state = client.agents.create( + name="test_agent_dedup", + memory_blocks=[ + CreateBlock( + label="persona", + value="You are a helpful assistant.", + ), + ], + model="openai/gpt-4o-mini", + embedding="openai/text-embedding-3-small", + tool_rules=duplicate_rules, + include_base_tools=False, + ) + + # Get the agent and check tool rules + retrieved_agent = client.agents.retrieve(agent_id=agent_state.id) + assert len(retrieved_agent.tool_rules) == 1, f"Expected 1 unique tool rule, got {len(retrieved_agent.tool_rules)}" + assert retrieved_agent.tool_rules[0].tool_name == "send_message" + assert retrieved_agent.tool_rules[0].type == "exit_loop" + + # Test update with duplicates + update_rules = [ + ContinueToolRule(tool_name="conversation_search"), + ContinueToolRule(tool_name="conversation_search"), # duplicate + MaxCountPerStepToolRule(tool_name="test_tool", max_count_limit=2), + MaxCountPerStepToolRule(tool_name="test_tool", max_count_limit=2), # exact duplicate + MaxCountPerStepToolRule(tool_name="test_tool", max_count_limit=3), # different limit, not a duplicate + ] + + updated_agent = client.agents.modify(agent_id=agent_state.id, tool_rules=update_rules) + + # Check that duplicates were removed + assert len(updated_agent.tool_rules) == 3, f"Expected 3 unique tool rules after update, got {len(updated_agent.tool_rules)}" + + # Verify the specific rules + rule_set = {(r.tool_name, r.type, getattr(r, "max_count_limit", None)) for r in updated_agent.tool_rules} + expected_set = { + ("conversation_search", "continue_loop", None), + ("test_tool", "max_count_per_step", 2), + ("test_tool", "max_count_per_step", 3), + } + assert rule_set == expected_set, f"Tool rules don't match expected. Got: {rule_set}" + + def test_add_tool_with_multiple_functions_in_source_code(client: LettaSDKClient): """Test adding a tool with multiple functions in the source code""" import textwrap diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 25bb5d31..5c6a5e86 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -4,8 +4,10 @@ from letta.helpers import ToolRulesSolver from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, + ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, + ParentToolRule, RequiredBeforeExitToolRule, TerminalToolRule, ) @@ -184,6 +186,287 @@ def test_max_count_per_step_tool_rule_resets_on_clear(): assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should allow 'start_tool' again after clearing history" +def test_tool_rule_equality_and_hashing(): + """Test __eq__ and __hash__ methods for all tool rule types.""" + + # test InitToolRule equality + rule1 = InitToolRule(tool_name="test_tool") + rule2 = InitToolRule(tool_name="test_tool") + rule3 = InitToolRule(tool_name="different_tool") + + assert rule1 == rule2, "InitToolRules with same tool_name should be equal" + assert rule1 != rule3, "InitToolRules with different tool_name should not be equal" + assert hash(rule1) == hash(rule2), "Equal InitToolRules should have same hash" + assert hash(rule1) != hash(rule3), "Different InitToolRules should have different hash" + + # test ChildToolRule equality + child_rule1 = ChildToolRule(tool_name="parent", children=["child1", "child2"]) + child_rule2 = ChildToolRule(tool_name="parent", children=["child2", "child1"]) # different order + child_rule3 = ChildToolRule(tool_name="parent", children=["child1"]) + child_rule4 = ChildToolRule(tool_name="different_parent", children=["child1", "child2"]) + + assert child_rule1 == child_rule2, "ChildToolRules with same children (different order) should be equal" + assert child_rule1 != child_rule3, "ChildToolRules with different children should not be equal" + assert child_rule1 != child_rule4, "ChildToolRules with different tool_name should not be equal" + assert hash(child_rule1) == hash(child_rule2), "Equal ChildToolRules should have same hash" + assert hash(child_rule1) != hash(child_rule3), "Different ChildToolRules should have different hash" + + # test ConditionalToolRule equality + cond_rule1 = ConditionalToolRule( + tool_name="conditional", child_output_mapping={"yes": "tool1", "no": "tool2"}, default_child="tool3", require_output_mapping=True + ) + cond_rule2 = ConditionalToolRule( + tool_name="conditional", + child_output_mapping={"no": "tool2", "yes": "tool1"}, # different order + default_child="tool3", + require_output_mapping=True, + ) + cond_rule3 = ConditionalToolRule( + tool_name="conditional", + child_output_mapping={"yes": "tool1", "no": "tool2"}, + default_child="different_tool", + require_output_mapping=True, + ) + cond_rule4 = ConditionalToolRule( + tool_name="conditional", + child_output_mapping={"yes": "tool1", "no": "tool2"}, + default_child="tool3", + require_output_mapping=False, # different require_output_mapping + ) + + assert cond_rule1 == cond_rule2, "ConditionalToolRules with same mapping (different order) should be equal" + assert cond_rule1 != cond_rule3, "ConditionalToolRules with different default_child should not be equal" + assert cond_rule1 != cond_rule4, "ConditionalToolRules with different require_output_mapping should not be equal" + assert hash(cond_rule1) == hash(cond_rule2), "Equal ConditionalToolRules should have same hash" + assert hash(cond_rule1) != hash(cond_rule3), "Different ConditionalToolRules should have different hash" + + # test MaxCountPerStepToolRule equality + max_rule1 = MaxCountPerStepToolRule(tool_name="limited_tool", max_count_limit=3) + max_rule2 = MaxCountPerStepToolRule(tool_name="limited_tool", max_count_limit=3) + max_rule3 = MaxCountPerStepToolRule(tool_name="limited_tool", max_count_limit=5) + max_rule4 = MaxCountPerStepToolRule(tool_name="different_tool", max_count_limit=3) + + assert max_rule1 == max_rule2, "MaxCountPerStepToolRules with same limit should be equal" + assert max_rule1 != max_rule3, "MaxCountPerStepToolRules with different limit should not be equal" + assert max_rule1 != max_rule4, "MaxCountPerStepToolRules with different tool_name should not be equal" + assert hash(max_rule1) == hash(max_rule2), "Equal MaxCountPerStepToolRules should have same hash" + assert hash(max_rule1) != hash(max_rule3), "Different MaxCountPerStepToolRules should have different hash" + + # test TerminalToolRule equality + term_rule1 = TerminalToolRule(tool_name="exit_tool") + term_rule2 = TerminalToolRule(tool_name="exit_tool") + term_rule3 = TerminalToolRule(tool_name="different_exit_tool") + + assert term_rule1 == term_rule2, "TerminalToolRules with same tool_name should be equal" + assert term_rule1 != term_rule3, "TerminalToolRules with different tool_name should not be equal" + assert hash(term_rule1) == hash(term_rule2), "Equal TerminalToolRules should have same hash" + + # test RequiredBeforeExitToolRule equality + req_rule1 = RequiredBeforeExitToolRule(tool_name="required_tool") + req_rule2 = RequiredBeforeExitToolRule(tool_name="required_tool") + req_rule3 = RequiredBeforeExitToolRule(tool_name="different_required_tool") + + assert req_rule1 == req_rule2, "RequiredBeforeExitToolRules with same tool_name should be equal" + assert req_rule1 != req_rule3, "RequiredBeforeExitToolRules with different tool_name should not be equal" + assert hash(req_rule1) == hash(req_rule2), "Equal RequiredBeforeExitToolRules should have same hash" + + # test cross-type inequality + assert rule1 != child_rule1, "Different rule types should never be equal" + assert child_rule1 != cond_rule1, "Different rule types should never be equal" + assert max_rule1 != term_rule1, "Different rule types should never be equal" + + +def test_tool_rule_deduplication_in_set(): + """Test that duplicate tool rules are properly deduplicated when used in sets.""" + + # create duplicate rules + rule1 = InitToolRule(tool_name="start") + rule2 = InitToolRule(tool_name="start") # duplicate + rule3 = InitToolRule(tool_name="different_start") + + child1 = ChildToolRule(tool_name="parent", children=["a", "b"]) + child2 = ChildToolRule(tool_name="parent", children=["b", "a"]) # duplicate (different order) + child3 = ChildToolRule(tool_name="parent", children=["a", "b", "c"]) # different + + max1 = MaxCountPerStepToolRule(tool_name="limited", max_count_limit=2) + max2 = MaxCountPerStepToolRule(tool_name="limited", max_count_limit=2) # duplicate + max3 = MaxCountPerStepToolRule(tool_name="limited", max_count_limit=3) # different + + # test set deduplication + rules_set = {rule1, rule2, rule3, child1, child2, child3, max1, max2, max3} + assert len(rules_set) == 6, "Set should contain only unique rules" + + # test list deduplication using dict.fromkeys + rules_list = [rule1, rule2, rule3, child1, child2, child3, max1, max2, max3] + deduplicated = list(dict.fromkeys(rules_list)) + assert len(deduplicated) == 6, "dict.fromkeys should deduplicate rules" + assert deduplicated[0] == rule1, "Order should be preserved" + assert deduplicated[1] == rule3, "Order should be preserved" + assert deduplicated[2] == child1, "Order should be preserved" + assert deduplicated[3] == child3, "Order should be preserved" + assert deduplicated[4] == max1, "Order should be preserved" + assert deduplicated[5] == max3, "Order should be preserved" + + +def test_parent_tool_rule_equality(): + """Test ParentToolRule equality and hashing.""" + parent_rule1 = ParentToolRule(tool_name="parent", children=["child1", "child2"]) + parent_rule2 = ParentToolRule(tool_name="parent", children=["child2", "child1"]) # different order + parent_rule3 = ParentToolRule(tool_name="parent", children=["child1"]) + parent_rule4 = ParentToolRule(tool_name="different_parent", children=["child1", "child2"]) + + assert parent_rule1 == parent_rule2, "ParentToolRules with same children (different order) should be equal" + assert parent_rule1 != parent_rule3, "ParentToolRules with different children should not be equal" + assert parent_rule1 != parent_rule4, "ParentToolRules with different tool_name should not be equal" + assert hash(parent_rule1) == hash(parent_rule2), "Equal ParentToolRules should have same hash" + assert hash(parent_rule1) != hash(parent_rule3), "Different ParentToolRules should have different hash" + + +def test_continue_tool_rule_equality_and_hashing(): + r1 = ContinueToolRule(tool_name="go_on") + r2 = ContinueToolRule(tool_name="go_on") + r3 = ContinueToolRule(tool_name="different") + + assert r1 == r2 + assert hash(r1) == hash(r2) + assert r1 != r3 + assert hash(r1) != hash(r3) + + +@pytest.mark.parametrize( + "rule_factory, kwargs_a, kwargs_b", + [ + (lambda **kw: InitToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), + (lambda **kw: TerminalToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), + (lambda **kw: ContinueToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), + (lambda **kw: RequiredBeforeExitToolRule(**kw), dict(tool_name="t"), dict(tool_name="t")), + (lambda **kw: MaxCountPerStepToolRule(**kw), dict(tool_name="t", max_count_limit=2), dict(tool_name="t", max_count_limit=2)), + (lambda **kw: ChildToolRule(**kw), dict(tool_name="t", children=["a", "b"]), dict(tool_name="t", children=["a", "b"])), + (lambda **kw: ParentToolRule(**kw), dict(tool_name="t", children=["a", "b"]), dict(tool_name="t", children=["a", "b"])), + ( + lambda **kw: ConditionalToolRule(**kw), + dict(tool_name="t", child_output_mapping={"x": "a"}, default_child=None, require_output_mapping=False), + dict(tool_name="t", child_output_mapping={"x": "a"}, default_child=None, require_output_mapping=False), + ), + ], +) +def test_prompt_template_ignored(rule_factory, kwargs_a, kwargs_b): + r1 = rule_factory(**kwargs_a, prompt_template="A") + r2 = rule_factory(**kwargs_b, prompt_template="B") + assert r1 == r2, f"{type(r1).__name__} should ignore prompt_template in equality" + assert hash(r1) == hash(r2), f"{type(r1).__name__} should ignore prompt_template in hash" + + +@pytest.mark.parametrize( + "a,b", + [ + (InitToolRule(tool_name="same"), TerminalToolRule(tool_name="same")), + (ContinueToolRule(tool_name="same"), RequiredBeforeExitToolRule(tool_name="same")), + (ChildToolRule(tool_name="same", children=["x"]), ParentToolRule(tool_name="same", children=["x"])), + ], +) +def test_cross_type_hash_distinguishes_types(a, b): + assert a != b + assert hash(a) != hash(b) + + +@pytest.mark.parametrize( + "rule", + [ + InitToolRule(tool_name="x"), + TerminalToolRule(tool_name="x"), + ContinueToolRule(tool_name="x"), + RequiredBeforeExitToolRule(tool_name="x"), + MaxCountPerStepToolRule(tool_name="x", max_count_limit=1), + ChildToolRule(tool_name="x", children=["a"]), + ParentToolRule(tool_name="x", children=["a"]), + ConditionalToolRule(tool_name="x", child_output_mapping={"k": "a"}, default_child=None, require_output_mapping=False), + ], +) +def test_equality_with_non_rule_objects(rule): + assert rule != object() + assert rule != None # noqa: E711 + + +def test_conditional_tool_rule_mapping_order_and_hash(): + r1 = ConditionalToolRule( + tool_name="cond", child_output_mapping={"yes": "tool1", "no": "tool2"}, default_child="tool3", require_output_mapping=True + ) + r2 = ConditionalToolRule( + tool_name="cond", child_output_mapping={"no": "tool2", "yes": "tool1"}, default_child="tool3", require_output_mapping=True + ) + assert r1 == r2 + assert hash(r1) == hash(r2) + + +def test_conditional_tool_rule_mapping_numeric_and_bool_keys_equivalence_current_behavior(): + # NOTE: Python dict equality treats True == 1 and 1 == 1.0 as equal keys. + # This test documents current behavior of __eq__ on mapping equality. + r_bool = ConditionalToolRule(tool_name="cond", child_output_mapping={True: "A"}, default_child=None, require_output_mapping=False) + r_int = ConditionalToolRule(tool_name="cond", child_output_mapping={1: "A"}, default_child=None, require_output_mapping=False) + r_float = ConditionalToolRule(tool_name="cond", child_output_mapping={1.0: "A"}, default_child=None, require_output_mapping=False) + # Document current semantics: these are equal under Python's dict equality. + assert r_bool == r_int + assert r_int == r_float + assert hash(r_bool) == hash(r_int) == hash(r_float) + + +def test_conditional_tool_rule_mapping_string_vs_numeric_not_equal(): + r_num = ConditionalToolRule(tool_name="cond", child_output_mapping={1: "A"}, default_child=None, require_output_mapping=False) + r_str = ConditionalToolRule(tool_name="cond", child_output_mapping={"1": "A"}, default_child=None, require_output_mapping=False) + assert r_num != r_str + assert hash(r_num) != hash(r_str) + + +def test_child_and_parent_order_invariance_multiple_permutations(): + pass + # permute a few ways + variants = [ + ["a", "b", "c"], + ["b", "c", "a"], + ["c", "a", "b"], + ] + child_rules = [ChildToolRule(tool_name="t", children=ch) for ch in variants] + parent_rules = [ParentToolRule(tool_name="t", children=ch) for ch in variants] + + # All child rules equal and same hash + for r in child_rules[1:]: + assert child_rules[0] == r + assert hash(child_rules[0]) == hash(r) + + # All parent rules equal and same hash + for r in parent_rules[1:]: + assert parent_rules[0] == r + assert hash(parent_rules[0]) == hash(r) + + +def test_conditional_order_invariance_multiple_permutations(): + maps = [ + {"x": "a", "y": "b", "z": "c"}, + {"z": "c", "y": "b", "x": "a"}, + {"y": "b", "x": "a", "z": "c"}, + ] + rules = [ConditionalToolRule(tool_name="t", child_output_mapping=m, default_child=None, require_output_mapping=False) for m in maps] + for r in rules[1:]: + assert rules[0] == r + assert hash(rules[0]) == hash(r) + + +# ---------- 7) Dict/dedup across all types including ContinueToolRule ---------- + + +def test_dedup_in_set_with_continue_and_required_and_terminal(): + s = { + ContinueToolRule(tool_name="x"), + ContinueToolRule(tool_name="x"), # dup + RequiredBeforeExitToolRule(tool_name="y"), + RequiredBeforeExitToolRule(tool_name="y"), # dup + TerminalToolRule(tool_name="z"), + TerminalToolRule(tool_name="z"), # dup + } + assert len(s) == 3 + + def test_required_before_exit_tool_rule_has_required_tools_been_called(): """Test has_required_tools_been_called() with no required tools.""" solver = ToolRulesSolver(tool_rules=[])