fix: Fix infinite loop required tools test (#3084)
This commit is contained in:
@@ -993,6 +993,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
# 4. Decide whether to keep stepping (<<< focal section simplified)
|
||||
continue_stepping, heartbeat_reason, stop_reason = self._decide_continuation(
|
||||
agent_state=agent_state,
|
||||
request_heartbeat=request_heartbeat,
|
||||
tool_call_name=tool_call_name,
|
||||
tool_rule_violated=tool_rule_violated,
|
||||
@@ -1048,6 +1049,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
def _decide_continuation(
|
||||
self,
|
||||
agent_state: AgentState,
|
||||
request_heartbeat: bool,
|
||||
tool_call_name: str,
|
||||
tool_rule_violated: bool,
|
||||
@@ -1083,10 +1085,12 @@ class LettaAgent(BaseAgent):
|
||||
continue_stepping = False
|
||||
stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||||
else:
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools()
|
||||
uncalled = tool_rules_solver.get_uncalled_required_tools(available_tools=set([t.name for t in agent_state.tools]))
|
||||
if not continue_stepping and uncalled:
|
||||
continue_stepping = True
|
||||
heartbeat_reason = f"{NON_USER_MSG_PREFIX}Missing required tools: " f"{', '.join(uncalled)}"
|
||||
heartbeat_reason = (
|
||||
f"{NON_USER_MSG_PREFIX}Continuing, user expects these tools: [" f"{', '.join(uncalled)}] to be called still."
|
||||
)
|
||||
|
||||
stop_reason = None # reset – we’re still going
|
||||
|
||||
|
||||
@@ -151,11 +151,11 @@ class ToolRulesSolver(BaseModel):
|
||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||
|
||||
def has_required_tools_been_called(self) -> bool:
|
||||
def has_required_tools_been_called(self, available_tools: Set[str]) -> bool:
|
||||
"""Check if all required-before-exit tools have been called."""
|
||||
return len(self.get_uncalled_required_tools()) == 0
|
||||
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
||||
|
||||
def get_uncalled_required_tools(self) -> List[str]:
|
||||
def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]:
|
||||
"""Get the list of required-before-exit tools that have not been called yet."""
|
||||
if not self.required_before_exit_tool_rules:
|
||||
return [] # No required tools means no uncalled tools
|
||||
@@ -163,7 +163,8 @@ class ToolRulesSolver(BaseModel):
|
||||
required_tool_names = {rule.tool_name for rule in self.required_before_exit_tool_rules}
|
||||
called_tool_names = set(self.tool_call_history)
|
||||
|
||||
return list(required_tool_names - called_tool_names)
|
||||
# Get required tools that are uncalled AND available
|
||||
return list((required_tool_names & available_tools) - called_tool_names)
|
||||
|
||||
def get_ending_tool_names(self) -> List[str]:
|
||||
"""Get the names of tools that are required before exit."""
|
||||
|
||||
@@ -191,7 +191,7 @@ def test_required_before_exit_tool_rule_has_required_tools_been_called():
|
||||
"""Test has_required_tools_been_called() with no required tools."""
|
||||
solver = ToolRulesSolver(tool_rules=[])
|
||||
|
||||
assert solver.has_required_tools_been_called() is True, "Should return True when no required tools are defined"
|
||||
assert solver.has_required_tools_been_called(set()) is True, "Should return True when no required tools are defined"
|
||||
|
||||
|
||||
def test_required_before_exit_tool_rule_single_required_tool():
|
||||
@@ -199,13 +199,13 @@ def test_required_before_exit_tool_rule_single_required_tool():
|
||||
required_rule = RequiredBeforeExitToolRule(tool_name=SAVE_TOOL)
|
||||
solver = ToolRulesSolver(tool_rules=[required_rule])
|
||||
|
||||
assert solver.has_required_tools_been_called() is False, "Should return False when required tool hasn't been called"
|
||||
assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should return list with uncalled required tool"
|
||||
assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False when required tool hasn't been called"
|
||||
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should return list with uncalled required tool"
|
||||
|
||||
solver.register_tool_call(SAVE_TOOL)
|
||||
|
||||
assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
|
||||
assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called"
|
||||
assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called"
|
||||
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called"
|
||||
|
||||
|
||||
def test_required_before_exit_tool_rule_multiple_required_tools():
|
||||
@@ -214,21 +214,31 @@ def test_required_before_exit_tool_rule_multiple_required_tools():
|
||||
required_rule_2 = RequiredBeforeExitToolRule(tool_name=REQUIRED_TOOL_2)
|
||||
solver = ToolRulesSolver(tool_rules=[required_rule_1, required_rule_2])
|
||||
|
||||
assert solver.has_required_tools_been_called() is False, "Should return False when no required tools have been called"
|
||||
uncalled_tools = solver.get_uncalled_required_tools()
|
||||
assert (
|
||||
solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False
|
||||
), "Should return False when no required tools have been called"
|
||||
uncalled_tools = solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2})
|
||||
assert set(uncalled_tools) == {REQUIRED_TOOL_1, REQUIRED_TOOL_2}, "Should return both uncalled required tools"
|
||||
|
||||
# Call first required tool
|
||||
solver.register_tool_call(REQUIRED_TOOL_1)
|
||||
|
||||
assert solver.has_required_tools_been_called() is False, "Should return False when only one required tool has been called"
|
||||
assert solver.get_uncalled_required_tools() == [REQUIRED_TOOL_2], "Should return remaining uncalled required tool"
|
||||
assert (
|
||||
solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is False
|
||||
), "Should return False when only one required tool has been called"
|
||||
assert solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == [
|
||||
REQUIRED_TOOL_2
|
||||
], "Should return remaining uncalled required tool"
|
||||
|
||||
# Call second required tool
|
||||
solver.register_tool_call(REQUIRED_TOOL_2)
|
||||
|
||||
assert solver.has_required_tools_been_called() is True, "Should return True when all required tools have been called"
|
||||
assert solver.get_uncalled_required_tools() == [], "Should return empty list when all required tools have been called"
|
||||
assert (
|
||||
solver.has_required_tools_been_called({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) is True
|
||||
), "Should return True when all required tools have been called"
|
||||
assert (
|
||||
solver.get_uncalled_required_tools({REQUIRED_TOOL_1, REQUIRED_TOOL_2}) == []
|
||||
), "Should return empty list when all required tools have been called"
|
||||
|
||||
|
||||
def test_required_before_exit_tool_rule_mixed_with_other_tools():
|
||||
@@ -240,14 +250,14 @@ def test_required_before_exit_tool_rule_mixed_with_other_tools():
|
||||
solver.register_tool_call(START_TOOL)
|
||||
solver.register_tool_call(HELPER_TOOL)
|
||||
|
||||
assert solver.has_required_tools_been_called() is False, "Should return False even after calling other tools"
|
||||
assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should still show required tool as uncalled"
|
||||
assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False even after calling other tools"
|
||||
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should still show required tool as uncalled"
|
||||
|
||||
# Call required tool
|
||||
solver.register_tool_call(SAVE_TOOL)
|
||||
|
||||
assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
|
||||
assert solver.get_uncalled_required_tools() == [], "Should return empty list after required tool is called"
|
||||
assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called"
|
||||
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [], "Should return empty list after required tool is called"
|
||||
|
||||
|
||||
def test_required_before_exit_tool_rule_clear_history():
|
||||
@@ -257,10 +267,10 @@ def test_required_before_exit_tool_rule_clear_history():
|
||||
|
||||
# Call required tool
|
||||
solver.register_tool_call(SAVE_TOOL)
|
||||
assert solver.has_required_tools_been_called() is True, "Should return True after required tool is called"
|
||||
assert solver.has_required_tools_been_called({SAVE_TOOL}) is True, "Should return True after required tool is called"
|
||||
|
||||
# Clear history
|
||||
solver.clear_tool_history()
|
||||
|
||||
assert solver.has_required_tools_been_called() is False, "Should return False after clearing history"
|
||||
assert solver.get_uncalled_required_tools() == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"
|
||||
assert solver.has_required_tools_been_called({SAVE_TOOL}) is False, "Should return False after clearing history"
|
||||
assert solver.get_uncalled_required_tools({SAVE_TOOL}) == [SAVE_TOOL], "Should show required tool as uncalled after clearing history"
|
||||
|
||||
Reference in New Issue
Block a user