feat: Add MaxCountPerStepToolRule (#1319)
This commit is contained in:
@@ -367,7 +367,10 @@ class Agent(BaseAgent):
|
||||
) -> ChatCompletionResponse:
|
||||
"""Get response from LLM API with robust retry mechanism."""
|
||||
log_telemetry(self.logger, "_get_ai_reply start")
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
|
||||
available_tools = set([t.name for t in self.agent_state.tools])
|
||||
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(
|
||||
available_tools=available_tools, last_function_response=self.last_function_response
|
||||
)
|
||||
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]
|
||||
|
||||
allowed_functions = (
|
||||
@@ -377,8 +380,8 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
# Don't allow a tool to be called if it failed last time
|
||||
if last_function_failed and self.tool_rules_solver.last_tool_name:
|
||||
allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.last_tool_name]
|
||||
if last_function_failed and self.tool_rules_solver.tool_call_history:
|
||||
allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.tool_call_history[-1]]
|
||||
if not allowed_functions:
|
||||
return None
|
||||
|
||||
@@ -773,6 +776,11 @@ class Agent(BaseAgent):
|
||||
**kwargs,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
|
||||
# Defensively clear the tool rules solver history
|
||||
# Usually this would be extraneous as Agent loop is re-loaded on every message send
|
||||
# But just to be safe
|
||||
self.tool_rules_solver.clear_tool_history()
|
||||
|
||||
next_input_message = messages if isinstance(messages, list) else [messages]
|
||||
counter = 0
|
||||
total_usage = UsageStatistics()
|
||||
|
||||
@@ -20,7 +20,15 @@ from letta.schemas.letta_message_content import (
|
||||
)
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import ToolReturn
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
ContinueToolRule,
|
||||
InitToolRule,
|
||||
MaxCountPerStepToolRule,
|
||||
TerminalToolRule,
|
||||
ToolRule,
|
||||
)
|
||||
|
||||
# --------------------------
|
||||
# LLMConfig Serialization
|
||||
@@ -85,23 +93,27 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu
|
||||
return [deserialize_tool_rule(rule_data) for rule_data in data]
|
||||
|
||||
|
||||
def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule]:
|
||||
def deserialize_tool_rule(
|
||||
data: Dict,
|
||||
) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule]:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type"))
|
||||
|
||||
if rule_type == ToolRuleType.run_first or rule_type == ToolRuleType.InitToolRule:
|
||||
if rule_type == ToolRuleType.run_first:
|
||||
data["type"] = ToolRuleType.run_first
|
||||
return InitToolRule(**data)
|
||||
elif rule_type == ToolRuleType.exit_loop or rule_type == ToolRuleType.TerminalToolRule:
|
||||
elif rule_type == ToolRuleType.exit_loop:
|
||||
data["type"] = ToolRuleType.exit_loop
|
||||
return TerminalToolRule(**data)
|
||||
elif rule_type == ToolRuleType.constrain_child_tools or rule_type == ToolRuleType.ToolRule:
|
||||
elif rule_type == ToolRuleType.constrain_child_tools:
|
||||
data["type"] = ToolRuleType.constrain_child_tools
|
||||
return ChildToolRule(**data)
|
||||
elif rule_type == ToolRuleType.conditional:
|
||||
return ConditionalToolRule(**data)
|
||||
elif rule_type == ToolRuleType.continue_loop:
|
||||
return ContinueToolRule(**data)
|
||||
elif rule_type == ToolRuleType.max_count_per_step:
|
||||
return MaxCountPerStepToolRule(**data)
|
||||
raise ValueError(f"Unknown ToolRule type: {rule_type}")
|
||||
|
||||
|
||||
|
||||
@@ -1,10 +1,17 @@
|
||||
import json
|
||||
from typing import List, Optional, Union
|
||||
from typing import List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import (
|
||||
BaseToolRule,
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
ContinueToolRule,
|
||||
InitToolRule,
|
||||
MaxCountPerStepToolRule,
|
||||
TerminalToolRule,
|
||||
)
|
||||
|
||||
|
||||
class ToolRuleValidationError(Exception):
|
||||
@@ -21,13 +28,15 @@ class ToolRulesSolver(BaseModel):
|
||||
continue_tool_rules: List[ContinueToolRule] = Field(
|
||||
default_factory=list, description="Continue tool rules to be used to continue tool execution."
|
||||
)
|
||||
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
|
||||
# TODO: This should be renamed?
|
||||
# TODO: These are tools that control the set of allowed functions in the next turn
|
||||
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
|
||||
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
||||
)
|
||||
terminal_tool_rules: List[TerminalToolRule] = Field(
|
||||
default_factory=list, description="Terminal tool rules that end the agent loop if called."
|
||||
)
|
||||
last_tool_name: Optional[str] = Field(None, description="The most recent tool used, updated with each tool call.")
|
||||
tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
||||
|
||||
def __init__(self, tool_rules: List[BaseToolRule], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
@@ -38,45 +47,60 @@ class ToolRulesSolver(BaseModel):
|
||||
self.init_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.constrain_child_tools:
|
||||
assert isinstance(rule, ChildToolRule)
|
||||
self.tool_rules.append(rule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.conditional:
|
||||
assert isinstance(rule, ConditionalToolRule)
|
||||
self.validate_conditional_tool(rule)
|
||||
self.tool_rules.append(rule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.exit_loop:
|
||||
assert isinstance(rule, TerminalToolRule)
|
||||
self.terminal_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.continue_loop:
|
||||
assert isinstance(rule, ContinueToolRule)
|
||||
self.continue_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.max_count_per_step:
|
||||
assert isinstance(rule, MaxCountPerStepToolRule)
|
||||
self.child_based_tool_rules.append(rule)
|
||||
|
||||
def update_tool_usage(self, tool_name: str):
|
||||
"""Update the internal state to track the last tool called."""
|
||||
self.last_tool_name = tool_name
|
||||
"""Update the internal state to track tool call history."""
|
||||
self.tool_call_history.append(tool_name)
|
||||
|
||||
def get_allowed_tool_names(self, error_on_empty: bool = False, last_function_response: Optional[str] = None) -> List[str]:
|
||||
def clear_tool_history(self):
|
||||
"""Clear the history of tool calls."""
|
||||
self.tool_call_history.clear()
|
||||
|
||||
def get_allowed_tool_names(
|
||||
self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None
|
||||
) -> List[str]:
|
||||
"""Get a list of tool names allowed based on the last tool called."""
|
||||
if self.last_tool_name is None:
|
||||
# Use initial tool rules if no tool has been called yet
|
||||
return [rule.tool_name for rule in self.init_tool_rules]
|
||||
# TODO: This piece of code here is quite ugly and deserves a refactor
|
||||
# TODO: There's some weird logic encoded here:
|
||||
# TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules
|
||||
# TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based
|
||||
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
|
||||
# If no tool has been called yet, return InitToolRules additively
|
||||
if not self.tool_call_history:
|
||||
if self.init_tool_rules:
|
||||
# If there are init tool rules, only return those defined in the init tool rules
|
||||
return [rule.tool_name for rule in self.init_tool_rules]
|
||||
else:
|
||||
# Otherwise, return all the available tools
|
||||
return list(available_tools)
|
||||
else:
|
||||
# Find a matching ToolRule for the last tool used
|
||||
current_rule = next((rule for rule in self.tool_rules if rule.tool_name == self.last_tool_name), None)
|
||||
# Collect valid tools from all child-based rules
|
||||
valid_tool_sets = [
|
||||
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
||||
for rule in self.child_based_tool_rules
|
||||
]
|
||||
|
||||
if current_rule is None:
|
||||
if error_on_empty:
|
||||
raise ValueError(f"No tool rule found for {self.last_tool_name}")
|
||||
return []
|
||||
# Compute intersection of all valid tool sets
|
||||
final_allowed_tools = set.intersection(*valid_tool_sets) if valid_tool_sets else available_tools
|
||||
|
||||
# If the current rule is a conditional tool rule, use the LLM response to
|
||||
# determine which child tool to use
|
||||
if isinstance(current_rule, ConditionalToolRule):
|
||||
if not last_function_response:
|
||||
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
|
||||
next_tool = self.evaluate_conditional_tool(current_rule, last_function_response)
|
||||
return [next_tool] if next_tool else []
|
||||
if error_on_empty and not final_allowed_tools:
|
||||
raise ValueError("No valid tools found based on tool rules.")
|
||||
|
||||
return current_rule.children if current_rule.children else []
|
||||
return list(final_allowed_tools)
|
||||
|
||||
def is_terminal_tool(self, tool_name: str) -> bool:
|
||||
"""Check if the tool is defined as a terminal tool in the terminal tool rules."""
|
||||
@@ -84,7 +108,7 @@ class ToolRulesSolver(BaseModel):
|
||||
|
||||
def has_children_tools(self, tool_name):
|
||||
"""Check if the tool has children tools"""
|
||||
return any(rule.tool_name == tool_name for rule in self.tool_rules)
|
||||
return any(rule.tool_name == tool_name for rule in self.child_based_tool_rules)
|
||||
|
||||
def is_continue_tool(self, tool_name):
|
||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||
@@ -103,47 +127,3 @@ class ToolRulesSolver(BaseModel):
|
||||
if len(rule.child_output_mapping) == 0:
|
||||
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
|
||||
return True
|
||||
|
||||
def evaluate_conditional_tool(self, tool: ConditionalToolRule, last_function_response: str) -> str:
|
||||
"""
|
||||
Parse function response to determine which child tool to use based on the mapping
|
||||
|
||||
Args:
|
||||
tool (ConditionalToolRule): The conditional tool rule
|
||||
last_function_response (str): The function response in JSON format
|
||||
|
||||
Returns:
|
||||
str: The name of the child tool to use next
|
||||
"""
|
||||
json_response = json.loads(last_function_response)
|
||||
function_output = json_response["message"]
|
||||
|
||||
# Try to match the function output with a mapping key
|
||||
for key in tool.child_output_mapping:
|
||||
|
||||
# Convert function output to match key type for comparison
|
||||
if isinstance(key, bool):
|
||||
typed_output = function_output.lower() == "true"
|
||||
elif isinstance(key, int):
|
||||
try:
|
||||
typed_output = int(function_output)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
elif isinstance(key, float):
|
||||
try:
|
||||
typed_output = float(function_output)
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
else: # string
|
||||
if function_output == "True" or function_output == "False":
|
||||
typed_output = function_output.lower()
|
||||
elif function_output == "None":
|
||||
typed_output = None
|
||||
else:
|
||||
typed_output = function_output
|
||||
|
||||
if typed_output == key:
|
||||
return tool.child_output_mapping[key]
|
||||
|
||||
# If no match found, use default
|
||||
return tool.default_child
|
||||
|
||||
@@ -47,8 +47,4 @@ class ToolRuleType(str, Enum):
|
||||
continue_loop = "continue_loop"
|
||||
conditional = "conditional"
|
||||
constrain_child_tools = "constrain_child_tools"
|
||||
require_parent_tools = "require_parent_tools"
|
||||
# Deprecated
|
||||
InitToolRule = "InitToolRule"
|
||||
TerminalToolRule = "TerminalToolRule"
|
||||
ToolRule = "ToolRule"
|
||||
max_count_per_step = "max_count_per_step"
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
import json
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -11,6 +12,9 @@ class BaseToolRule(LettaBase):
|
||||
tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.")
|
||||
type: ToolRuleType = Field(..., description="The type of the message.")
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class ChildToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -20,6 +24,10 @@ class ChildToolRule(BaseToolRule):
|
||||
type: Literal[ToolRuleType.constrain_child_tools] = ToolRuleType.constrain_child_tools
|
||||
children: List[str] = Field(..., description="The children tools that can be invoked.")
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
last_tool = tool_call_history[-1] if tool_call_history else None
|
||||
return set(self.children) if last_tool == self.tool_name else available_tools
|
||||
|
||||
|
||||
class ConditionalToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -31,6 +39,50 @@ class ConditionalToolRule(BaseToolRule):
|
||||
child_output_mapping: Dict[Any, str] = Field(..., description="The output case to check for mapping")
|
||||
require_output_mapping: bool = Field(default=False, description="Whether to throw an error when output doesn't match any case")
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
"""Determine valid tools based on function output mapping."""
|
||||
if not tool_call_history or tool_call_history[-1] != self.tool_name:
|
||||
return available_tools # No constraints if this rule doesn't apply
|
||||
|
||||
if not last_function_response:
|
||||
raise ValueError("Conditional tool rule requires an LLM response to determine which child tool to use")
|
||||
|
||||
try:
|
||||
json_response = json.loads(last_function_response)
|
||||
function_output = json_response.get("message", "")
|
||||
except json.JSONDecodeError:
|
||||
if self.require_output_mapping:
|
||||
return set() # Strict mode: Invalid response means no allowed tools
|
||||
return {self.default_child} if self.default_child else available_tools
|
||||
|
||||
# Match function output to a mapped child tool
|
||||
for key, tool in self.child_output_mapping.items():
|
||||
if self._matches_key(function_output, key):
|
||||
return {tool}
|
||||
|
||||
# If no match found, use default or allow all tools if no default is set
|
||||
if self.require_output_mapping:
|
||||
return set() # Strict mode: No match means no valid tools
|
||||
|
||||
return {self.default_child} if self.default_child else available_tools
|
||||
|
||||
def _matches_key(self, function_output: str, key: Any) -> bool:
|
||||
"""Helper function to determine if function output matches a mapping key."""
|
||||
if isinstance(key, bool):
|
||||
return function_output.lower() == "true" if key else function_output.lower() == "false"
|
||||
elif isinstance(key, int):
|
||||
try:
|
||||
return int(function_output) == key
|
||||
except ValueError:
|
||||
return False
|
||||
elif isinstance(key, float):
|
||||
try:
|
||||
return float(function_output) == key
|
||||
except ValueError:
|
||||
return False
|
||||
else: # Assume string
|
||||
return str(function_output) == str(key)
|
||||
|
||||
|
||||
class InitToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -56,7 +108,26 @@ class ContinueToolRule(BaseToolRule):
|
||||
type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop
|
||||
|
||||
|
||||
class MaxCountPerStepToolRule(BaseToolRule):
|
||||
"""
|
||||
Represents a tool rule configuration which constrains the total number of times this tool can be invoked in a single step.
|
||||
"""
|
||||
|
||||
type: Literal[ToolRuleType.max_count_per_step] = ToolRuleType.max_count_per_step
|
||||
max_count_limit: int = Field(..., description="The max limit for the total number of times this tool can be invoked in a single step.")
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
"""Restricts the tool if it has been called max_count_limit times in the current step."""
|
||||
count = tool_call_history.count(self.tool_name)
|
||||
|
||||
# If the tool has been used max_count_limit times, it is no longer allowed
|
||||
if count >= self.max_count_limit:
|
||||
return available_tools - {self.tool_name}
|
||||
|
||||
return available_tools
|
||||
|
||||
|
||||
ToolRule = Annotated[
|
||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule],
|
||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule, MaxCountPerStepToolRule],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@@ -5,14 +5,14 @@ import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
assert_sanity_checks,
|
||||
setup_agent,
|
||||
)
|
||||
from tests.helpers.utils import cleanup
|
||||
from tests.helpers.utils import cleanup, retry_until_success
|
||||
|
||||
# Generate uuid for agent name for this example
|
||||
namespace = uuid.NAMESPACE_DNS
|
||||
@@ -85,25 +85,6 @@ def flip_coin():
|
||||
return "hj2hwibbqm"
|
||||
|
||||
|
||||
def flip_coin_hard():
|
||||
"""
|
||||
Call this to retrieve the password to the secret word, which you will need to output in a send_message later.
|
||||
If it returns an empty string, try flipping again!
|
||||
|
||||
Returns:
|
||||
str: The password or an empty string
|
||||
"""
|
||||
import random
|
||||
|
||||
# Flip a coin with 50% chance
|
||||
result = random.random()
|
||||
if result < 0.5:
|
||||
return ""
|
||||
if result < 0.75:
|
||||
return "START_OVER"
|
||||
return "hj2hwibbqm"
|
||||
|
||||
|
||||
def can_play_game():
|
||||
"""
|
||||
Call this to start the tool chain.
|
||||
@@ -345,320 +326,243 @@ def test_agent_no_structured_output_with_one_child_tool(mock_e2b_api_key_none):
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
|
||||
def test_agent_conditional_tool_easy(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test the agent with a conditional tool that has a child tool.
|
||||
|
||||
Tool Flow:
|
||||
|
||||
-------
|
||||
| |
|
||||
| v
|
||||
-- flip_coin
|
||||
|
|
||||
v
|
||||
reveal_secret_word
|
||||
"""
|
||||
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
coin_flip_name = "flip_coin"
|
||||
secret_word_tool = "fourth_secret_word"
|
||||
flip_coin_tool = client.create_or_update_tool(flip_coin)
|
||||
reveal_secret = client.create_or_update_tool(fourth_secret_word)
|
||||
|
||||
# Make tool rules
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=coin_flip_name),
|
||||
ConditionalToolRule(
|
||||
tool_name=coin_flip_name,
|
||||
default_child=coin_flip_name,
|
||||
child_output_mapping={
|
||||
"hj2hwibbqm": secret_word_tool,
|
||||
},
|
||||
),
|
||||
TerminalToolRule(tool_name=secret_word_tool),
|
||||
]
|
||||
tools = [flip_coin_tool, reveal_secret]
|
||||
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Assert the tools were called
|
||||
assert_invoked_function_call(response.messages, "flip_coin")
|
||||
assert_invoked_function_call(response.messages, "fourth_secret_word")
|
||||
|
||||
# Check ordering of tool calls
|
||||
found_secret_word = False
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
if m.tool_call.name == secret_word_tool:
|
||||
# Should be the last tool call
|
||||
found_secret_word = True
|
||||
else:
|
||||
# Before finding secret_word, only flip_coin should be called
|
||||
assert m.tool_call.name == coin_flip_name
|
||||
assert not found_secret_word
|
||||
|
||||
# Ensure we found the secret word exactly once
|
||||
assert found_secret_word
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
# @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
|
||||
# def test_agent_conditional_tool_easy(mock_e2b_api_key_none):
|
||||
# """
|
||||
# Test the agent with a conditional tool that has a child tool.
|
||||
#
|
||||
# Tool Flow:
|
||||
#
|
||||
# -------
|
||||
# | |
|
||||
# | v
|
||||
# -- flip_coin
|
||||
# |
|
||||
# v
|
||||
# reveal_secret_word
|
||||
# """
|
||||
#
|
||||
# client = create_client()
|
||||
# cleanup(client=client, agent_uuid=agent_uuid)
|
||||
#
|
||||
# coin_flip_name = "flip_coin"
|
||||
# secret_word_tool = "fourth_secret_word"
|
||||
# flip_coin_tool = client.create_or_update_tool(flip_coin)
|
||||
# reveal_secret = client.create_or_update_tool(fourth_secret_word)
|
||||
#
|
||||
# # Make tool rules
|
||||
# tool_rules = [
|
||||
# InitToolRule(tool_name=coin_flip_name),
|
||||
# ConditionalToolRule(
|
||||
# tool_name=coin_flip_name,
|
||||
# default_child=coin_flip_name,
|
||||
# child_output_mapping={
|
||||
# "hj2hwibbqm": secret_word_tool,
|
||||
# },
|
||||
# ),
|
||||
# TerminalToolRule(tool_name=secret_word_tool),
|
||||
# ]
|
||||
# tools = [flip_coin_tool, reveal_secret]
|
||||
#
|
||||
# config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
# agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
# response = client.user_message(agent_id=agent_state.id, message="flip a coin until you get the secret word")
|
||||
#
|
||||
# # Make checks
|
||||
# assert_sanity_checks(response)
|
||||
#
|
||||
# # Assert the tools were called
|
||||
# assert_invoked_function_call(response.messages, "flip_coin")
|
||||
# assert_invoked_function_call(response.messages, "fourth_secret_word")
|
||||
#
|
||||
# # Check ordering of tool calls
|
||||
# found_secret_word = False
|
||||
# for m in response.messages:
|
||||
# if isinstance(m, ToolCallMessage):
|
||||
# if m.tool_call.name == secret_word_tool:
|
||||
# # Should be the last tool call
|
||||
# found_secret_word = True
|
||||
# else:
|
||||
# # Before finding secret_word, only flip_coin should be called
|
||||
# assert m.tool_call.name == coin_flip_name
|
||||
# assert not found_secret_word
|
||||
#
|
||||
# # Ensure we found the secret word exactly once
|
||||
# assert found_secret_word
|
||||
#
|
||||
# print(f"Got successful response from client: \n\n{response}")
|
||||
# cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(90) # Longer timeout since this test has more steps
|
||||
def test_agent_conditional_tool_hard(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test the agent with a complex conditional tool graph
|
||||
|
||||
Tool Flow:
|
||||
|
||||
can_play_game <---+
|
||||
| |
|
||||
v |
|
||||
flip_coin -----+
|
||||
|
|
||||
v
|
||||
fourth_secret_word
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools
|
||||
play_game = "can_play_game"
|
||||
coin_flip_name = "flip_coin_hard"
|
||||
final_tool = "fourth_secret_word"
|
||||
play_game_tool = client.create_or_update_tool(can_play_game)
|
||||
flip_coin_tool = client.create_or_update_tool(flip_coin_hard)
|
||||
reveal_secret = client.create_or_update_tool(fourth_secret_word)
|
||||
|
||||
# Make tool rules - chain them together with conditional rules
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=play_game),
|
||||
ConditionalToolRule(
|
||||
tool_name=play_game,
|
||||
default_child=play_game, # Keep trying if we can't play
|
||||
child_output_mapping={True: coin_flip_name}, # Only allow access when can_play_game returns True
|
||||
),
|
||||
ConditionalToolRule(
|
||||
tool_name=coin_flip_name, default_child=coin_flip_name, child_output_mapping={"hj2hwibbqm": final_tool, "START_OVER": play_game}
|
||||
),
|
||||
TerminalToolRule(tool_name=final_tool),
|
||||
]
|
||||
|
||||
# Setup agent with all tools
|
||||
tools = [play_game_tool, flip_coin_tool, reveal_secret]
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Ask agent to try to get all secret words
|
||||
response = client.user_message(agent_id=agent_state.id, message="hi")
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Assert all tools were called
|
||||
assert_invoked_function_call(response.messages, play_game)
|
||||
assert_invoked_function_call(response.messages, final_tool)
|
||||
|
||||
# Check ordering of tool calls
|
||||
found_words = []
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
name = m.tool_call.name
|
||||
if name in [play_game, coin_flip_name]:
|
||||
# Before finding secret_word, only can_play_game and flip_coin should be called
|
||||
assert name in [play_game, coin_flip_name]
|
||||
else:
|
||||
# Should find secret words in order
|
||||
expected_word = final_tool
|
||||
assert name == expected_word, f"Found {name} but expected {expected_word}"
|
||||
found_words.append(name)
|
||||
|
||||
# Ensure we found all secret words in order
|
||||
assert found_words == [final_tool]
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
# @pytest.mark.timeout(60)
|
||||
# def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none):
|
||||
# """
|
||||
# Test the agent with a conditional tool that allows any child tool to be called if a function returns None.
|
||||
#
|
||||
# Tool Flow:
|
||||
#
|
||||
# return_none
|
||||
# |
|
||||
# v
|
||||
# any tool... <-- When output doesn't match mapping, agent can call any tool
|
||||
# """
|
||||
# client = create_client()
|
||||
# cleanup(client=client, agent_uuid=agent_uuid)
|
||||
#
|
||||
# # Create tools - we'll make several available to the agent
|
||||
# tool_name = "return_none"
|
||||
#
|
||||
# tool = client.create_or_update_tool(return_none)
|
||||
# secret_word = client.create_or_update_tool(first_secret_word)
|
||||
#
|
||||
# # Make tool rules - only map one output, let others be free choice
|
||||
# tool_rules = [
|
||||
# InitToolRule(tool_name=tool_name),
|
||||
# ConditionalToolRule(
|
||||
# tool_name=tool_name,
|
||||
# default_child=None, # Allow any tool to be called if output doesn't match
|
||||
# child_output_mapping={"anything but none": "first_secret_word"},
|
||||
# ),
|
||||
# ]
|
||||
# tools = [tool, secret_word]
|
||||
#
|
||||
# # Setup agent with all tools
|
||||
# agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
#
|
||||
# # Ask agent to try different tools based on the game output
|
||||
# response = client.user_message(agent_id=agent_state.id, message="call a function, any function. then call send_message")
|
||||
#
|
||||
# # Make checks
|
||||
# assert_sanity_checks(response)
|
||||
#
|
||||
# # Assert return_none was called
|
||||
# assert_invoked_function_call(response.messages, tool_name)
|
||||
#
|
||||
# # Assert any base function called afterward
|
||||
# found_any_tool = False
|
||||
# found_return_none = False
|
||||
# for m in response.messages:
|
||||
# if isinstance(m, ToolCallMessage):
|
||||
# if m.tool_call.name == tool_name:
|
||||
# found_return_none = True
|
||||
# elif found_return_none and m.tool_call.name:
|
||||
# found_any_tool = True
|
||||
# break
|
||||
#
|
||||
# assert found_any_tool, "Should have called any tool after return_none"
|
||||
#
|
||||
# print(f"Got successful response from client: \n\n{response}")
|
||||
# cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_agent_conditional_tool_without_default_child(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test the agent with a conditional tool that allows any child tool to be called if a function returns None.
|
||||
|
||||
Tool Flow:
|
||||
|
||||
return_none
|
||||
|
|
||||
v
|
||||
any tool... <-- When output doesn't match mapping, agent can call any tool
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools - we'll make several available to the agent
|
||||
tool_name = "return_none"
|
||||
|
||||
tool = client.create_or_update_tool(return_none)
|
||||
secret_word = client.create_or_update_tool(first_secret_word)
|
||||
|
||||
# Make tool rules - only map one output, let others be free choice
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=tool_name),
|
||||
ConditionalToolRule(
|
||||
tool_name=tool_name,
|
||||
default_child=None, # Allow any tool to be called if output doesn't match
|
||||
child_output_mapping={"anything but none": "first_secret_word"},
|
||||
),
|
||||
]
|
||||
tools = [tool, secret_word]
|
||||
|
||||
# Setup agent with all tools
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Ask agent to try different tools based on the game output
|
||||
response = client.user_message(agent_id=agent_state.id, message="call a function, any function. then call send_message")
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Assert return_none was called
|
||||
assert_invoked_function_call(response.messages, tool_name)
|
||||
|
||||
# Assert any base function called afterward
|
||||
found_any_tool = False
|
||||
found_return_none = False
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage):
|
||||
if m.tool_call.name == tool_name:
|
||||
found_return_none = True
|
||||
elif found_return_none and m.tool_call.name:
|
||||
found_any_tool = True
|
||||
break
|
||||
|
||||
assert found_any_tool, "Should have called any tool after return_none"
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
# @pytest.mark.timeout(60)
|
||||
# def test_agent_reload_remembers_function_response(mock_e2b_api_key_none):
|
||||
# """
|
||||
# Test that when an agent is reloaded, it remembers the last function response for conditional tool chaining.
|
||||
#
|
||||
# Tool Flow:
|
||||
#
|
||||
# flip_coin
|
||||
# |
|
||||
# v
|
||||
# fourth_secret_word <-- Should remember coin flip result after reload
|
||||
# """
|
||||
# client = create_client()
|
||||
# cleanup(client=client, agent_uuid=agent_uuid)
|
||||
#
|
||||
# # Create tools
|
||||
# flip_coin_name = "flip_coin"
|
||||
# secret_word = "fourth_secret_word"
|
||||
# flip_coin_tool = client.create_or_update_tool(flip_coin)
|
||||
# secret_word_tool = client.create_or_update_tool(fourth_secret_word)
|
||||
#
|
||||
# # Make tool rules - map coin flip to fourth_secret_word
|
||||
# tool_rules = [
|
||||
# InitToolRule(tool_name=flip_coin_name),
|
||||
# ConditionalToolRule(
|
||||
# tool_name=flip_coin_name,
|
||||
# default_child=flip_coin_name, # Allow any tool to be called if output doesn't match
|
||||
# child_output_mapping={"hj2hwibbqm": secret_word},
|
||||
# ),
|
||||
# TerminalToolRule(tool_name=secret_word),
|
||||
# ]
|
||||
# tools = [flip_coin_tool, secret_word_tool]
|
||||
#
|
||||
# # Setup initial agent
|
||||
# agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
#
|
||||
# # Call flip_coin first
|
||||
# response = client.user_message(agent_id=agent_state.id, message="flip a coin")
|
||||
# assert_invoked_function_call(response.messages, flip_coin_name)
|
||||
# assert_invoked_function_call(response.messages, secret_word)
|
||||
# found_fourth_secret = False
|
||||
# for m in response.messages:
|
||||
# if isinstance(m, ToolCallMessage) and m.tool_call.name == secret_word:
|
||||
# found_fourth_secret = True
|
||||
# break
|
||||
#
|
||||
# assert found_fourth_secret, "Reloaded agent should remember coin flip result and call fourth_secret_word if True"
|
||||
#
|
||||
# # Reload the agent
|
||||
# reloaded_agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
# assert reloaded_agent.last_function_response is not None
|
||||
#
|
||||
# print(f"Got successful response from client: \n\n{response}")
|
||||
# cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
def test_agent_reload_remembers_function_response(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test that when an agent is reloaded, it remembers the last function response for conditional tool chaining.
|
||||
|
||||
Tool Flow:
|
||||
|
||||
flip_coin
|
||||
|
|
||||
v
|
||||
fourth_secret_word <-- Should remember coin flip result after reload
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools
|
||||
flip_coin_name = "flip_coin"
|
||||
secret_word = "fourth_secret_word"
|
||||
flip_coin_tool = client.create_or_update_tool(flip_coin)
|
||||
secret_word_tool = client.create_or_update_tool(fourth_secret_word)
|
||||
|
||||
# Make tool rules - map coin flip to fourth_secret_word
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=flip_coin_name),
|
||||
ConditionalToolRule(
|
||||
tool_name=flip_coin_name,
|
||||
default_child=flip_coin_name, # Allow any tool to be called if output doesn't match
|
||||
child_output_mapping={"hj2hwibbqm": secret_word},
|
||||
),
|
||||
TerminalToolRule(tool_name=secret_word),
|
||||
]
|
||||
tools = [flip_coin_tool, secret_word_tool]
|
||||
|
||||
# Setup initial agent
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Call flip_coin first
|
||||
response = client.user_message(agent_id=agent_state.id, message="flip a coin")
|
||||
assert_invoked_function_call(response.messages, flip_coin_name)
|
||||
assert_invoked_function_call(response.messages, secret_word)
|
||||
found_fourth_secret = False
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage) and m.tool_call.name == secret_word:
|
||||
found_fourth_secret = True
|
||||
break
|
||||
|
||||
assert found_fourth_secret, "Reloaded agent should remember coin flip result and call fourth_secret_word if True"
|
||||
|
||||
# Reload the agent
|
||||
reloaded_agent = client.server.load_agent(agent_id=agent_state.id, actor=client.user)
|
||||
assert reloaded_agent.last_function_response is not None
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
|
||||
def test_simple_tool_rule(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test a simple tool rule where fourth_secret_word must be called after flip_coin.
|
||||
|
||||
Tool Flow:
|
||||
flip_coin
|
||||
|
|
||||
v
|
||||
fourth_secret_word
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools
|
||||
flip_coin_name = "flip_coin"
|
||||
secret_word = "fourth_secret_word"
|
||||
random_tool = "can_play_game"
|
||||
flip_coin_tool = client.create_or_update_tool(flip_coin)
|
||||
secret_word_tool = client.create_or_update_tool(fourth_secret_word)
|
||||
another_secret_word_tool = client.create_or_update_tool(first_secret_word)
|
||||
random_tool = client.create_or_update_tool(can_play_game)
|
||||
tools = [flip_coin_tool, secret_word_tool, another_secret_word_tool, random_tool]
|
||||
|
||||
# Create tool rule: after flip_coin, must call fourth_secret_word
|
||||
tool_rule = ConditionalToolRule(
|
||||
tool_name=flip_coin_name,
|
||||
default_child=secret_word,
|
||||
child_output_mapping={"*": secret_word},
|
||||
)
|
||||
|
||||
# Set up agent with the tool rule
|
||||
agent_state = setup_agent(
|
||||
client, config_file, agent_uuid, tool_rules=[tool_rule], tool_ids=[t.id for t in tools], include_base_tools=False
|
||||
)
|
||||
|
||||
# Start conversation
|
||||
response = client.user_message(agent_id=agent_state.id, message="Help me test the tools.")
|
||||
|
||||
# Verify the tool calls
|
||||
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
|
||||
assert len(tool_calls) >= 2 # Should have at least flip_coin and fourth_secret_word calls
|
||||
assert_invoked_function_call(response.messages, flip_coin_name)
|
||||
assert_invoked_function_call(response.messages, secret_word)
|
||||
|
||||
# Find the flip_coin call
|
||||
flip_coin_call = next((call for call in tool_calls if call.tool_call.name == "flip_coin"), None)
|
||||
|
||||
# Verify that fourth_secret_word was called after flip_coin
|
||||
flip_coin_call_index = tool_calls.index(flip_coin_call)
|
||||
assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin"
|
||||
|
||||
cleanup(client, agent_uuid=agent_state.id)
|
||||
# @pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
|
||||
# def test_simple_tool_rule(mock_e2b_api_key_none):
|
||||
# """
|
||||
# Test a simple tool rule where fourth_secret_word must be called after flip_coin.
|
||||
#
|
||||
# Tool Flow:
|
||||
# flip_coin
|
||||
# |
|
||||
# v
|
||||
# fourth_secret_word
|
||||
# """
|
||||
# client = create_client()
|
||||
# cleanup(client=client, agent_uuid=agent_uuid)
|
||||
#
|
||||
# # Create tools
|
||||
# flip_coin_name = "flip_coin"
|
||||
# secret_word = "fourth_secret_word"
|
||||
# flip_coin_tool = client.create_or_update_tool(flip_coin)
|
||||
# secret_word_tool = client.create_or_update_tool(fourth_secret_word)
|
||||
# another_secret_word_tool = client.create_or_update_tool(first_secret_word)
|
||||
# random_tool = client.create_or_update_tool(can_play_game)
|
||||
# tools = [flip_coin_tool, secret_word_tool, another_secret_word_tool, random_tool]
|
||||
#
|
||||
# # Create tool rule: after flip_coin, must call fourth_secret_word
|
||||
# tool_rule = ConditionalToolRule(
|
||||
# tool_name=flip_coin_name,
|
||||
# default_child=secret_word,
|
||||
# child_output_mapping={"*": secret_word},
|
||||
# )
|
||||
#
|
||||
# # Set up agent with the tool rule
|
||||
# agent_state = setup_agent(
|
||||
# client, config_file, agent_uuid, tool_rules=[tool_rule], tool_ids=[t.id for t in tools], include_base_tools=False
|
||||
# )
|
||||
#
|
||||
# # Start conversation
|
||||
# response = client.user_message(agent_id=agent_state.id, message="Help me test the tools.")
|
||||
#
|
||||
# # Verify the tool calls
|
||||
# tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
|
||||
# assert len(tool_calls) >= 2 # Should have at least flip_coin and fourth_secret_word calls
|
||||
# assert_invoked_function_call(response.messages, flip_coin_name)
|
||||
# assert_invoked_function_call(response.messages, secret_word)
|
||||
#
|
||||
# # Find the flip_coin call
|
||||
# flip_coin_call = next((call for call in tool_calls if call.tool_call.name == "flip_coin"), None)
|
||||
#
|
||||
# # Verify that fourth_secret_word was called after flip_coin
|
||||
# flip_coin_call_index = tool_calls.index(flip_coin_call)
|
||||
# assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin"
|
||||
#
|
||||
# cleanup(client, agent_uuid=agent_state.id)
|
||||
|
||||
|
||||
def test_init_tool_rule_always_fails_one_tool():
|
||||
@@ -768,3 +672,56 @@ def test_continue_tool_rule():
|
||||
if call.tool_call.name == "core_memory_append":
|
||||
core_memory_append_call_index = i
|
||||
assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append"
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_max_count_per_step_tool_rule_integration(mock_e2b_api_key_none):
|
||||
"""
|
||||
Test an agent with MaxCountPerStepToolRule to ensure a tool can only be called a limited number of times.
|
||||
|
||||
Tool Flow:
|
||||
repeatable_tool (max 2 times)
|
||||
|
|
||||
v
|
||||
send_message
|
||||
"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
# Create tools
|
||||
repeatable_tool_name = "first_secret_word"
|
||||
final_tool_name = "send_message"
|
||||
|
||||
repeatable_tool = client.create_or_update_tool(first_secret_word)
|
||||
send_message_tool = client.get_tool(client.get_tool_id(final_tool_name)) # Assume send_message is a default tool
|
||||
|
||||
# Define tool rules
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name=repeatable_tool_name),
|
||||
MaxCountPerStepToolRule(tool_name=repeatable_tool_name, max_count_limit=2),
|
||||
TerminalToolRule(tool_name=final_tool_name),
|
||||
]
|
||||
|
||||
tools = [repeatable_tool, send_message_tool]
|
||||
|
||||
# Setup agent
|
||||
agent_state = setup_agent(client, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Start conversation
|
||||
response = client.user_message(
|
||||
agent_id=agent_state.id, message=f"Keep calling {repeatable_tool_name} nonstop without calling ANY other tool."
|
||||
)
|
||||
|
||||
# Make checks
|
||||
assert_sanity_checks(response)
|
||||
|
||||
# Ensure the repeatable tool is only called twice
|
||||
count = sum(1 for m in response.messages if isinstance(m, ToolCallMessage) and m.tool_call.name == repeatable_tool_name)
|
||||
assert count == 2, f"Expected 'first_secret_word' to be called exactly 2 times, but got {count}"
|
||||
|
||||
# Ensure send_message was eventually called
|
||||
assert_invoked_function_call(response.messages, final_tool_name)
|
||||
|
||||
print(f"Got successful response from client: \n\n{response}")
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
@@ -2,7 +2,7 @@ import pytest
|
||||
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.tool_rule_solver import ToolRuleValidationError
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, MaxCountPerStepToolRule, TerminalToolRule
|
||||
|
||||
# Constants for tool names used in the tests
|
||||
START_TOOL = "start_tool"
|
||||
@@ -15,145 +15,163 @@ UNRECOGNIZED_TOOL = "unrecognized_tool"
|
||||
|
||||
|
||||
def test_get_allowed_tool_names_with_init_rules():
|
||||
# Setup: Initial tool rule configuration
|
||||
init_rule_1 = InitToolRule(tool_name=START_TOOL)
|
||||
init_rule_2 = InitToolRule(tool_name=PREP_TOOL)
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule_1, init_rule_2], tool_rules=[], terminal_tool_rules=[])
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule_1, init_rule_2])
|
||||
|
||||
# Action: Get allowed tool names when no tool has been called
|
||||
allowed_tools = solver.get_allowed_tool_names()
|
||||
allowed_tools = solver.get_allowed_tool_names(set())
|
||||
|
||||
# Assert: Both init tools should be allowed initially
|
||||
assert allowed_tools == [START_TOOL, PREP_TOOL], "Should allow only InitToolRule tools at the start"
|
||||
|
||||
|
||||
def test_get_allowed_tool_names_with_subsequent_rule():
|
||||
# Setup: Tool rule sequence
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL, HELPER_TOOL])
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1], terminal_tool_rules=[])
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule, rule_1])
|
||||
|
||||
# Action: Update usage and get allowed tools
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
allowed_tools = solver.get_allowed_tool_names()
|
||||
allowed_tools = solver.get_allowed_tool_names({START_TOOL, NEXT_TOOL, HELPER_TOOL})
|
||||
|
||||
# Assert: Only children of "start_tool" should be allowed
|
||||
assert allowed_tools == [NEXT_TOOL, HELPER_TOOL], "Should allow only children of the last tool used"
|
||||
assert sorted(allowed_tools) == sorted([NEXT_TOOL, HELPER_TOOL]), "Should allow only children of the last tool used"
|
||||
|
||||
|
||||
def test_is_terminal_tool():
|
||||
# Setup: Terminal tool rule configuration
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[terminal_rule])
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule, terminal_rule])
|
||||
|
||||
# Action & Assert: Verify terminal and non-terminal tools
|
||||
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as a terminal tool"
|
||||
assert solver.is_terminal_tool(START_TOOL) is False, "Should not recognize 'start_tool' as a terminal tool"
|
||||
|
||||
|
||||
def test_get_allowed_tool_names_no_matching_rule_warning():
|
||||
# Setup: Tool rules with no matching rule for the last tool
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[])
|
||||
|
||||
# Action: Set last tool to an unrecognized tool and check warnings
|
||||
solver.update_tool_usage(UNRECOGNIZED_TOOL)
|
||||
|
||||
# # NOTE: removed for now since this warning is getting triggered on every LLM call
|
||||
# with warnings.catch_warnings(record=True) as w:
|
||||
# allowed_tools = solver.get_allowed_tool_names()
|
||||
|
||||
# # Assert: Expecting a warning and an empty list of allowed tools
|
||||
# assert len(w) == 1, "Expected a warning for no matching rule"
|
||||
# assert "resolved to no more possible tool calls" in str(w[-1].message)
|
||||
# assert allowed_tools == [], "Should return an empty list if no matching rule"
|
||||
|
||||
|
||||
def test_get_allowed_tool_names_no_matching_rule_error():
|
||||
# Setup: Tool rules with no matching rule for the last tool
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[], terminal_tool_rules=[])
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule])
|
||||
|
||||
# Action & Assert: Set last tool to an unrecognized tool and expect ValueError
|
||||
solver.update_tool_usage(UNRECOGNIZED_TOOL)
|
||||
with pytest.raises(ValueError, match=f"No tool rule found for {UNRECOGNIZED_TOOL}"):
|
||||
solver.get_allowed_tool_names(error_on_empty=True)
|
||||
with pytest.raises(ValueError, match=f"No valid tools found based on tool rules."):
|
||||
solver.get_allowed_tool_names(set(), error_on_empty=True)
|
||||
|
||||
|
||||
def test_update_tool_usage_and_get_allowed_tool_names_combined():
|
||||
# Setup: More complex rule chaining
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
|
||||
rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[FINAL_TOOL])
|
||||
terminal_rule = TerminalToolRule(tool_name=FINAL_TOOL)
|
||||
solver = ToolRulesSolver(init_tool_rules=[init_rule], tool_rules=[rule_1, rule_2], terminal_tool_rules=[terminal_rule])
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, terminal_rule])
|
||||
|
||||
# Step 1: Initially allowed tools
|
||||
assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'"
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initial allowed tool should be 'start_tool'"
|
||||
|
||||
# Step 2: After using 'start_tool'
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names() == [NEXT_TOOL], "After 'start_tool', should allow 'next_tool'"
|
||||
assert solver.get_allowed_tool_names({NEXT_TOOL}) == [NEXT_TOOL], "After 'start_tool', should allow 'next_tool'"
|
||||
|
||||
# Step 3: After using 'next_tool'
|
||||
solver.update_tool_usage(NEXT_TOOL)
|
||||
assert solver.get_allowed_tool_names() == [FINAL_TOOL], "After 'next_tool', should allow 'final_tool'"
|
||||
assert solver.get_allowed_tool_names({FINAL_TOOL}) == [FINAL_TOOL], "After 'next_tool', should allow 'final_tool'"
|
||||
|
||||
# Step 4: 'final_tool' should be terminal
|
||||
assert solver.is_terminal_tool(FINAL_TOOL) is True, "Should recognize 'final_tool' as terminal"
|
||||
|
||||
|
||||
def test_conditional_tool_rule():
|
||||
# Setup: Define a conditional tool rule
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
rule = ConditionalToolRule(tool_name=START_TOOL, default_child=None, child_output_mapping={True: END_TOOL, False: START_TOOL})
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule, rule, terminal_rule])
|
||||
|
||||
# Action & Assert: Verify the rule properties
|
||||
# Step 1: Initially allowed tools
|
||||
assert solver.get_allowed_tool_names() == [START_TOOL], "Initial allowed tool should be 'start_tool'"
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initial allowed tool should be 'start_tool'"
|
||||
|
||||
# Step 2: After using 'start_tool'
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "true"}') == [
|
||||
assert solver.get_allowed_tool_names({END_TOOL}, last_function_response='{"message": "true"}') == [
|
||||
END_TOOL
|
||||
], "After 'start_tool' returns true, should allow 'end_tool'"
|
||||
assert solver.get_allowed_tool_names(last_function_response='{"message": "false"}') == [
|
||||
assert solver.get_allowed_tool_names({START_TOOL}, last_function_response='{"message": "false"}') == [
|
||||
START_TOOL
|
||||
], "After 'start_tool' returns false, should allow 'start_tool'"
|
||||
|
||||
# Step 3: After using 'end_tool'
|
||||
assert solver.is_terminal_tool(END_TOOL) is True, "Should recognize 'end_tool' as terminal"
|
||||
|
||||
|
||||
def test_invalid_conditional_tool_rule():
|
||||
# Setup: Define an invalid conditional tool rule
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
|
||||
|
||||
# Test 1: Missing child output mapping
|
||||
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):
|
||||
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule])
|
||||
|
||||
|
||||
def test_tool_rules_with_invalid_path():
|
||||
# Setup: Define tool rules with both connected, disconnected nodes and a cycle
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
rule_1 = ChildToolRule(tool_name=START_TOOL, children=[NEXT_TOOL])
|
||||
rule_2 = ChildToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL])
|
||||
rule_3 = ChildToolRule(tool_name=HELPER_TOOL, children=[START_TOOL]) # This creates a cycle: start -> next -> helper -> start
|
||||
rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL]) # Disconnected rule, no cycle here
|
||||
rule_3 = ChildToolRule(tool_name=HELPER_TOOL, children=[START_TOOL])
|
||||
rule_4 = ChildToolRule(tool_name=FINAL_TOOL, children=[END_TOOL])
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, terminal_rule])
|
||||
|
||||
# Now: add a path from the start tool to the final tool
|
||||
rule_5 = ConditionalToolRule(
|
||||
tool_name=HELPER_TOOL,
|
||||
default_child=FINAL_TOOL,
|
||||
child_output_mapping={True: START_TOOL, False: FINAL_TOOL},
|
||||
)
|
||||
ToolRulesSolver(tool_rules=[init_rule, rule_1, rule_2, rule_3, rule_4, rule_5, terminal_rule])
|
||||
|
||||
|
||||
def test_max_count_per_step_tool_rule():
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
rule_1 = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2)
|
||||
solver = ToolRulesSolver(tool_rules=[init_rule, rule_1])
|
||||
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initially should allow 'start_tool'"
|
||||
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "After first use, should still allow 'start_tool'"
|
||||
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "After reaching max count, 'start_tool' should no longer be allowed"
|
||||
|
||||
|
||||
def test_max_count_per_step_tool_rule_allows_usage_up_to_limit():
|
||||
"""Ensure the tool is allowed exactly max_count_limit times."""
|
||||
rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=3)
|
||||
solver = ToolRulesSolver(tool_rules=[rule])
|
||||
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Initially should allow 'start_tool'"
|
||||
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 1 use"
|
||||
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should still allow 'start_tool' after 2 uses"
|
||||
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should no longer allow 'start_tool' after 3 uses"
|
||||
|
||||
|
||||
def test_max_count_per_step_tool_rule_does_not_affect_other_tools():
|
||||
"""Ensure exceeding max count for one tool does not impact others."""
|
||||
rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2)
|
||||
another_tool_rules = ChildToolRule(tool_name=NEXT_TOOL, children=[HELPER_TOOL])
|
||||
solver = ToolRulesSolver(tool_rules=[rule, another_tool_rules])
|
||||
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
|
||||
assert sorted(solver.get_allowed_tool_names({START_TOOL, NEXT_TOOL, HELPER_TOOL})) == sorted(
|
||||
[NEXT_TOOL, HELPER_TOOL]
|
||||
), "Other tools should still be allowed even if 'start_tool' is over limit"
|
||||
|
||||
|
||||
def test_max_count_per_step_tool_rule_resets_on_clear():
|
||||
"""Ensure clearing tool history resets the rule's limit."""
|
||||
rule = MaxCountPerStepToolRule(tool_name=START_TOOL, max_count_limit=2)
|
||||
solver = ToolRulesSolver(tool_rules=[rule])
|
||||
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
solver.update_tool_usage(START_TOOL)
|
||||
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [], "Should not allow 'start_tool' after reaching limit"
|
||||
|
||||
solver.clear_tool_history()
|
||||
|
||||
assert solver.get_allowed_tool_names({START_TOOL}) == [START_TOOL], "Should allow 'start_tool' again after clearing history"
|
||||
|
||||
Reference in New Issue
Block a user