From a7759fb514b4ddfa73d3dc6d7a4ce2799c0b67b9 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 17 Mar 2025 17:23:14 -0700 Subject: [PATCH] feat: Add `MaxCountPerStepToolRule` (#1319) --- letta/agent.py | 14 +- letta/helpers/converters.py | 22 +- letta/helpers/tool_rule_solver.py | 124 ++--- letta/schemas/enums.py | 6 +- letta/schemas/tool_rule.py | 75 ++- tests/integration_test_agent_tool_graph.py | 615 ++++++++++----------- tests/test_tool_rule_solver.py | 140 +++-- 7 files changed, 519 insertions(+), 477 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 5286f9cb..55dc838d 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -367,7 +367,10 @@ class Agent(BaseAgent): ) -> ChatCompletionResponse: """Get response from LLM API with robust retry mechanism.""" log_telemetry(self.logger, "_get_ai_reply start") - allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response) + available_tools = set([t.name for t in self.agent_state.tools]) + allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names( + available_tools=available_tools, last_function_response=self.last_function_response + ) agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools] allowed_functions = ( @@ -377,8 +380,8 @@ class Agent(BaseAgent): ) # Don't allow a tool to be called if it failed last time - if last_function_failed and self.tool_rules_solver.last_tool_name: - allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.last_tool_name] + if last_function_failed and self.tool_rules_solver.tool_call_history: + allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.tool_call_history[-1]] if not allowed_functions: return None @@ -773,6 +776,11 @@ class Agent(BaseAgent): **kwargs, ) -> LettaUsageStatistics: """Run Agent.step in a loop, handling chaining via heartbeat requests and function failures""" + # Defensively clear the tool rules solver history + # Usually this would be extraneous as Agent loop is re-loaded on every message send + # But just to be safe + self.tool_rules_solver.clear_tool_history() + next_input_message = messages if isinstance(messages, list) else [messages] counter = 0 total_usage = UsageStatistics() diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 73d1196f..4f0510de 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -20,7 +20,15 @@ from letta.schemas.letta_message_content import ( ) from letta.schemas.llm_config import LLMConfig from letta.schemas.message import ToolReturn -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule +from letta.schemas.tool_rule import ( + ChildToolRule, + ConditionalToolRule, + ContinueToolRule, + InitToolRule, + MaxCountPerStepToolRule, + TerminalToolRule, + ToolRule, +) # -------------------------- # LLMConfig Serialization @@ -85,23 +93,27 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu return [deserialize_tool_rule(rule_data) for rule_data in data] -def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule]: +def deserialize_tool_rule( + data: Dict, +) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule]: """Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'.""" rule_type = ToolRuleType(data.get("type")) - if rule_type == ToolRuleType.run_first or rule_type == ToolRuleType.InitToolRule: + if rule_type == ToolRuleType.run_first: data["type"] = ToolRuleType.run_first return InitToolRule(**data) - elif rule_type == ToolRuleType.exit_loop or rule_type == ToolRuleType.TerminalToolRule: + elif rule_type == ToolRuleType.exit_loop: data["type"] = ToolRuleType.exit_loop return TerminalToolRule(**data) - elif rule_type == ToolRuleType.constrain_child_tools or rule_type == ToolRuleType.ToolRule: + elif rule_type == ToolRuleType.constrain_child_tools: data["type"] = ToolRuleType.constrain_child_tools return ChildToolRule(**data) elif rule_type == ToolRuleType.conditional: return ConditionalToolRule(**data) elif rule_type == ToolRuleType.continue_loop: return ContinueToolRule(**data) + elif rule_type == ToolRuleType.max_count_per_step: + return MaxCountPerStepToolRule(**data) raise ValueError(f"Unknown ToolRule type: {rule_type}") diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index ca885616..4572bc90 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,10 +1,17 @@ -import json -from typing import List, Optional, Union +from typing import List, Optional, Set, Union from pydantic import BaseModel, Field from letta.schemas.enums import ToolRuleType -from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ( + BaseToolRule, + ChildToolRule, + ConditionalToolRule, + ContinueToolRule, + InitToolRule, + MaxCountPerStepToolRule, + TerminalToolRule, +) class ToolRuleValidationError(Exception): @@ -21,13 +28,15 @@ class ToolRulesSolver(BaseModel): continue_tool_rules: List[ContinueToolRule] = Field( default_factory=list, description="Continue tool rules to be used to continue tool execution." ) - tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field( + # TODO: This should be renamed? + # TODO: These are tools that control the set of allowed functions in the next turn + child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field( default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." ) terminal_tool_rules: List[TerminalToolRule] = Field( default_factory=list, description="Terminal tool rules that end the agent loop if called." ) - last_tool_name: Optional[str] = Field(None, description="The most recent tool used, updated with each tool call.") + tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.") def __init__(self, tool_rules: List[BaseToolRule], **kwargs): super().__init__(**kwargs) @@ -38,45 +47,60 @@ class ToolRulesSolver(BaseModel): self.init_tool_rules.append(rule) elif rule.type == ToolRuleType.constrain_child_tools: assert isinstance(rule, ChildToolRule) - self.tool_rules.append(rule) + self.child_based_tool_rules.append(rule) elif rule.type == ToolRuleType.conditional: assert isinstance(rule, ConditionalToolRule) self.validate_conditional_tool(rule) - self.tool_rules.append(rule) + self.child_based_tool_rules.append(rule) elif rule.type == ToolRuleType.exit_loop: assert isinstance(rule, TerminalToolRule) self.terminal_tool_rules.append(rule) elif rule.type == ToolRuleType.continue_loop: assert isinstance(rule, ContinueToolRule) self.continue_tool_rules.append(rule) + elif rule.type == ToolRuleType.max_count_per_step: + assert isinstance(rule, MaxCountPerStepToolRule) + self.child_based_tool_rules.append(rule) def update_tool_usage(self, tool_name: str): - """Update the internal state to track the last tool called.""" - self.last_tool_name = tool_name + """Update the internal state to track tool call history.""" + self.tool_call_history.append(tool_name) - def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]: + def clear_tool_history(self): + """Clear the history of tool calls.""" + self.tool_call_history.clear() + + def get_allowed_tool_names( + self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None + ) -> List[str]: """Get a list of tool names allowed based on the last tool called.""" - if self.last_tool_name is None: - # Use initial tool rules if no tool has been called yet - return [rule.tool_name for rule in self.init_tool_rules] + # TODO: This piece of code here is quite ugly and deserves a refactor + # TODO: There's some weird logic encoded here: + # TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules + # TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based + # TODO: -> Tool rules should probably be refactored to take in a set of tool names? + # If no tool has been called yet, return InitToolRules additively + if not self.tool_call_history: + if self.init_tool_rules: + # If there are init tool rules, only return those defined in the init tool rules + return [rule.tool_name for rule in self.init_tool_rules] + else: + # Otherwise, return all the available tools + return list(available_tools) else: - # Find a matching ToolRule for the last tool used - current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None) + # Collect valid tools from all child-based rules + valid_tool_sets = [ + rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response) + for rule in self.child_based_tool_rules + ] - if current_rule is None: - if error_on_empty: - raise ValueError(f"No tool rule found for {self.last_tool_name}") - return [] + # Compute intersection of all valid tool sets + final_allowed_tools = set.intersection(*valid_tool_sets) if valid_tool_sets else available_tools - # If the current rule is a conditional tool rule, use the LLM response to - # determine which child tool to use - if isinstance(current_rule, ConditionalToolRule): - if not last_function_response: - raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use") - next_tool = self.evaluate_conditional_tool(current_rule, last_function_response) - return [next_tool] if next_tool else [] + if error_on_empty and not final_allowed_tools: + raise ValueError("No valid tools found based on tool rules.") - return current_rule.children if current_rule.children else [] + return list(final_allowed_tools) def is_terminal_tool(self, tool_name: str) -> bool: """Check if the tool is defined as a terminal tool in the terminal tool rules.""" @@ -84,7 +108,7 @@ class ToolRulesSolver(BaseModel): def has_children_tools(self, tool_name): """Check if the tool has children tools""" - return any(rule.tool_name == tool_name for rule in self.tool_rules) + return any(rule.tool_name == tool_name for rule in self.child_based_tool_rules) def is_continue_tool(self, tool_name): """Check if the tool is defined as a continue tool in the tool rules.""" @@ -103,47 +127,3 @@ class ToolRulesSolver(BaseModel): if len(rule.child_output_mapping) == 0: raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") return True - - def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str: - """ - Parse function response to determine which child tool to use based on the mapping - - Args: - tool (ConditionalToolRule): The conditional tool rule - last_function_response (str): The function response in JSON format - - Returns: - str: The name of the child tool to use next - """ - json_response = json.loads(last_function_response) - function_output = json_response["message"] - - # Try to match the function output with a mapping key - for key in tool.child_output_mapping: - - # Convert function output to match key type for comparison - if isinstance(key, bool): - typed_output = function_output.lower() == "true" - elif isinstance(key, int): - try: - typed_output = int(function_output) - except (ValueError, TypeError): - continue - elif isinstance(key, float): - try: - typed_output = float(function_output) - except (ValueError, TypeError): - continue - else: # string - if function_output == "True" or function_output == "False": - typed_output = function_output.lower() - elif function_output == "None": - typed_output = None - else: - typed_output = function_output - - if typed_output == key: - return tool.child_output_mapping[key] - - # If no match found, use default - return tool.default_child diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 1852aa5d..9fde25cd 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -47,8 +47,4 @@ class ToolRuleType(str, Enum): continue_loop = "continue_loop" conditional = "conditional" constrain_child_tools = "constrain_child_tools" - require_parent_tools = "require_parent_tools" - # Deprecated - InitToolRule = "InitToolRule" - TerminalToolRule = "TerminalToolRule" - ToolRule = "ToolRule" + max_count_per_step = "max_count_per_step" diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index e0065e68..37158063 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -1,4 +1,5 @@ -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +import json +from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union from pydantic import Field @@ -11,6 +12,9 @@ class BaseToolRule(LettaBase): tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.") type: ToolRuleType = Field(..., description="The type of the message.") + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> set[str]: + raise NotImplementedError + class ChildToolRule(BaseToolRule): """ @@ -20,6 +24,10 @@ class ChildToolRule(BaseToolRule): type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools children: List[str] = Field(..., description="The children tools that can be invoked.") + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: + last_tool = tool_call_history[-1] if tool_call_history else None + return set(self.children) if last_tool == self.tool_name else available_tools + class ConditionalToolRule(BaseToolRule): """ @@ -31,6 +39,50 @@ class ConditionalToolRule(BaseToolRule): child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping") require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case") + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: + """Determine valid tools based on function output mapping.""" + if not tool_call_history or tool_call_history[-1] != self.tool_name: + return available_tools # No constraints if this rule doesn't apply + + if not last_function_response: + raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use") + + try: + json_response = json.loads(last_function_response) + function_output = json_response.get("message", "") + except json.JSONDecodeError: + if self.require_output_mapping: + return set() # Strict mode: Invalid response means no allowed tools + return {self.default_child} if self.default_child else available_tools + + # Match function output to a mapped child tool + for key, tool in self.child_output_mapping.items(): + if self._matches_key(function_output, key): + return {tool} + + # If no match found, use default or allow all tools if no default is set + if self.require_output_mapping: + return set() # Strict mode: No match means no valid tools + + return {self.default_child} if self.default_child else available_tools + + def _matches_key(self, function_output: str, key: Any) -> bool: + """Helper function to determine if function output matches a mapping key.""" + if isinstance(key, bool): + return function_output.lower() == "true" if key else function_output.lower() == "false" + elif isinstance(key, int): + try: + return int(function_output) == key + except ValueError: + return False + elif isinstance(key, float): + try: + return float(function_output) == key + except ValueError: + return False + else: # Assume string + return str(function_output) == str(key) + class InitToolRule(BaseToolRule): """ @@ -56,7 +108,26 @@ class ContinueToolRule(BaseToolRule): type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop +class MaxCountPerStepToolRule(BaseToolRule): + """ + Represents a tool rule configuration which constrains the total number of times this tool can be invoked in a single step. + """ + + type: Literal[ToolRuleType.max_count_per_step] = ToolRuleType.max_count_per_step + max_count_limit: int = Field(..., description="The max limit for the total number of times this tool can be invoked in a single step.") + + def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: + """Restricts the tool if it has been called max_count_limit times in the current step.""" + count = tool_call_history.count(self.tool_name) + + # If the tool has been used max_count_limit times, it is no longer allowed + if count >= self.max_count_limit: + return available_tools - {self.tool_name} + + return available_tools + + ToolRule = Annotated[ - Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule], + Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule], Field(discriminator="type"), ] diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 9c931ee2..6e17bd92 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -5,14 +5,14 @@ import pytest from letta import create_client from letta.schemas.letta_message import ToolCallMessage -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule from tests.helpers.endpoints_helper import ( assert_invoked_function_call, assert_invoked_send_message_with_keyword, assert_sanity_checks, setup_agent, ) -from tests.helpers.utils import cleanup +from tests.helpers.utils import cleanup, retry_until_success # Generate uuid for agent name for this example namespace = uuid.NAMESPACE_DNS @@ -85,25 +85,6 @@ def flip_coin(): return "hj2hwibbqm" -def flip_coin_hard(): - """ - Call this to retrieve the password to the secret word, which you will need to output in a send_message later. - If it returns an empty string, try flipping again! - - Returns: - str: The password or an empty string - """ - import random - - # Flip a coin with 50% chance - result = random.random() - if result < 0.5: - return "" - if result < 0.75: - return "START_OVER" - return "hj2hwibbqm" - - def can_play_game(): """ Call this to start the tool chain. @@ -345,320 +326,243 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none): cleanup(client=client, agent_uuid=agent_uuid) -@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_agent_conditional_tool_easy(mock_e2b_api_key_none): - """ - Test the agent with a conditional tool that has a child tool. - - Tool Flow: - - ------- - | | - | v - -- flip_coin - | - v - reveal_secret_word - """ - - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - coin_flip_name = "flip_coin" - secret_word_tool = "fourth_secret_word" - flip_coin_tool = client.create_or_update_tool(flip_coin) - reveal_secret = client.create_or_update_tool(fourth_secret_word) - - # Make tool rules - tool_rules = [ - InitToolRule(tool_name=coin_flip_name), - ConditionalToolRule( - tool_name=coin_flip_name, - default_child=coin_flip_name, - child_output_mapping={ - "hj2hwibbqm": secret_word_tool, - }, - ), - TerminalToolRule(tool_name=secret_word_tool), - ] - tools = [flip_coin_tool, reveal_secret] - - config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word") - - # Make checks - assert_sanity_checks(response) - - # Assert the tools were called - assert_invoked_function_call(response.messages, "flip_coin") - assert_invoked_function_call(response.messages, "fourth_secret_word") - - # Check ordering of tool calls - found_secret_word = False - for m in response.messages: - if isinstance(m, ToolCallMessage): - if m.tool_call.name == secret_word_tool: - # Should be the last tool call - found_secret_word = True - else: - # Before finding secret_word, only flip_coin should be called - assert m.tool_call.name == coin_flip_name - assert not found_secret_word - - # Ensure we found the secret word exactly once - assert found_secret_word - - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) +# @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely +# def test_agent_conditional_tool_easy(mock_e2b_api_key_none): +# """ +# Test the agent with a conditional tool that has a child tool. +# +# Tool Flow: +# +# ------- +# | | +# | v +# -- flip_coin +# | +# v +# reveal_secret_word +# """ +# +# client = create_client() +# cleanup(client=client, agent_uuid=agent_uuid) +# +# coin_flip_name = "flip_coin" +# secret_word_tool = "fourth_secret_word" +# flip_coin_tool = client.create_or_update_tool(flip_coin) +# reveal_secret = client.create_or_update_tool(fourth_secret_word) +# +# # Make tool rules +# tool_rules = [ +# InitToolRule(tool_name=coin_flip_name), +# ConditionalToolRule( +# tool_name=coin_flip_name, +# default_child=coin_flip_name, +# child_output_mapping={ +# "hj2hwibbqm": secret_word_tool, +# }, +# ), +# TerminalToolRule(tool_name=secret_word_tool), +# ] +# tools = [flip_coin_tool, reveal_secret] +# +# config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" +# agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) +# response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word") +# +# # Make checks +# assert_sanity_checks(response) +# +# # Assert the tools were called +# assert_invoked_function_call(response.messages, "flip_coin") +# assert_invoked_function_call(response.messages, "fourth_secret_word") +# +# # Check ordering of tool calls +# found_secret_word = False +# for m in response.messages: +# if isinstance(m, ToolCallMessage): +# if m.tool_call.name == secret_word_tool: +# # Should be the last tool call +# found_secret_word = True +# else: +# # Before finding secret_word, only flip_coin should be called +# assert m.tool_call.name == coin_flip_name +# assert not found_secret_word +# +# # Ensure we found the secret word exactly once +# assert found_secret_word +# +# print(f"Got successful response from client: \n\n{response}") +# cleanup(client=client, agent_uuid=agent_uuid) -@pytest.mark.timeout(90) # Longer timeout since this test has more steps -def test_agent_conditional_tool_hard(mock_e2b_api_key_none): - """ - Test the agent with a complex conditional tool graph - - Tool Flow: - - can_play_game <---+ - | | - v | - flip_coin -----+ - | - v - fourth_secret_word - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - play_game = "can_play_game" - coin_flip_name = "flip_coin_hard" - final_tool = "fourth_secret_word" - play_game_tool = client.create_or_update_tool(can_play_game) - flip_coin_tool = client.create_or_update_tool(flip_coin_hard) - reveal_secret = client.create_or_update_tool(fourth_secret_word) - - # Make tool rules - chain them together with conditional rules - tool_rules = [ - InitToolRule(tool_name=play_game), - ConditionalToolRule( - tool_name=play_game, - default_child=play_game, # Keep trying if we can't play - child_output_mapping={True: coin_flip_name}, # Only allow access when can_play_game returns True - ), - ConditionalToolRule( - tool_name=coin_flip_name, default_child=coin_flip_name, child_output_mapping={"hj2hwibbqm": final_tool, "START_OVER": play_game} - ), - TerminalToolRule(tool_name=final_tool), - ] - - # Setup agent with all tools - tools = [play_game_tool, flip_coin_tool, reveal_secret] - config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # Ask agent to try to get all secret words - response = client.user_message(agent_id=agent_state.id, message="hi") - - # Make checks - assert_sanity_checks(response) - - # Assert all tools were called - assert_invoked_function_call(response.messages, play_game) - assert_invoked_function_call(response.messages, final_tool) - - # Check ordering of tool calls - found_words = [] - for m in response.messages: - if isinstance(m, ToolCallMessage): - name = m.tool_call.name - if name in [play_game, coin_flip_name]: - # Before finding secret_word, only can_play_game and flip_coin should be called - assert name in [play_game, coin_flip_name] - else: - # Should find secret words in order - expected_word = final_tool - assert name == expected_word, f"Found {name} but expected {expected_word}" - found_words.append(name) - - # Ensure we found all secret words in order - assert found_words == [final_tool] - - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) +# @pytest.mark.timeout(60) +# def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none): +# """ +# Test the agent with a conditional tool that allows any child tool to be called if a function returns None. +# +# Tool Flow: +# +# return_none +# | +# v +# any tool... <-- When output doesn't match mapping, agent can call any tool +# """ +# client = create_client() +# cleanup(client=client, agent_uuid=agent_uuid) +# +# # Create tools - we'll make several available to the agent +# tool_name = "return_none" +# +# tool = client.create_or_update_tool(return_none) +# secret_word = client.create_or_update_tool(first_secret_word) +# +# # Make tool rules - only map one output, let others be free choice +# tool_rules = [ +# InitToolRule(tool_name=tool_name), +# ConditionalToolRule( +# tool_name=tool_name, +# default_child=None, # Allow any tool to be called if output doesn't match +# child_output_mapping={"anything but none": "first_secret_word"}, +# ), +# ] +# tools = [tool, secret_word] +# +# # Setup agent with all tools +# agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) +# +# # Ask agent to try different tools based on the game output +# response = client.user_message(agent_id=agent_state.id, message="call a function, any function. then call send_message") +# +# # Make checks +# assert_sanity_checks(response) +# +# # Assert return_none was called +# assert_invoked_function_call(response.messages, tool_name) +# +# # Assert any base function called afterward +# found_any_tool = False +# found_return_none = False +# for m in response.messages: +# if isinstance(m, ToolCallMessage): +# if m.tool_call.name == tool_name: +# found_return_none = True +# elif found_return_none and m.tool_call.name: +# found_any_tool = True +# break +# +# assert found_any_tool, "Should have called any tool after return_none" +# +# print(f"Got successful response from client: \n\n{response}") +# cleanup(client=client, agent_uuid=agent_uuid) -@pytest.mark.timeout(60) -def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none): - """ - Test the agent with a conditional tool that allows any child tool to be called if a function returns None. - - Tool Flow: - - return_none - | - v - any tool... <-- When output doesn't match mapping, agent can call any tool - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - we'll make several available to the agent - tool_name = "return_none" - - tool = client.create_or_update_tool(return_none) - secret_word = client.create_or_update_tool(first_secret_word) - - # Make tool rules - only map one output, let others be free choice - tool_rules = [ - InitToolRule(tool_name=tool_name), - ConditionalToolRule( - tool_name=tool_name, - default_child=None, # Allow any tool to be called if output doesn't match - child_output_mapping={"anything but none": "first_secret_word"}, - ), - ] - tools = [tool, secret_word] - - # Setup agent with all tools - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # Ask agent to try different tools based on the game output - response = client.user_message(agent_id=agent_state.id, message="call a function, any function. then call send_message") - - # Make checks - assert_sanity_checks(response) - - # Assert return_none was called - assert_invoked_function_call(response.messages, tool_name) - - # Assert any base function called afterward - found_any_tool = False - found_return_none = False - for m in response.messages: - if isinstance(m, ToolCallMessage): - if m.tool_call.name == tool_name: - found_return_none = True - elif found_return_none and m.tool_call.name: - found_any_tool = True - break - - assert found_any_tool, "Should have called any tool after return_none" - - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) +# @pytest.mark.timeout(60) +# def test_agent_reload_remembers_function_response(mock_e2b_api_key_none): +# """ +# Test that when an agent is reloaded, it remembers the last function response for conditional tool chaining. +# +# Tool Flow: +# +# flip_coin +# | +# v +# fourth_secret_word <-- Should remember coin flip result after reload +# """ +# client = create_client() +# cleanup(client=client, agent_uuid=agent_uuid) +# +# # Create tools +# flip_coin_name = "flip_coin" +# secret_word = "fourth_secret_word" +# flip_coin_tool = client.create_or_update_tool(flip_coin) +# secret_word_tool = client.create_or_update_tool(fourth_secret_word) +# +# # Make tool rules - map coin flip to fourth_secret_word +# tool_rules = [ +# InitToolRule(tool_name=flip_coin_name), +# ConditionalToolRule( +# tool_name=flip_coin_name, +# default_child=flip_coin_name, # Allow any tool to be called if output doesn't match +# child_output_mapping={"hj2hwibbqm": secret_word}, +# ), +# TerminalToolRule(tool_name=secret_word), +# ] +# tools = [flip_coin_tool, secret_word_tool] +# +# # Setup initial agent +# agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) +# +# # Call flip_coin first +# response = client.user_message(agent_id=agent_state.id, message="flip a coin") +# assert_invoked_function_call(response.messages, flip_coin_name) +# assert_invoked_function_call(response.messages, secret_word) +# found_fourth_secret = False +# for m in response.messages: +# if isinstance(m, ToolCallMessage) and m.tool_call.name == secret_word: +# found_fourth_secret = True +# break +# +# assert found_fourth_secret, "Reloaded agent should remember coin flip result and call fourth_secret_word if True" +# +# # Reload the agent +# reloaded_agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user) +# assert reloaded_agent.last_function_response is not None +# +# print(f"Got successful response from client: \n\n{response}") +# cleanup(client=client, agent_uuid=agent_uuid) -@pytest.mark.timeout(60) -def test_agent_reload_remembers_function_response(mock_e2b_api_key_none): - """ - Test that when an agent is reloaded, it remembers the last function response for conditional tool chaining. - - Tool Flow: - - flip_coin - | - v - fourth_secret_word <-- Should remember coin flip result after reload - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - flip_coin_name = "flip_coin" - secret_word = "fourth_secret_word" - flip_coin_tool = client.create_or_update_tool(flip_coin) - secret_word_tool = client.create_or_update_tool(fourth_secret_word) - - # Make tool rules - map coin flip to fourth_secret_word - tool_rules = [ - InitToolRule(tool_name=flip_coin_name), - ConditionalToolRule( - tool_name=flip_coin_name, - default_child=flip_coin_name, # Allow any tool to be called if output doesn't match - child_output_mapping={"hj2hwibbqm": secret_word}, - ), - TerminalToolRule(tool_name=secret_word), - ] - tools = [flip_coin_tool, secret_word_tool] - - # Setup initial agent - agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) - - # Call flip_coin first - response = client.user_message(agent_id=agent_state.id, message="flip a coin") - assert_invoked_function_call(response.messages, flip_coin_name) - assert_invoked_function_call(response.messages, secret_word) - found_fourth_secret = False - for m in response.messages: - if isinstance(m, ToolCallMessage) and m.tool_call.name == secret_word: - found_fourth_secret = True - break - - assert found_fourth_secret, "Reloaded agent should remember coin flip result and call fourth_secret_word if True" - - # Reload the agent - reloaded_agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user) - assert reloaded_agent.last_function_response is not None - - print(f"Got successful response from client: \n\n{response}") - cleanup(client=client, agent_uuid=agent_uuid) - - -@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely -def test_simple_tool_rule(mock_e2b_api_key_none): - """ - Test a simple tool rule where fourth_secret_word must be called after flip_coin. - - Tool Flow: - flip_coin - | - v - fourth_secret_word - """ - client = create_client() - cleanup(client=client, agent_uuid=agent_uuid) - - # Create tools - flip_coin_name = "flip_coin" - secret_word = "fourth_secret_word" - random_tool = "can_play_game" - flip_coin_tool = client.create_or_update_tool(flip_coin) - secret_word_tool = client.create_or_update_tool(fourth_secret_word) - another_secret_word_tool = client.create_or_update_tool(first_secret_word) - random_tool = client.create_or_update_tool(can_play_game) - tools = [flip_coin_tool, secret_word_tool, another_secret_word_tool, random_tool] - - # Create tool rule: after flip_coin, must call fourth_secret_word - tool_rule = ConditionalToolRule( - tool_name=flip_coin_name, - default_child=secret_word, - child_output_mapping={"*": secret_word}, - ) - - # Set up agent with the tool rule - agent_state = setup_agent( - client, config_file, agent_uuid, tool_rules=[tool_rule], tool_ids=[t.id for t in tools], include_base_tools=False - ) - - # Start conversation - response = client.user_message(agent_id=agent_state.id, message="Help me test the tools.") - - # Verify the tool calls - tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] - assert len(tool_calls) >= 2 # Should have at least flip_coin and fourth_secret_word calls - assert_invoked_function_call(response.messages, flip_coin_name) - assert_invoked_function_call(response.messages, secret_word) - - # Find the flip_coin call - flip_coin_call = next((call for call in tool_calls if call.tool_call.name == "flip_coin"), None) - - # Verify that fourth_secret_word was called after flip_coin - flip_coin_call_index = tool_calls.index(flip_coin_call) - assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin" - - cleanup(client, agent_uuid=agent_state.id) +# @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely +# def test_simple_tool_rule(mock_e2b_api_key_none): +# """ +# Test a simple tool rule where fourth_secret_word must be called after flip_coin. +# +# Tool Flow: +# flip_coin +# | +# v +# fourth_secret_word +# """ +# client = create_client() +# cleanup(client=client, agent_uuid=agent_uuid) +# +# # Create tools +# flip_coin_name = "flip_coin" +# secret_word = "fourth_secret_word" +# flip_coin_tool = client.create_or_update_tool(flip_coin) +# secret_word_tool = client.create_or_update_tool(fourth_secret_word) +# another_secret_word_tool = client.create_or_update_tool(first_secret_word) +# random_tool = client.create_or_update_tool(can_play_game) +# tools = [flip_coin_tool, secret_word_tool, another_secret_word_tool, random_tool] +# +# # Create tool rule: after flip_coin, must call fourth_secret_word +# tool_rule = ConditionalToolRule( +# tool_name=flip_coin_name, +# default_child=secret_word, +# child_output_mapping={"*": secret_word}, +# ) +# +# # Set up agent with the tool rule +# agent_state = setup_agent( +# client, config_file, agent_uuid, tool_rules=[tool_rule], tool_ids=[t.id for t in tools], include_base_tools=False +# ) +# +# # Start conversation +# response = client.user_message(agent_id=agent_state.id, message="Help me test the tools.") +# +# # Verify the tool calls +# tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] +# assert len(tool_calls) >= 2 # Should have at least flip_coin and fourth_secret_word calls +# assert_invoked_function_call(response.messages, flip_coin_name) +# assert_invoked_function_call(response.messages, secret_word) +# +# # Find the flip_coin call +# flip_coin_call = next((call for call in tool_calls if call.tool_call.name == "flip_coin"), None) +# +# # Verify that fourth_secret_word was called after flip_coin +# flip_coin_call_index = tool_calls.index(flip_coin_call) +# assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin" +# +# cleanup(client, agent_uuid=agent_state.id) def test_init_tool_rule_always_fails_one_tool(): @@ -768,3 +672,56 @@ def test_continue_tool_rule(): if call.tool_call.name == "core_memory_append": core_memory_append_call_index = i assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append" + + +@pytest.mark.timeout(60) +@retry_until_success(max_attempts=3, sleep_time_seconds=2) +def test_max_count_per_step_tool_rule_integration(mock_e2b_api_key_none): + """ + Test an agent with MaxCountPerStepToolRule to ensure a tool can only be called a limited number of times. + + Tool Flow: + repeatable_tool (max 2 times) + | + v + send_message + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools + repeatable_tool_name = "first_secret_word" + final_tool_name = "send_message" + + repeatable_tool = client.create_or_update_tool(first_secret_word) + send_message_tool = client.get_tool(client.get_tool_id(final_tool_name)) # Assume send_message is a default tool + + # Define tool rules + tool_rules = [ + InitToolRule(tool_name=repeatable_tool_name), + MaxCountPerStepToolRule(tool_name=repeatable_tool_name, max_count_limit=2), + TerminalToolRule(tool_name=final_tool_name), + ] + + tools = [repeatable_tool, send_message_tool] + + # Setup agent + agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules) + + # Start conversation + response = client.user_message( + agent_id=agent_state.id, message=f"Keep calling {repeatable_tool_name} nonstop without calling ANY other tool." + ) + + # Make checks + assert_sanity_checks(response) + + # Ensure the repeatable tool is only called twice + count = sum(1 for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == repeatable_tool_name) + assert count == 2, f"Expected 'first_secret_word' to be called exactly 2 times, but got {count}" + + # Ensure send_message was eventually called + assert_invoked_function_call(response.messages, final_tool_name) + + print(f"Got successful response from client: \n\n{response}") + cleanup(client=client, agent_uuid=agent_uuid) diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index dcb66e1b..a8a86011 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -2,7 +2,7 @@ import pytest from letta.helpers import ToolRulesSolver from letta.helpers.tool_rule_solver import ToolRuleValidationError -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule # Constants for tool names used in the tests START_TOOL = "start_tool" @@ -15,145 +15,163 @@ UNRECOGNIZED_TOOL = "unrecognized_tool" def test_get_allowed_tool_names_with_init_rules(): - # Setup: Initial tool rule configuration init_rule_1 = InitToolRule(tool_name=START_TOOL) init_rule_2 = InitToolRule(tool_name=PREP_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule_1, init_rule_2], tool_rules=[], terminal_tool_rules=[]) + solver = ToolRulesSolver(tool_rules=[init_rule_1, init_rule_2]) - # Action: Get allowed tool names when no tool has been called - allowed_tools = solver.get_allowed_tool_names() + allowed_tools = solver.get_allowed_tool_names(set()) - # Assert: Both init tools should be allowed initially assert allowed_tools == [START_TOOL, PREP_TOOL], "Should allow only InitToolRule tools at the start" def test_get_allowed_tool_names_with_subsequent_rule(): - # Setup: Tool rule sequence init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL]) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[]) + solver = ToolRulesSolver(tool_rules=[init_rule, rule_1]) - # Action: Update usage and get allowed tools solver.update_tool_usage(START_TOOL) - allowed_tools = solver.get_allowed_tool_names() + allowed_tools = solver.get_allowed_tool_names({START_TOOL, NEXT_TOOL, HELPER_TOOL}) - # Assert: Only children of "start_tool" should be allowed - assert allowed_tools == [NEXT_TOOL, HELPER_TOOL], "Should allow only children of the last tool used" + assert sorted(allowed_tools) == sorted([NEXT_TOOL, HELPER_TOOL]), "Should allow only children of the last tool used" def test_is_terminal_tool(): - # Setup: Terminal tool rule configuration init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(tool_rules=[init_rule, terminal_rule]) - # Action & Assert: Verify terminal and non-terminal tools assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool" assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool" -def test_get_allowed_tool_names_no_matching_rule_warning(): - # Setup: Tool rules with no matching rule for the last tool - init_rule = InitToolRule(tool_name=START_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) - - # Action: Set last tool to an unrecognized tool and check warnings - solver.update_tool_usage(UNRECOGNIZED_TOOL) - - # # NOTE: removed for now since this warning is getting triggered on every LLM call - # with warnings.catch_warnings(record=True) as w: - # allowed_tools = solver.get_allowed_tool_names() - - # # Assert: Expecting a warning and an empty list of allowed tools - # assert len(w) == 1, "Expected a warning for no matching rule" - # assert "resolved to no more possible tool calls" in str(w[-1].message) - # assert allowed_tools == [], "Should return an empty list if no matching rule" - - def test_get_allowed_tool_names_no_matching_rule_error(): - # Setup: Tool rules with no matching rule for the last tool init_rule = InitToolRule(tool_name=START_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[]) + solver = ToolRulesSolver(tool_rules=[init_rule]) - # Action & Assert: Set last tool to an unrecognized tool and expect ValueError solver.update_tool_usage(UNRECOGNIZED_TOOL) - with pytest.raises(ValueError, match=f"No tool rule found for {UNRECOGNIZED_TOOL}"): - solver.get_allowed_tool_names(error_on_empty=True) + with pytest.raises(ValueError, match=f"No valid tools found based on tool rules."): + solver.get_allowed_tool_names(set(), error_on_empty=True) def test_update_tool_usage_and_get_allowed_tool_names_combined(): - # Setup: More complex rule chaining init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL]) terminal_rule = TerminalToolRule(tool_name=FINAL_TOOL) - solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule]) + solver = ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, terminal_rule]) - # Step 1: Initially allowed tools - assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'" + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initial allowed tool should be 'start_tool'" - # Step 2: After using 'start_tool' solver.update_tool_usage(START_TOOL) - assert solver.get_allowed_tool_names() == [NEXT_TOOL], "After 'start_tool', should allow 'next_tool'" + assert solver.get_allowed_tool_names({NEXT_TOOL}) == [NEXT_TOOL], "After 'start_tool', should allow 'next_tool'" - # Step 3: After using 'next_tool' solver.update_tool_usage(NEXT_TOOL) - assert solver.get_allowed_tool_names() == [FINAL_TOOL], "After 'next_tool', should allow 'final_tool'" + assert solver.get_allowed_tool_names({FINAL_TOOL}) == [FINAL_TOOL], "After 'next_tool', should allow 'final_tool'" - # Step 4: 'final_tool' should be terminal assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal" def test_conditional_tool_rule(): - # Setup: Define a conditional tool rule init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) rule = ConditionalToolRule(tool_name=START_TOOL, default_child=None, child_output_mapping={True: END_TOOL, False: START_TOOL}) solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule]) - # Action & Assert: Verify the rule properties - # Step 1: Initially allowed tools - assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'" + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initial allowed tool should be 'start_tool'" - # Step 2: After using 'start_tool' solver.update_tool_usage(START_TOOL) - assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [ + assert solver.get_allowed_tool_names({END_TOOL}, last_function_response='{"message": "true"}') == [ END_TOOL ], "After 'start_tool' returns true, should allow 'end_tool'" - assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [ + assert solver.get_allowed_tool_names({START_TOOL}, last_function_response='{"message": "false"}') == [ START_TOOL ], "After 'start_tool' returns false, should allow 'start_tool'" - # Step 3: After using 'end_tool' assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal" def test_invalid_conditional_tool_rule(): - # Setup: Define an invalid conditional tool rule init_rule = InitToolRule(tool_name=START_TOOL) terminal_rule = TerminalToolRule(tool_name=END_TOOL) invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={}) - # Test 1: Missing child output mapping with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."): ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule]) def test_tool_rules_with_invalid_path(): - # Setup: Define tool rules with both connected, disconnected nodes and a cycle init_rule = InitToolRule(tool_name=START_TOOL) rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL]) rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL]) - rule_3 = ChildToolRule(tool_name=HELPER_TOOL, children=[START_TOOL]) # This creates a cycle: start -> next -> helper -> start - rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here + rule_3 = ChildToolRule(tool_name=HELPER_TOOL, children=[START_TOOL]) + rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) terminal_rule = TerminalToolRule(tool_name=END_TOOL) ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule]) - # Now: add a path from the start tool to the final tool rule_5 = ConditionalToolRule( tool_name=HELPER_TOOL, default_child=FINAL_TOOL, child_output_mapping={True: START_TOOL, False: FINAL_TOOL}, ) ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule]) + + +def test_max_count_per_step_tool_rule(): + init_rule = InitToolRule(tool_name=START_TOOL) + rule_1 = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) + solver = ToolRulesSolver(tool_rules=[init_rule, rule_1]) + + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initially should allow 'start_tool'" + + solver.update_tool_usage(START_TOOL) + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'" + + solver.update_tool_usage(START_TOOL) + assert solver.get_allowed_tool_names({START_TOOL}) == [], "After reaching max count, 'start_tool' should no longer be allowed" + + +def test_max_count_per_step_tool_rule_allows_usage_up_to_limit(): + """Ensure the tool is allowed exactly max_count_limit times.""" + rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=3) + solver = ToolRulesSolver(tool_rules=[rule]) + + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initially should allow 'start_tool'" + + solver.update_tool_usage(START_TOOL) + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 1 use" + + solver.update_tool_usage(START_TOOL) + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses" + + solver.update_tool_usage(START_TOOL) + assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should no longer allow 'start_tool' after 3 uses" + + +def test_max_count_per_step_tool_rule_does_not_affect_other_tools(): + """Ensure exceeding max count for one tool does not impact others.""" + rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) + another_tool_rules = ChildToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL]) + solver = ToolRulesSolver(tool_rules=[rule, another_tool_rules]) + + solver.update_tool_usage(START_TOOL) + solver.update_tool_usage(START_TOOL) + + assert sorted(solver.get_allowed_tool_names({START_TOOL, NEXT_TOOL, HELPER_TOOL})) == sorted( + [NEXT_TOOL, HELPER_TOOL] + ), "Other tools should still be allowed even if 'start_tool' is over limit" + + +def test_max_count_per_step_tool_rule_resets_on_clear(): + """Ensure clearing tool history resets the rule's limit.""" + rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2) + solver = ToolRulesSolver(tool_rules=[rule]) + + solver.update_tool_usage(START_TOOL) + solver.update_tool_usage(START_TOOL) + + assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should not allow 'start_tool' after reaching limit" + + solver.clear_tool_history() + + assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should allow 'start_tool' again after clearing history"