* Add args * Add testing to tool rule solver * Add live integration tests for args prefilling * Add args override
286 lines
13 KiB
Python
286 lines
13 KiB
Python
from typing import TypeAlias
|
|
|
|
from pydantic import BaseModel, Field
|
|
|
|
from letta.schemas.block import Block
|
|
from letta.schemas.tool_rule import (
|
|
ChildToolRule,
|
|
ConditionalToolRule,
|
|
ContinueToolRule,
|
|
InitToolRule,
|
|
MaxCountPerStepToolRule,
|
|
ParentToolRule,
|
|
RequiredBeforeExitToolRule,
|
|
RequiresApprovalToolRule,
|
|
TerminalToolRule,
|
|
ToolRule,
|
|
)
|
|
|
|
ToolName: TypeAlias = str
|
|
|
|
COMPILED_PROMPT_DESCRIPTION = "The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls."
|
|
|
|
|
|
class ToolRulesSolver(BaseModel):
|
|
tool_rules: list[ToolRule] | None = Field(default=None, description="Input list of tool rules")
|
|
|
|
# Categorized fields
|
|
init_tool_rules: list[InitToolRule] = Field(
|
|
default_factory=list, description="Initial tool rules to be used at the start of tool execution.", exclude=True
|
|
)
|
|
continue_tool_rules: list[ContinueToolRule] = Field(
|
|
default_factory=list, description="Continue tool rules to be used to continue tool execution.", exclude=True
|
|
)
|
|
# TODO: This should be renamed?
|
|
# TODO: These are tools that control the set of allowed functions in the next turn
|
|
child_based_tool_rules: list[ChildToolRule | ConditionalToolRule | MaxCountPerStepToolRule] = Field(
|
|
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions.", exclude=True
|
|
)
|
|
parent_tool_rules: list[ParentToolRule] = Field(
|
|
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set.", exclude=True
|
|
)
|
|
terminal_tool_rules: list[TerminalToolRule] = Field(
|
|
default_factory=list, description="Terminal tool rules that end the agent loop if called.", exclude=True
|
|
)
|
|
required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field(
|
|
default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True
|
|
)
|
|
requires_approval_tool_rules: list[RequiresApprovalToolRule] = Field(
|
|
default_factory=list, description="Tool rules that trigger an approval request for human-in-the-loop.", exclude=True
|
|
)
|
|
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
|
|
|
# Last-evaluated prefilled args cache (per step)
|
|
last_prefilled_args_by_tool: dict[str, dict] = Field(
|
|
default_factory=dict, description="Cached mapping of tool name to prefilled args from the last allowlist evaluation.", exclude=True
|
|
)
|
|
last_prefilled_args_provenance: dict[str, str] = Field(
|
|
default_factory=dict,
|
|
description="Cached mapping of tool name to a short description of which rule provided the prefilled args.",
|
|
exclude=True,
|
|
)
|
|
|
|
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
|
|
super().__init__(tool_rules=tool_rules, **kwargs)
|
|
|
|
def model_post_init(self, __context):
|
|
if self.tool_rules:
|
|
for rule in self.tool_rules:
|
|
if isinstance(rule, InitToolRule):
|
|
self.init_tool_rules.append(rule)
|
|
elif isinstance(rule, ChildToolRule):
|
|
self.child_based_tool_rules.append(rule)
|
|
elif isinstance(rule, ConditionalToolRule):
|
|
self.child_based_tool_rules.append(rule)
|
|
elif isinstance(rule, TerminalToolRule):
|
|
self.terminal_tool_rules.append(rule)
|
|
elif isinstance(rule, ContinueToolRule):
|
|
self.continue_tool_rules.append(rule)
|
|
elif isinstance(rule, MaxCountPerStepToolRule):
|
|
self.child_based_tool_rules.append(rule)
|
|
elif isinstance(rule, ParentToolRule):
|
|
self.parent_tool_rules.append(rule)
|
|
elif isinstance(rule, RequiredBeforeExitToolRule):
|
|
self.required_before_exit_tool_rules.append(rule)
|
|
elif isinstance(rule, RequiresApprovalToolRule):
|
|
self.requires_approval_tool_rules.append(rule)
|
|
|
|
def register_tool_call(self, tool_name: str):
|
|
"""Update the internal state to track tool call history."""
|
|
self.tool_call_history.append(tool_name)
|
|
|
|
def clear_tool_history(self):
|
|
"""Clear the history of tool calls."""
|
|
self.tool_call_history.clear()
|
|
|
|
def get_allowed_tool_names(
|
|
self, available_tools: set[ToolName], error_on_empty: bool = True, last_function_response: str | None = None
|
|
) -> list[ToolName]:
|
|
"""Get a list of tool names allowed based on the last tool called.
|
|
|
|
Side-effect: also caches any prefilled args provided by active rules into
|
|
`last_prefilled_args_by_tool` and `last_prefilled_args_provenance`.
|
|
|
|
The logic is as follows:
|
|
1. if there are no previous tool calls, and we have InitToolRules, those are the only options for the first tool call
|
|
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
|
|
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
|
|
"""
|
|
# Compute allowed tools first
|
|
if not self.tool_call_history and self.init_tool_rules:
|
|
allowed = [rule.tool_name for rule in self.init_tool_rules]
|
|
else:
|
|
valid_tool_sets = []
|
|
for rule in self.child_based_tool_rules + self.parent_tool_rules:
|
|
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
|
valid_tool_sets.append(tools)
|
|
|
|
# Compute intersection of all valid tool sets
|
|
final_allowed_tools = set.intersection(*valid_tool_sets) if valid_tool_sets else available_tools
|
|
|
|
if error_on_empty and not final_allowed_tools:
|
|
raise ValueError("No valid tools found based on tool rules.")
|
|
|
|
allowed = list(final_allowed_tools)
|
|
|
|
# Build prefilled args cache for current allowed set
|
|
args_by_tool: dict[str, dict] = {}
|
|
provenance_by_tool: dict[str, str] = {}
|
|
|
|
def _store_args(tool_name: str, args: dict, rule: BaseModel):
|
|
if not isinstance(args, dict) or len(args) == 0:
|
|
return
|
|
if tool_name not in args_by_tool:
|
|
args_by_tool[tool_name] = {}
|
|
args_by_tool[tool_name].update(args) # last-write-wins
|
|
provenance_by_tool[tool_name] = f"{rule.__class__.__name__}({getattr(rule, 'tool_name', tool_name)})"
|
|
|
|
allowed_set = set(allowed)
|
|
|
|
if not self.tool_call_history and self.init_tool_rules:
|
|
for rule in self.init_tool_rules:
|
|
if hasattr(rule, "args") and getattr(rule, "args") and rule.tool_name in allowed_set:
|
|
_store_args(rule.tool_name, getattr(rule, "args"), rule)
|
|
else:
|
|
for rule in (
|
|
self.child_based_tool_rules
|
|
+ self.parent_tool_rules
|
|
+ self.continue_tool_rules
|
|
+ self.terminal_tool_rules
|
|
+ self.required_before_exit_tool_rules
|
|
+ self.requires_approval_tool_rules
|
|
):
|
|
if hasattr(rule, "args") and getattr(rule, "args") and getattr(rule, "tool_name", None) in allowed_set:
|
|
_store_args(rule.tool_name, getattr(rule, "args"), rule)
|
|
|
|
self.last_prefilled_args_by_tool = args_by_tool
|
|
self.last_prefilled_args_provenance = provenance_by_tool
|
|
|
|
return allowed
|
|
|
|
def is_terminal_tool(self, tool_name: ToolName) -> bool:
|
|
"""Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules."""
|
|
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
|
|
|
|
def has_children_tools(self, tool_name: ToolName):
|
|
"""Check if the tool has children tools"""
|
|
return any(rule.tool_name == tool_name for rule in self.child_based_tool_rules)
|
|
|
|
def is_continue_tool(self, tool_name: ToolName):
|
|
"""Check if the tool is defined as a continue tool in the tool rules."""
|
|
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
|
|
|
def is_requires_approval_tool(self, tool_name: ToolName):
|
|
"""Check if the tool is defined as a requires-approval tool in the tool rules."""
|
|
return any(rule.tool_name == tool_name for rule in self.requires_approval_tool_rules)
|
|
|
|
def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool:
|
|
"""Check if all required-before-exit tools have been called."""
|
|
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
|
|
|
def get_requires_approval_tools(self, available_tools: set[ToolName]) -> list[ToolName]:
|
|
"""Get the list of tools that require approval."""
|
|
return [rule.tool_name for rule in self.requires_approval_tool_rules]
|
|
|
|
def get_uncalled_required_tools(self, available_tools: set[ToolName]) -> list[str]:
|
|
"""Get the list of required-before-exit tools that have not been called yet."""
|
|
if not self.required_before_exit_tool_rules:
|
|
return [] # No required tools means no uncalled tools
|
|
|
|
required_tool_names = {rule.tool_name for rule in self.required_before_exit_tool_rules}
|
|
called_tool_names = set(self.tool_call_history)
|
|
|
|
# Get required tools that are uncalled AND available
|
|
return list((required_tool_names & available_tools) - called_tool_names)
|
|
|
|
def compile_tool_rule_prompts(self) -> Block | None:
|
|
"""
|
|
Compile prompt templates from all tool rules into an ephemeral Block.
|
|
|
|
Returns:
|
|
Block | None: Compiled prompt block with tool rule constraints, or None if no templates exist.
|
|
"""
|
|
compiled_prompts = []
|
|
|
|
all_rules = (
|
|
self.init_tool_rules
|
|
+ self.continue_tool_rules
|
|
+ self.child_based_tool_rules
|
|
+ self.parent_tool_rules
|
|
+ self.terminal_tool_rules
|
|
)
|
|
|
|
for rule in all_rules:
|
|
rendered = rule.render_prompt()
|
|
if rendered:
|
|
compiled_prompts.append(rendered)
|
|
|
|
if compiled_prompts:
|
|
return Block(
|
|
label="tool_usage_rules",
|
|
value="\n".join(compiled_prompts),
|
|
description=COMPILED_PROMPT_DESCRIPTION,
|
|
)
|
|
return None
|
|
|
|
def guess_rule_violation(self, tool_name: ToolName) -> list[str]:
|
|
"""
|
|
Check if the given tool name or the previous tool in history matches any tool rule,
|
|
and return rendered prompt templates for matching rule violations.
|
|
|
|
Args:
|
|
tool_name: The name of the tool to check for rule violations
|
|
|
|
Returns:
|
|
list of rendered prompt templates from matching tool rules
|
|
"""
|
|
violated_rules = []
|
|
|
|
# Get the previous tool from history if it exists
|
|
previous_tool = self.tool_call_history[-1] if self.tool_call_history else None
|
|
|
|
# Check all tool rules for matches
|
|
all_rules = (
|
|
self.init_tool_rules
|
|
+ self.continue_tool_rules
|
|
+ self.child_based_tool_rules
|
|
+ self.parent_tool_rules
|
|
+ self.terminal_tool_rules
|
|
)
|
|
|
|
for rule in all_rules:
|
|
# Check if the current tool name or previous tool matches this rule's tool_name
|
|
if rule.tool_name == tool_name or (previous_tool and rule.tool_name == previous_tool):
|
|
rendered_prompt = rule.render_prompt()
|
|
if rendered_prompt:
|
|
violated_rules.append(rendered_prompt)
|
|
|
|
return violated_rules
|
|
|
|
def should_force_tool_call(self) -> bool:
|
|
"""
|
|
Determine if a tool call should be forced (using 'required' instead of 'auto') based on active constrained tool rules.
|
|
|
|
Returns:
|
|
bool: True if a constrained tool rule is currently active, False otherwise
|
|
"""
|
|
# check if we're at the start with init rules
|
|
if not self.tool_call_history and self.init_tool_rules:
|
|
return True
|
|
|
|
# check if any constrained rule is currently active
|
|
if self.tool_call_history:
|
|
last_tool = self.tool_call_history[-1]
|
|
|
|
# check child-based rules (ChildToolRule, ConditionalToolRule)
|
|
for rule in self.child_based_tool_rules:
|
|
if rule.requires_force_tool_call and rule.tool_name == last_tool:
|
|
return True
|
|
|
|
# check parent rules, `requires_force_tool_call` for safety in case this gets expanded
|
|
for rule in self.parent_tool_rules:
|
|
if rule.requires_force_tool_call and rule.tool_name == last_tool:
|
|
return True
|
|
|
|
return False
|