chore: clean up tool rule solver code

This commit is contained in:
Andy Li
2025-08-08 16:39:17 -07:00
committed by GitHub
parent 243a2b65e0
commit 3183c7b3c1
7 changed files with 84 additions and 113 deletions

View File

@@ -395,6 +395,24 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
if not data:
return None
if solver_data := data.get("tool_rules_solver"):
# Get existing tool_rules or reconstruct from categorized fields for backwards compatibility
tool_rules_data = solver_data.get("tool_rules", [])
if not tool_rules_data:
for field_name in (
"init_tool_rules",
"continue_tool_rules",
"child_based_tool_rules",
"parent_tool_rules",
"terminal_tool_rules",
"required_before_exit_tool_rules",
):
if field_data := solver_data.get(field_name):
tool_rules_data.extend(field_data)
solver_data["tool_rules"] = deserialize_tool_rules(tool_rules_data)
return AgentStepState(**data)
@@ -418,6 +436,7 @@ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormat
return JsonSchemaResponseFormat(**data)
if data["type"] == ResponseFormatType.json_object:
return JsonObjectResponseFormat(**data)
raise ValueError(f"Unknown Response Format type: {data['type']}")
# --------------------------

View File

@@ -1,11 +1,9 @@
from typing import List, Optional, Union
from typing import TypeAlias
from pydantic import BaseModel, Field
from letta.schemas.block import Block
from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
ChildToolRule,
ConditionalToolRule,
ContinueToolRule,
@@ -14,88 +12,61 @@ from letta.schemas.tool_rule import (
ParentToolRule,
RequiredBeforeExitToolRule,
TerminalToolRule,
ToolRule,
)
ToolName: TypeAlias = str
class ToolRuleValidationError(Exception):
"""Custom exception for tool rule validation errors in ToolRulesSolver."""
def __init__(self, message: str):
super().__init__(f"ToolRuleValidationError: {message}")
COMPILED_PROMPT_DESCRIPTION = "The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls."
class ToolRulesSolver(BaseModel):
init_tool_rules: List[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
tool_rules: list[ToolRule] | None = Field(default=None, description="Input list of tool rules")
# Categorized fields
init_tool_rules: list[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution.", exclude=True
)
continue_tool_rules: List[ContinueToolRule] = Field(
default_factory=list, description="Continue tool rules to be used to continue tool execution."
continue_tool_rules: list[ContinueToolRule] = Field(
default_factory=list, description="Continue tool rules to be used to continue tool execution.", exclude=True
)
# TODO: This should be renamed?
# TODO: These are tools that control the set of allowed functions in the next turn
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
child_based_tool_rules: list[ChildToolRule | ConditionalToolRule | MaxCountPerStepToolRule] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions.", exclude=True
)
parent_tool_rules: List[ParentToolRule] = Field(
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set."
parent_tool_rules: list[ParentToolRule] = Field(
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set.", exclude=True
)
terminal_tool_rules: List[TerminalToolRule] = Field(
default_factory=list, description="Terminal tool rules that end the agent loop if called."
terminal_tool_rules: list[TerminalToolRule] = Field(
default_factory=list, description="Terminal tool rules that end the agent loop if called.", exclude=True
)
required_before_exit_tool_rules: List[RequiredBeforeExitToolRule] = Field(
default_factory=list, description="Tool rules that must be called before the agent can exit."
required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field(
default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True
)
tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
def __init__(
self,
tool_rules: Optional[List[BaseToolRule]] = None,
init_tool_rules: Optional[List[InitToolRule]] = None,
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
parent_tool_rules: Optional[List[ParentToolRule]] = None,
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
required_before_exit_tool_rules: Optional[List[RequiredBeforeExitToolRule]] = None,
tool_call_history: Optional[List[str]] = None,
**kwargs,
):
super().__init__(
init_tool_rules=init_tool_rules or [],
continue_tool_rules=continue_tool_rules or [],
child_based_tool_rules=child_based_tool_rules or [],
parent_tool_rules=parent_tool_rules or [],
terminal_tool_rules=terminal_tool_rules or [],
required_before_exit_tool_rules=required_before_exit_tool_rules or [],
tool_call_history=tool_call_history or [],
**kwargs,
)
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
super().__init__(tool_rules=tool_rules, **kwargs)
if tool_rules:
for rule in tool_rules:
if rule.type == ToolRuleType.run_first:
assert isinstance(rule, InitToolRule)
def model_post_init(self, __context):
if self.tool_rules:
for rule in self.tool_rules:
if isinstance(rule, InitToolRule):
self.init_tool_rules.append(rule)
elif rule.type == ToolRuleType.constrain_child_tools:
assert isinstance(rule, ChildToolRule)
elif isinstance(rule, ChildToolRule):
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.conditional:
assert isinstance(rule, ConditionalToolRule)
self.validate_conditional_tool(rule)
elif isinstance(rule, ConditionalToolRule):
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
elif isinstance(rule, TerminalToolRule):
self.terminal_tool_rules.append(rule)
elif rule.type == ToolRuleType.continue_loop:
assert isinstance(rule, ContinueToolRule)
elif isinstance(rule, ContinueToolRule):
self.continue_tool_rules.append(rule)
elif rule.type == ToolRuleType.max_count_per_step:
assert isinstance(rule, MaxCountPerStepToolRule)
elif isinstance(rule, MaxCountPerStepToolRule):
self.child_based_tool_rules.append(rule)
elif rule.type == ToolRuleType.parent_last_tool:
assert isinstance(rule, ParentToolRule)
elif isinstance(rule, ParentToolRule):
self.parent_tool_rules.append(rule)
elif rule.type == ToolRuleType.required_before_exit:
assert isinstance(rule, RequiredBeforeExitToolRule)
elif isinstance(rule, RequiredBeforeExitToolRule):
self.required_before_exit_tool_rules.append(rule)
def register_tool_call(self, tool_name: str):
@@ -107,12 +78,12 @@ class ToolRulesSolver(BaseModel):
self.tool_call_history.clear()
def get_allowed_tool_names(
self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None
) -> List[str]:
self, available_tools: set[ToolName], error_on_empty: bool = True, last_function_response: str | None = None
) -> list[ToolName]:
"""Get a list of tool names allowed based on the last tool called.
The logic is as follows:
1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call
1. if there are no previous tool calls, and we have InitToolRules, those are the only options for the first tool call
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
"""
@@ -134,23 +105,23 @@ class ToolRulesSolver(BaseModel):
return list(final_allowed_tools)
def is_terminal_tool(self, tool_name: str) -> bool:
def is_terminal_tool(self, tool_name: ToolName) -> bool:
"""Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules."""
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
def has_children_tools(self, tool_name):
def has_children_tools(self, tool_name: ToolName):
"""Check if the tool has children tools"""
return any(rule.tool_name == tool_name for rule in self.child_based_tool_rules)
def is_continue_tool(self, tool_name):
def is_continue_tool(self, tool_name: ToolName):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def has_required_tools_been_called(self, available_tools: set[str]) -> bool:
def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool:
"""Check if all required-before-exit tools have been called."""
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]:
def get_uncalled_required_tools(self, available_tools: set[ToolName]) -> list[str]:
"""Get the list of required-before-exit tools that have not been called yet."""
if not self.required_before_exit_tool_rules:
return [] # No required tools means no uncalled tools
@@ -161,16 +132,12 @@ class ToolRulesSolver(BaseModel):
# Get required tools that are uncalled AND available
return list((required_tool_names & available_tools) - called_tool_names)
def get_ending_tool_names(self) -> List[str]:
"""Get the names of tools that are required before exit."""
return [rule.tool_name for rule in self.required_before_exit_tool_rules]
def compile_tool_rule_prompts(self) -> Optional[Block]:
def compile_tool_rule_prompts(self) -> Block | None:
"""
Compile prompt templates from all tool rules into an ephemeral Block.
Returns:
Optional[str]: Compiled prompt string with tool rule constraints, or None if no templates exist.
Block | None: Compiled prompt block with tool rule constraints, or None if no templates exist.
"""
compiled_prompts = []
@@ -191,20 +158,20 @@ class ToolRulesSolver(BaseModel):
return Block(
label="tool_usage_rules",
value="\n".join(compiled_prompts),
description="The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls.",
description=COMPILED_PROMPT_DESCRIPTION,
)
return None
def guess_rule_violation(self, tool_name: str) -> List[str]:
def guess_rule_violation(self, tool_name: ToolName) -> list[str]:
"""
Check if the given tool name or the previous tool in history matches any tool rule,
and return rendered prompt templates for matching rules.
and return rendered prompt templates for matching rule violations.
Args:
tool_name: The name of the tool to check for rule violations
Returns:
List of rendered prompt templates from matching tool rules
list of rendered prompt templates from matching tool rules
"""
violated_rules = []
@@ -228,18 +195,3 @@ class ToolRulesSolver(BaseModel):
violated_rules.append(rendered_prompt)
return violated_rules
@staticmethod
def validate_conditional_tool(rule: ConditionalToolRule):
"""
Validate a conditional tool rule
Args:
rule (ConditionalToolRule): The conditional tool rule to validate
Raises:
ToolRuleValidationError: If the rule is invalid
"""
if len(rule.child_output_mapping) == 0:
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
return True

View File

@@ -1,7 +1,7 @@
from enum import Enum
from typing import Annotated, Any, Dict, Literal, Union
from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, Field, field_validator
class ResponseFormatType(str, Enum):
@@ -52,13 +52,12 @@ class JsonSchemaResponseFormat(ResponseFormat):
description="The JSON schema of the response.",
)
@validator("json_schema")
def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
@field_validator("json_schema")
@classmethod
def validate_json_schema(cls, v: dict[str, Any]) -> Dict[str, Any]:
"""Validate that the provided schema is a valid JSON schema."""
if not isinstance(v, dict):
raise ValueError("JSON schema must be a dictionary")
if "schema" not in v:
raise ValueError("JSON schema should include a $schema property")
raise ValueError("JSON schema should include a schema property")
return v

View File

@@ -3,7 +3,7 @@ import logging
from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union
from jinja2 import Template
from pydantic import Field
from pydantic import Field, field_validator
from letta.schemas.enums import ToolRuleType
from letta.schemas.letta_base import LettaBase
@@ -117,6 +117,13 @@ class ConditionalToolRule(BaseToolRule):
return {self.default_child} if self.default_child else available_tools
@field_validator("child_output_mapping")
@classmethod
def validate_child_output_mapping(cls, v):
if len(v) == 0:
raise ValueError("Conditional tool rule must have at least one child tool.")
return v
@staticmethod
def _matches_key(function_output: str, key: Any) -> bool:
"""Helper function to determine if function output matches a mapping key."""

View File

@@ -1706,6 +1706,7 @@ class AgentManager:
else:
return agent_state
# Do not remove comment. (cliandy)
# TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish
@enforce_types
@trace_method
@@ -1715,7 +1716,6 @@ class AgentManager:
actor: PydanticUser,
force=False,
update_timestamp=True,
tool_rules_solver: Optional[ToolRulesSolver] = None,
dry_run: bool = False,
) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]:
"""Rebuilds the system message with the latest memory object and any shared memory block updates
@@ -1728,8 +1728,7 @@ class AgentManager:
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
if not tool_rules_solver:
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)

View File

@@ -201,7 +201,7 @@ class DatabaseChoice(str, Enum):
class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore")
letta_dir: Optional[Path] = Field(Path.home() / ".letta", env="LETTA_DIR")
letta_dir: Optional[Path] = Field(Path.home() / ".letta", alias="LETTA_DIR")
debug: Optional[bool] = False
cors_origins: Optional[list] = cors_origins

View File

@@ -1,7 +1,6 @@
import pytest
from letta.helpers import ToolRulesSolver
from letta.helpers.tool_rule_solver import ToolRuleValidationError
from letta.schemas.tool_rule import (
ChildToolRule,
ConditionalToolRule,
@@ -101,12 +100,8 @@ def test_conditional_tool_rule():
def test_invalid_conditional_tool_rule():
init_rule = InitToolRule(tool_name=START_TOOL)
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule])
with pytest.raises(ValueError, match="Conditional tool rule must have at least one child tool."):
ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
def test_tool_rules_with_invalid_path():