chore: clean up tool rule solver code
This commit is contained in:
@@ -395,6 +395,24 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat
|
||||
if not data:
|
||||
return None
|
||||
|
||||
if solver_data := data.get("tool_rules_solver"):
|
||||
# Get existing tool_rules or reconstruct from categorized fields for backwards compatibility
|
||||
tool_rules_data = solver_data.get("tool_rules", [])
|
||||
|
||||
if not tool_rules_data:
|
||||
for field_name in (
|
||||
"init_tool_rules",
|
||||
"continue_tool_rules",
|
||||
"child_based_tool_rules",
|
||||
"parent_tool_rules",
|
||||
"terminal_tool_rules",
|
||||
"required_before_exit_tool_rules",
|
||||
):
|
||||
if field_data := solver_data.get(field_name):
|
||||
tool_rules_data.extend(field_data)
|
||||
|
||||
solver_data["tool_rules"] = deserialize_tool_rules(tool_rules_data)
|
||||
|
||||
return AgentStepState(**data)
|
||||
|
||||
|
||||
@@ -418,6 +436,7 @@ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormat
|
||||
return JsonSchemaResponseFormat(**data)
|
||||
if data["type"] == ResponseFormatType.json_object:
|
||||
return JsonObjectResponseFormat(**data)
|
||||
raise ValueError(f"Unknown Response Format type: {data['type']}")
|
||||
|
||||
|
||||
# --------------------------
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from typing import List, Optional, Union
|
||||
from typing import TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.tool_rule import (
|
||||
BaseToolRule,
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
ContinueToolRule,
|
||||
@@ -14,88 +12,61 @@ from letta.schemas.tool_rule import (
|
||||
ParentToolRule,
|
||||
RequiredBeforeExitToolRule,
|
||||
TerminalToolRule,
|
||||
ToolRule,
|
||||
)
|
||||
|
||||
ToolName: TypeAlias = str
|
||||
|
||||
class ToolRuleValidationError(Exception):
|
||||
"""Custom exception for tool rule validation errors in ToolRulesSolver."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(f"ToolRuleValidationError: {message}")
|
||||
COMPILED_PROMPT_DESCRIPTION = "The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls."
|
||||
|
||||
|
||||
class ToolRulesSolver(BaseModel):
|
||||
init_tool_rules: List[InitToolRule] = Field(
|
||||
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
|
||||
tool_rules: list[ToolRule] | None = Field(default=None, description="Input list of tool rules")
|
||||
|
||||
# Categorized fields
|
||||
init_tool_rules: list[InitToolRule] = Field(
|
||||
default_factory=list, description="Initial tool rules to be used at the start of tool execution.", exclude=True
|
||||
)
|
||||
continue_tool_rules: List[ContinueToolRule] = Field(
|
||||
default_factory=list, description="Continue tool rules to be used to continue tool execution."
|
||||
continue_tool_rules: list[ContinueToolRule] = Field(
|
||||
default_factory=list, description="Continue tool rules to be used to continue tool execution.", exclude=True
|
||||
)
|
||||
# TODO: This should be renamed?
|
||||
# TODO: These are tools that control the set of allowed functions in the next turn
|
||||
child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field(
|
||||
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
||||
child_based_tool_rules: list[ChildToolRule | ConditionalToolRule | MaxCountPerStepToolRule] = Field(
|
||||
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions.", exclude=True
|
||||
)
|
||||
parent_tool_rules: List[ParentToolRule] = Field(
|
||||
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set."
|
||||
parent_tool_rules: list[ParentToolRule] = Field(
|
||||
default_factory=list, description="Filter tool rules to be used to filter out tools from the available set.", exclude=True
|
||||
)
|
||||
terminal_tool_rules: List[TerminalToolRule] = Field(
|
||||
default_factory=list, description="Terminal tool rules that end the agent loop if called."
|
||||
terminal_tool_rules: list[TerminalToolRule] = Field(
|
||||
default_factory=list, description="Terminal tool rules that end the agent loop if called.", exclude=True
|
||||
)
|
||||
required_before_exit_tool_rules: List[RequiredBeforeExitToolRule] = Field(
|
||||
default_factory=list, description="Tool rules that must be called before the agent can exit."
|
||||
required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field(
|
||||
default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True
|
||||
)
|
||||
tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
||||
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
init_tool_rules: Optional[List[InitToolRule]] = None,
|
||||
continue_tool_rules: Optional[List[ContinueToolRule]] = None,
|
||||
child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None,
|
||||
parent_tool_rules: Optional[List[ParentToolRule]] = None,
|
||||
terminal_tool_rules: Optional[List[TerminalToolRule]] = None,
|
||||
required_before_exit_tool_rules: Optional[List[RequiredBeforeExitToolRule]] = None,
|
||||
tool_call_history: Optional[List[str]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
init_tool_rules=init_tool_rules or [],
|
||||
continue_tool_rules=continue_tool_rules or [],
|
||||
child_based_tool_rules=child_based_tool_rules or [],
|
||||
parent_tool_rules=parent_tool_rules or [],
|
||||
terminal_tool_rules=terminal_tool_rules or [],
|
||||
required_before_exit_tool_rules=required_before_exit_tool_rules or [],
|
||||
tool_call_history=tool_call_history or [],
|
||||
**kwargs,
|
||||
)
|
||||
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
|
||||
super().__init__(tool_rules=tool_rules, **kwargs)
|
||||
|
||||
if tool_rules:
|
||||
for rule in tool_rules:
|
||||
if rule.type == ToolRuleType.run_first:
|
||||
assert isinstance(rule, InitToolRule)
|
||||
def model_post_init(self, __context):
|
||||
if self.tool_rules:
|
||||
for rule in self.tool_rules:
|
||||
if isinstance(rule, InitToolRule):
|
||||
self.init_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.constrain_child_tools:
|
||||
assert isinstance(rule, ChildToolRule)
|
||||
elif isinstance(rule, ChildToolRule):
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.conditional:
|
||||
assert isinstance(rule, ConditionalToolRule)
|
||||
self.validate_conditional_tool(rule)
|
||||
elif isinstance(rule, ConditionalToolRule):
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.exit_loop:
|
||||
assert isinstance(rule, TerminalToolRule)
|
||||
elif isinstance(rule, TerminalToolRule):
|
||||
self.terminal_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.continue_loop:
|
||||
assert isinstance(rule, ContinueToolRule)
|
||||
elif isinstance(rule, ContinueToolRule):
|
||||
self.continue_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.max_count_per_step:
|
||||
assert isinstance(rule, MaxCountPerStepToolRule)
|
||||
elif isinstance(rule, MaxCountPerStepToolRule):
|
||||
self.child_based_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.parent_last_tool:
|
||||
assert isinstance(rule, ParentToolRule)
|
||||
elif isinstance(rule, ParentToolRule):
|
||||
self.parent_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.required_before_exit:
|
||||
assert isinstance(rule, RequiredBeforeExitToolRule)
|
||||
elif isinstance(rule, RequiredBeforeExitToolRule):
|
||||
self.required_before_exit_tool_rules.append(rule)
|
||||
|
||||
def register_tool_call(self, tool_name: str):
|
||||
@@ -107,12 +78,12 @@ class ToolRulesSolver(BaseModel):
|
||||
self.tool_call_history.clear()
|
||||
|
||||
def get_allowed_tool_names(
|
||||
self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None
|
||||
) -> List[str]:
|
||||
self, available_tools: set[ToolName], error_on_empty: bool = True, last_function_response: str | None = None
|
||||
) -> list[ToolName]:
|
||||
"""Get a list of tool names allowed based on the last tool called.
|
||||
|
||||
The logic is as follows:
|
||||
1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call
|
||||
1. if there are no previous tool calls, and we have InitToolRules, those are the only options for the first tool call
|
||||
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
|
||||
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
|
||||
"""
|
||||
@@ -134,23 +105,23 @@ class ToolRulesSolver(BaseModel):
|
||||
|
||||
return list(final_allowed_tools)
|
||||
|
||||
def is_terminal_tool(self, tool_name: str) -> bool:
|
||||
def is_terminal_tool(self, tool_name: ToolName) -> bool:
|
||||
"""Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules)
|
||||
|
||||
def has_children_tools(self, tool_name):
|
||||
def has_children_tools(self, tool_name: ToolName):
|
||||
"""Check if the tool has children tools"""
|
||||
return any(rule.tool_name == tool_name for rule in self.child_based_tool_rules)
|
||||
|
||||
def is_continue_tool(self, tool_name):
|
||||
def is_continue_tool(self, tool_name: ToolName):
|
||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||
|
||||
def has_required_tools_been_called(self, available_tools: set[str]) -> bool:
|
||||
def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool:
|
||||
"""Check if all required-before-exit tools have been called."""
|
||||
return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0
|
||||
|
||||
def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]:
|
||||
def get_uncalled_required_tools(self, available_tools: set[ToolName]) -> list[str]:
|
||||
"""Get the list of required-before-exit tools that have not been called yet."""
|
||||
if not self.required_before_exit_tool_rules:
|
||||
return [] # No required tools means no uncalled tools
|
||||
@@ -161,16 +132,12 @@ class ToolRulesSolver(BaseModel):
|
||||
# Get required tools that are uncalled AND available
|
||||
return list((required_tool_names & available_tools) - called_tool_names)
|
||||
|
||||
def get_ending_tool_names(self) -> List[str]:
|
||||
"""Get the names of tools that are required before exit."""
|
||||
return [rule.tool_name for rule in self.required_before_exit_tool_rules]
|
||||
|
||||
def compile_tool_rule_prompts(self) -> Optional[Block]:
|
||||
def compile_tool_rule_prompts(self) -> Block | None:
|
||||
"""
|
||||
Compile prompt templates from all tool rules into an ephemeral Block.
|
||||
|
||||
Returns:
|
||||
Optional[str]: Compiled prompt string with tool rule constraints, or None if no templates exist.
|
||||
Block | None: Compiled prompt block with tool rule constraints, or None if no templates exist.
|
||||
"""
|
||||
compiled_prompts = []
|
||||
|
||||
@@ -191,20 +158,20 @@ class ToolRulesSolver(BaseModel):
|
||||
return Block(
|
||||
label="tool_usage_rules",
|
||||
value="\n".join(compiled_prompts),
|
||||
description="The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls.",
|
||||
description=COMPILED_PROMPT_DESCRIPTION,
|
||||
)
|
||||
return None
|
||||
|
||||
def guess_rule_violation(self, tool_name: str) -> List[str]:
|
||||
def guess_rule_violation(self, tool_name: ToolName) -> list[str]:
|
||||
"""
|
||||
Check if the given tool name or the previous tool in history matches any tool rule,
|
||||
and return rendered prompt templates for matching rules.
|
||||
and return rendered prompt templates for matching rule violations.
|
||||
|
||||
Args:
|
||||
tool_name: The name of the tool to check for rule violations
|
||||
|
||||
Returns:
|
||||
List of rendered prompt templates from matching tool rules
|
||||
list of rendered prompt templates from matching tool rules
|
||||
"""
|
||||
violated_rules = []
|
||||
|
||||
@@ -228,18 +195,3 @@ class ToolRulesSolver(BaseModel):
|
||||
violated_rules.append(rendered_prompt)
|
||||
|
||||
return violated_rules
|
||||
|
||||
@staticmethod
|
||||
def validate_conditional_tool(rule: ConditionalToolRule):
|
||||
"""
|
||||
Validate a conditional tool rule
|
||||
|
||||
Args:
|
||||
rule (ConditionalToolRule): The conditional tool rule to validate
|
||||
|
||||
Raises:
|
||||
ToolRuleValidationError: If the rule is invalid
|
||||
"""
|
||||
if len(rule.child_output_mapping) == 0:
|
||||
raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.")
|
||||
return True
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from enum import Enum
|
||||
from typing import Annotated, Any, Dict, Literal, Union
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
class ResponseFormatType(str, Enum):
|
||||
@@ -52,13 +52,12 @@ class JsonSchemaResponseFormat(ResponseFormat):
|
||||
description="The JSON schema of the response.",
|
||||
)
|
||||
|
||||
@validator("json_schema")
|
||||
def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]:
|
||||
@field_validator("json_schema")
|
||||
@classmethod
|
||||
def validate_json_schema(cls, v: dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Validate that the provided schema is a valid JSON schema."""
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("JSON schema must be a dictionary")
|
||||
if "schema" not in v:
|
||||
raise ValueError("JSON schema should include a $schema property")
|
||||
raise ValueError("JSON schema should include a schema property")
|
||||
return v
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
@@ -117,6 +117,13 @@ class ConditionalToolRule(BaseToolRule):
|
||||
|
||||
return {self.default_child} if self.default_child else available_tools
|
||||
|
||||
@field_validator("child_output_mapping")
|
||||
@classmethod
|
||||
def validate_child_output_mapping(cls, v):
|
||||
if len(v) == 0:
|
||||
raise ValueError("Conditional tool rule must have at least one child tool.")
|
||||
return v
|
||||
|
||||
@staticmethod
|
||||
def _matches_key(function_output: str, key: Any) -> bool:
|
||||
"""Helper function to determine if function output matches a mapping key."""
|
||||
|
||||
@@ -1706,6 +1706,7 @@ class AgentManager:
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
# Do not remove comment. (cliandy)
|
||||
# TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -1715,7 +1716,6 @@ class AgentManager:
|
||||
actor: PydanticUser,
|
||||
force=False,
|
||||
update_timestamp=True,
|
||||
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
||||
dry_run: bool = False,
|
||||
) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]:
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
@@ -1728,8 +1728,7 @@ class AgentManager:
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor)
|
||||
|
||||
if not tool_rules_solver:
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
|
||||
curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||||
|
||||
|
||||
@@ -201,7 +201,7 @@ class DatabaseChoice(str, Enum):
|
||||
class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore")
|
||||
|
||||
letta_dir: Optional[Path] = Field(Path.home() / ".letta", env="LETTA_DIR")
|
||||
letta_dir: Optional[Path] = Field(Path.home() / ".letta", alias="LETTA_DIR")
|
||||
debug: Optional[bool] = False
|
||||
cors_origins: Optional[list] = cors_origins
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import pytest
|
||||
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.tool_rule_solver import ToolRuleValidationError
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
@@ -101,12 +100,8 @@ def test_conditional_tool_rule():
|
||||
|
||||
|
||||
def test_invalid_conditional_tool_rule():
|
||||
init_rule = InitToolRule(tool_name=START_TOOL)
|
||||
terminal_rule = TerminalToolRule(tool_name=END_TOOL)
|
||||
invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
|
||||
|
||||
with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."):
|
||||
ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule])
|
||||
with pytest.raises(ValueError, match="Conditional tool rule must have at least one child tool."):
|
||||
ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={})
|
||||
|
||||
|
||||
def test_tool_rules_with_invalid_path():
|
||||
|
||||
Reference in New Issue
Block a user