From fbe5e7cdd1cc8f3c98e580940d79a96433afb31e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 27 Jun 2025 16:23:55 -0700 Subject: [PATCH] fix: Fix infinite loop required tools test (#3084) --- letta/agents/letta_agent.py | 8 ++++-- letta/helpers/tool_rule_solver.py | 9 +++--- tests/test_tool_rule_solver.py | 46 +++++++++++++++++++------------ 3 files changed, 39 insertions(+), 24 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index d6d3ce0c..8a310747 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -993,6 +993,7 @@ class LettaAgent(BaseAgent): # 4. Decide whether to keep stepping (<<< focal section simplified) continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation( + agent_state=agent_state, request_heartbeat=request_heartbeat, tool_call_name=tool_call_name, tool_rule_violated=tool_rule_violated, @@ -1048,6 +1049,7 @@ class LettaAgent(BaseAgent): def _decide_continuation( self, + agent_state: AgentState, request_heartbeat: bool, tool_call_name: str, tool_rule_violated: bool, @@ -1083,10 +1085,12 @@ class LettaAgent(BaseAgent): continue_stepping = False stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value) else: - uncalled = tool_rules_solver.get_uncalled_required_tools() + uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools])) if not continue_stepping and uncalled: continue_stepping = True - heartbeat_reason = f"{NON_USER_MSG_PREFIX}Missing required tools: " f"{', '.join(uncalled)}" + heartbeat_reason = ( + f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [" f"{', '.join(uncalled)}] to be called still." + ) stop_reason = None # reset – we’re still going diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 65258f5b..e9a2dd71 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -151,11 +151,11 @@ class ToolRulesSolver(BaseModel): """Check if the tool is defined as a continue tool in the tool rules.""" return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) - def has_required_tools_been_called(self) -> bool: + def has_required_tools_been_called(self, available_tools: Set[str]) -> bool: """Check if all required-before-exit tools have been called.""" - return len(self.get_uncalled_required_tools()) == 0 + return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0 - def get_uncalled_required_tools(self) -> List[str]: + def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]: """Get the list of required-before-exit tools that have not been called yet.""" if not self.required_before_exit_tool_rules: return [] # No required tools means no uncalled tools @@ -163,7 +163,8 @@ class ToolRulesSolver(BaseModel): required_tool_names = {rule.tool_name for rule in self.required_before_exit_tool_rules} called_tool_names = set(self.tool_call_history) - return list(required_tool_names - called_tool_names) + # Get required tools that are uncalled AND available + return list((required_tool_names & available_tools) - called_tool_names) def get_ending_tool_names(self) -> List[str]: """Get the names of tools that are required before exit.""" diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index d319b575..d81b2011 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -191,7 +191,7 @@ def test_required_before_exit_tool_rule_has_required_tools_been_called(): """Test has_required_tools_been_called() with no required tools.""" solver = ToolRulesSolver(tool_rules=[]) - assert solver.has_required_tools_been_called() is True, "Should return True when no required tools are defined" + assert solver.has_required_tools_been_called(set()) is True, "Should return True when no required tools are defined" def test_required_before_exit_tool_rule_single_required_tool(): @@ -199,13 +199,13 @@ def test_required_before_exit_tool_rule_single_required_tool(): required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL) solver = ToolRulesSolver(tool_rules=[required_rule]) - assert solver.has_required_tools_been_called() is False, "Should return False when required tool hasn't been called" - assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should return list with uncalled required tool" + assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False when required tool hasn't been called" + assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should return list with uncalled required tool" solver.register_tool_call(SAVE_TOOL) - assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called" - assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called" + assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called" + assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called" def test_required_before_exit_tool_rule_multiple_required_tools(): @@ -214,21 +214,31 @@ def test_required_before_exit_tool_rule_multiple_required_tools(): required_rule_2 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_2) solver = ToolRulesSolver(tool_rules=[required_rule_1, required_rule_2]) - assert solver.has_required_tools_been_called() is False, "Should return False when no required tools have been called" - uncalled_tools = solver.get_uncalled_required_tools() + assert ( + solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False + ), "Should return False when no required tools have been called" + uncalled_tools = solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) assert set(uncalled_tools) == {REQUIRED_TOOL_1, REQUIRED_TOOL_2}, "Should return both uncalled required tools" # Call first required tool solver.register_tool_call(REQUIRED_TOOL_1) - assert solver.has_required_tools_been_called() is False, "Should return False when only one required tool has been called" - assert solver.get_uncalled_required_tools() == [REQUIRED_TOOL_2], "Should return remaining uncalled required tool" + assert ( + solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False + ), "Should return False when only one required tool has been called" + assert solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == [ + REQUIRED_TOOL_2 + ], "Should return remaining uncalled required tool" # Call second required tool solver.register_tool_call(REQUIRED_TOOL_2) - assert solver.has_required_tools_been_called() is True, "Should return True when all required tools have been called" - assert solver.get_uncalled_required_tools() == [], "Should return empty list when all required tools have been called" + assert ( + solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is True + ), "Should return True when all required tools have been called" + assert ( + solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == [] + ), "Should return empty list when all required tools have been called" def test_required_before_exit_tool_rule_mixed_with_other_tools(): @@ -240,14 +250,14 @@ def test_required_before_exit_tool_rule_mixed_with_other_tools(): solver.register_tool_call(START_TOOL) solver.register_tool_call(HELPER_TOOL) - assert solver.has_required_tools_been_called() is False, "Should return False even after calling other tools" - assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should still show required tool as uncalled" + assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False even after calling other tools" + assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should still show required tool as uncalled" # Call required tool solver.register_tool_call(SAVE_TOOL) - assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called" - assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called" + assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called" + assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called" def test_required_before_exit_tool_rule_clear_history(): @@ -257,10 +267,10 @@ def test_required_before_exit_tool_rule_clear_history(): # Call required tool solver.register_tool_call(SAVE_TOOL) - assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called" + assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called" # Clear history solver.clear_tool_history() - assert solver.has_required_tools_been_called() is False, "Should return False after clearing history" - assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should show required tool as uncalled after clearing history" + assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False after clearing history" + assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"