fix: Fix infinite loop required tools test (#3084)

This commit is contained in:
Matthew Zhou
2025-06-27 16:23:55 -07:00
committed by GitHub
parent 1287f44515
commit fbe5e7cdd1
3 changed files with 39 additions and 24 deletions

View File

@@ -993,6 +993,7 @@ class LettaAgent(BaseAgent):
# 4. Decide whether to keep stepping (<<< focal section simplified)
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
agent_state=agent_state,
request_heartbeat=request_heartbeat,
tool_call_name=tool_call_name,
tool_rule_violated=tool_rule_violated,
@@ -1048,6 +1049,7 @@ class LettaAgent(BaseAgent):
def _decide_continuation(
self,
agent_state: AgentState,
request_heartbeat: bool,
tool_call_name: str,
tool_rule_violated: bool,
@@ -1083,10 +1085,12 @@ class LettaAgent(BaseAgent):
continue_stepping = False
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
else:
uncalled = tool_rules_solver.get_uncalled_required_tools()
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
if not continue_stepping and uncalled:
continue_stepping = True
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Missing required tools: " f"{', '.join(uncalled)}"
heartbeat_reason = (
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [" f"{', '.join(uncalled)}] to be called still."
)
stop_reason = None # reset were still going

View File

@@ -151,11 +151,11 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def has_required_tools_been_called(self) -> bool:
def has_required_tools_been_called(self, available_tools: Set[str]) -> bool:
"""Check if all required-before-exit tools have been called."""
return len(self.get_uncalled_required_tools()) == 0
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
def get_uncalled_required_tools(self) -> List[str]:
def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]:
"""Get the list of required-before-exit tools that have not been called yet."""
if not self.required_before_exit_tool_rules:
return [] # No required tools means no uncalled tools
@@ -163,7 +163,8 @@ class ToolRulesSolver(BaseModel):
required_tool_names = {rule.tool_name for rule in self.required_before_exit_tool_rules}
called_tool_names = set(self.tool_call_history)
return list(required_tool_names - called_tool_names)
# Get required tools that are uncalled AND available
return list((required_tool_names & available_tools) - called_tool_names)
def get_ending_tool_names(self) -> List[str]:
"""Get the names of tools that are required before exit."""

View File

@@ -191,7 +191,7 @@ def test_required_before_exit_tool_rule_has_required_tools_been_called():
"""Test has_required_tools_been_called() with no required tools."""
solver = ToolRulesSolver(tool_rules=[])
assert solver.has_required_tools_been_called() is True, "Should return True when no required tools are defined"
assert solver.has_required_tools_been_called(set()) is True, "Should return True when no required tools are defined"
def test_required_before_exit_tool_rule_single_required_tool():
@@ -199,13 +199,13 @@ def test_required_before_exit_tool_rule_single_required_tool():
required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL)
solver = ToolRulesSolver(tool_rules=[required_rule])
assert solver.has_required_tools_been_called() is False, "Should return False when required tool hasn't been called"
assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should return list with uncalled required tool"
assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False when required tool hasn't been called"
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should return list with uncalled required tool"
solver.register_tool_call(SAVE_TOOL)
assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called"
assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called"
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called"
def test_required_before_exit_tool_rule_multiple_required_tools():
@@ -214,21 +214,31 @@ def test_required_before_exit_tool_rule_multiple_required_tools():
required_rule_2 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_2)
solver = ToolRulesSolver(tool_rules=[required_rule_1, required_rule_2])
assert solver.has_required_tools_been_called() is False, "Should return False when no required tools have been called"
uncalled_tools = solver.get_uncalled_required_tools()
assert (
solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False
), "Should return False when no required tools have been called"
uncalled_tools = solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2})
assert set(uncalled_tools) == {REQUIRED_TOOL_1, REQUIRED_TOOL_2}, "Should return both uncalled required tools"
# Call first required tool
solver.register_tool_call(REQUIRED_TOOL_1)
assert solver.has_required_tools_been_called() is False, "Should return False when only one required tool has been called"
assert solver.get_uncalled_required_tools() == [REQUIRED_TOOL_2], "Should return remaining uncalled required tool"
assert (
solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False
), "Should return False when only one required tool has been called"
assert solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == [
REQUIRED_TOOL_2
], "Should return remaining uncalled required tool"
# Call second required tool
solver.register_tool_call(REQUIRED_TOOL_2)
assert solver.has_required_tools_been_called() is True, "Should return True when all required tools have been called"
assert solver.get_uncalled_required_tools() == [], "Should return empty list when all required tools have been called"
assert (
solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is True
), "Should return True when all required tools have been called"
assert (
solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == []
), "Should return empty list when all required tools have been called"
def test_required_before_exit_tool_rule_mixed_with_other_tools():
@@ -240,14 +250,14 @@ def test_required_before_exit_tool_rule_mixed_with_other_tools():
solver.register_tool_call(START_TOOL)
solver.register_tool_call(HELPER_TOOL)
assert solver.has_required_tools_been_called() is False, "Should return False even after calling other tools"
assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should still show required tool as uncalled"
assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False even after calling other tools"
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should still show required tool as uncalled"
# Call required tool
solver.register_tool_call(SAVE_TOOL)
assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called"
assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called"
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called"
def test_required_before_exit_tool_rule_clear_history():
@@ -257,10 +267,10 @@ def test_required_before_exit_tool_rule_clear_history():
# Call required tool
solver.register_tool_call(SAVE_TOOL)
assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called"
# Clear history
solver.clear_tool_history()
assert solver.has_required_tools_been_called() is False, "Should return False after clearing history"
assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"
assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False after clearing history"
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"