Files
letta-server/letta/helpers/tool_rule_solver.py
2025-06-27 16:23:55 -07:00

251 lines
11 KiB
Python

from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field
from letta.schemas.block import Block
from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
ChildToolRule,
ConditionalToolRule,
ContinueToolRule,
InitToolRule,
MaxCountPerStepToolRule,
ParentToolRule,
RequiredBeforeExitToolRule,
TerminalToolRule,
)
class ToolRuleValidationError(Exception):
"""Custom exception for tool rule validation errors in ToolRulesSolver."""
def __init__(self, message: str):
super().__init__(f"ToolRuleValidationError: {message}")
class ToolRulesSolver(BaseModel):
init_tool_rules: List[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
)
continue_tool_rules: List[ContinueToolRule] = Field(
default_factory=list, description="Continue tool rules to be used to continue tool execution."
)
# TODO: This should be renamed?
# TODO: These are tools that control the set of allowed functions in the next turn
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
)
parent_tool_rules: List[ParentToolRule] = Field(
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set."
)
terminal_tool_rules: List[TerminalToolRule] = Field(
default_factory=list, description="Terminal tool rules that end the agent loop if called."
)
required_before_exit_tool_rules: List[RequiredBeforeExitToolRule] = Field(
default_factory=list, description="Tool rules that must be called before the agent can exit."
)
tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
def __init__(
self,
tool_rules: Optional[List[BaseToolRule]] = None,
init_tool_rules: Optional[List[InitToolRule]] = None,
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
parent_tool_rules: Optional[List[ParentToolRule]] = None,
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
required_before_exit_tool_rules: Optional[List[RequiredBeforeExitToolRule]] = None,
tool_call_history: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
init_tool_rules=init_tool_rules or [],
continue_tool_rules=continue_tool_rules or [],
child_based_tool_rules=child_based_tool_rules or [],
parent_tool_rules=parent_tool_rules or [],
terminal_tool_rules=terminal_tool_rules or [],
required_before_exit_tool_rules=required_before_exit_tool_rules or [],
tool_call_history=tool_call_history or [],
**kwargs,
)
if tool_rules:
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)
elif rule.type == ToolRuleType.continue_loop:
assert isinstance(rule, ContinueToolRule)
self.continue_tool_rules.append(rule)
elif rule.type == ToolRuleType.max_count_per_step:
assert isinstance(rule, MaxCountPerStepToolRule)
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.parent_last_tool:
assert isinstance(rule, ParentToolRule)
self.parent_tool_rules.append(rule)
elif rule.type == ToolRuleType.required_before_exit:
assert isinstance(rule, RequiredBeforeExitToolRule)
self.required_before_exit_tool_rules.append(rule)
def register_tool_call(self, tool_name: str):
"""Update the internal state to track tool call history."""
self.tool_call_history.append(tool_name)
def clear_tool_history(self):
"""Clear the history of tool calls."""
self.tool_call_history.clear()
def get_allowed_tool_names(
self, available_tools: Set[str], error_on_empty: bool = False, last_function_response: Optional[str] = None
) -> List[str]:
"""Get a list of tool names allowed based on the last tool called."""
# TODO: This piece of code here is quite ugly and deserves a refactor
# TODO: There's some weird logic encoded here:
# TODO: -> This only takes into consideration Init, and a set of Child/Conditional/MaxSteps tool rules
# TODO: -> Init tool rules outputs are treated additively, Child/Conditional/MaxSteps are intersection based
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
# If no tool has been called yet, return InitToolRules additively
if not self.tool_call_history:
if self.init_tool_rules:
# If there are init tool rules, only return those defined in the init tool rules
return [rule.tool_name for rule in self.init_tool_rules]
else:
# Otherwise, return all tools besides those constrained by parent tool rules
available_tools = available_tools - set.union(set(), *(set(rule.children) for rule in self.parent_tool_rules))
return list(available_tools)
else:
# Collect valid tools from all child-based rules
valid_tool_sets = []
for rule in self.child_based_tool_rules + self.parent_tool_rules:
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
valid_tool_sets.append(tools)
# Compute intersection of all valid tool sets
final_allowed_tools = set.intersection(*valid_tool_sets) if valid_tool_sets else available_tools
if error_on_empty and not final_allowed_tools:
raise ValueError("No valid tools found based on tool rules.")
return list(final_allowed_tools)
def is_terminal_tool(self, tool_name: str) -> bool:
"""Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules."""
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
def has_children_tools(self, tool_name):
"""Check if the tool has children tools"""
return any(rule.tool_name == tool_name for rule in self.child_based_tool_rules)
def is_continue_tool(self, tool_name):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def has_required_tools_been_called(self, available_tools: Set[str]) -> bool:
"""Check if all required-before-exit tools have been called."""
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
def get_uncalled_required_tools(self, available_tools: Set[str]) -> List[str]:
"""Get the list of required-before-exit tools that have not been called yet."""
if not self.required_before_exit_tool_rules:
return [] # No required tools means no uncalled tools
required_tool_names = {rule.tool_name for rule in self.required_before_exit_tool_rules}
called_tool_names = set(self.tool_call_history)
# Get required tools that are uncalled AND available
return list((required_tool_names & available_tools) - called_tool_names)
def get_ending_tool_names(self) -> List[str]:
"""Get the names of tools that are required before exit."""
return [rule.tool_name for rule in self.required_before_exit_tool_rules]
def compile_tool_rule_prompts(self) -> Optional[Block]:
"""
Compile prompt templates from all tool rules into an ephemeral Block.
Returns:
Optional[str]: Compiled prompt string with tool rule constraints, or None if no templates exist.
"""
compiled_prompts = []
all_rules = (
self.init_tool_rules
+ self.continue_tool_rules
+ self.child_based_tool_rules
+ self.parent_tool_rules
+ self.terminal_tool_rules
)
for rule in all_rules:
rendered = rule.render_prompt()
if rendered:
compiled_prompts.append(rendered)
if compiled_prompts:
return Block(
label="tool_usage_rules",
value="\n".join(compiled_prompts),
description="The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls.",
)
return None
def guess_rule_violation(self, tool_name: str) -> List[str]:
"""
Check if the given tool name or the previous tool in history matches any tool rule,
and return rendered prompt templates for matching rules.
Args:
tool_name: The name of the tool to check for rule violations
Returns:
List of rendered prompt templates from matching tool rules
"""
violated_rules = []
# Get the previous tool from history if it exists
previous_tool = self.tool_call_history[-1] if self.tool_call_history else None
# Check all tool rules for matches
all_rules = (
self.init_tool_rules
+ self.continue_tool_rules
+ self.child_based_tool_rules
+ self.parent_tool_rules
+ self.terminal_tool_rules
)
for rule in all_rules:
# Check if the current tool name or previous tool matches this rule's tool_name
if rule.tool_name == tool_name or (previous_tool and rule.tool_name == previous_tool):
rendered_prompt = rule.render_prompt()
if rendered_prompt:
violated_rules.append(rendered_prompt)
return violated_rules
@staticmethod
def validate_conditional_tool(rule: ConditionalToolRule):
"""
Validate a conditional tool rule
Args:
rule (ConditionalToolRule): The conditional tool rule to validate
Raises:
ToolRuleValidationError: If the rule is invalid
"""
if len(rule.child_output_mapping) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
return True