diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index c2b4ddab..daa92522 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -395,6 +395,24 @@ def deserialize_agent_step_state(data: Optional[Dict]) -> Optional[AgentStepStat if not data: return None + if solver_data := data.get("tool_rules_solver"): + # Get existing tool_rules or reconstruct from categorized fields for backwards compatibility + tool_rules_data = solver_data.get("tool_rules", []) + + if not tool_rules_data: + for field_name in ( + "init_tool_rules", + "continue_tool_rules", + "child_based_tool_rules", + "parent_tool_rules", + "terminal_tool_rules", + "required_before_exit_tool_rules", + ): + if field_data := solver_data.get(field_name): + tool_rules_data.extend(field_data) + + solver_data["tool_rules"] = deserialize_tool_rules(tool_rules_data) + return AgentStepState(**data) @@ -418,6 +436,7 @@ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormat return JsonSchemaResponseFormat(**data) if data["type"] == ResponseFormatType.json_object: return JsonObjectResponseFormat(**data) + raise ValueError(f"Unknown Response Format type: {data['type']}") # -------------------------- diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index acf7c2dd..b3ffc402 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -1,11 +1,9 @@ -from typing import List, Optional, Union +from typing import TypeAlias from pydantic import BaseModel, Field from letta.schemas.block import Block -from letta.schemas.enums import ToolRuleType from letta.schemas.tool_rule import ( - BaseToolRule, ChildToolRule, ConditionalToolRule, ContinueToolRule, @@ -14,88 +12,61 @@ from letta.schemas.tool_rule import ( ParentToolRule, RequiredBeforeExitToolRule, TerminalToolRule, + ToolRule, ) +ToolName: TypeAlias = str -class ToolRuleValidationError(Exception): - """Custom exception for tool rule validation errors in ToolRulesSolver.""" - - def __init__(self, message: str): - super().__init__(f"ToolRuleValidationError: {message}") +COMPILED_PROMPT_DESCRIPTION = "The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls." class ToolRulesSolver(BaseModel): - init_tool_rules: List[InitToolRule] = Field( - default_factory=list, description="Initial tool rules to be used at the start of tool execution." + tool_rules: list[ToolRule] | None = Field(default=None, description="Input list of tool rules") + + # Categorized fields + init_tool_rules: list[InitToolRule] = Field( + default_factory=list, description="Initial tool rules to be used at the start of tool execution.", exclude=True ) - continue_tool_rules: List[ContinueToolRule] = Field( - default_factory=list, description="Continue tool rules to be used to continue tool execution." + continue_tool_rules: list[ContinueToolRule] = Field( + default_factory=list, description="Continue tool rules to be used to continue tool execution.", exclude=True ) # TODO: This should be renamed? # TODO: These are tools that control the set of allowed functions in the next turn - child_based_tool_rules: List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]] = Field( - default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." + child_based_tool_rules: list[ChildToolRule | ConditionalToolRule | MaxCountPerStepToolRule] = Field( + default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions.", exclude=True ) - parent_tool_rules: List[ParentToolRule] = Field( - default_factory=list, description="Filter tool rules to be used to filter out tools from the available set." + parent_tool_rules: list[ParentToolRule] = Field( + default_factory=list, description="Filter tool rules to be used to filter out tools from the available set.", exclude=True ) - terminal_tool_rules: List[TerminalToolRule] = Field( - default_factory=list, description="Terminal tool rules that end the agent loop if called." + terminal_tool_rules: list[TerminalToolRule] = Field( + default_factory=list, description="Terminal tool rules that end the agent loop if called.", exclude=True ) - required_before_exit_tool_rules: List[RequiredBeforeExitToolRule] = Field( - default_factory=list, description="Tool rules that must be called before the agent can exit." + required_before_exit_tool_rules: list[RequiredBeforeExitToolRule] = Field( + default_factory=list, description="Tool rules that must be called before the agent can exit.", exclude=True ) - tool_call_history: List[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.") + tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.") - def __init__( - self, - tool_rules: Optional[List[BaseToolRule]] = None, - init_tool_rules: Optional[List[InitToolRule]] = None, - continue_tool_rules: Optional[List[ContinueToolRule]] = None, - child_based_tool_rules: Optional[List[Union[ChildToolRule, ConditionalToolRule, MaxCountPerStepToolRule]]] = None, - parent_tool_rules: Optional[List[ParentToolRule]] = None, - terminal_tool_rules: Optional[List[TerminalToolRule]] = None, - required_before_exit_tool_rules: Optional[List[RequiredBeforeExitToolRule]] = None, - tool_call_history: Optional[List[str]] = None, - **kwargs, - ): - super().__init__( - init_tool_rules=init_tool_rules or [], - continue_tool_rules=continue_tool_rules or [], - child_based_tool_rules=child_based_tool_rules or [], - parent_tool_rules=parent_tool_rules or [], - terminal_tool_rules=terminal_tool_rules or [], - required_before_exit_tool_rules=required_before_exit_tool_rules or [], - tool_call_history=tool_call_history or [], - **kwargs, - ) + def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs): + super().__init__(tool_rules=tool_rules, **kwargs) - if tool_rules: - for rule in tool_rules: - if rule.type == ToolRuleType.run_first: - assert isinstance(rule, InitToolRule) + def model_post_init(self, __context): + if self.tool_rules: + for rule in self.tool_rules: + if isinstance(rule, InitToolRule): self.init_tool_rules.append(rule) - elif rule.type == ToolRuleType.constrain_child_tools: - assert isinstance(rule, ChildToolRule) + elif isinstance(rule, ChildToolRule): self.child_based_tool_rules.append(rule) - elif rule.type == ToolRuleType.conditional: - assert isinstance(rule, ConditionalToolRule) - self.validate_conditional_tool(rule) + elif isinstance(rule, ConditionalToolRule): self.child_based_tool_rules.append(rule) - elif rule.type == ToolRuleType.exit_loop: - assert isinstance(rule, TerminalToolRule) + elif isinstance(rule, TerminalToolRule): self.terminal_tool_rules.append(rule) - elif rule.type == ToolRuleType.continue_loop: - assert isinstance(rule, ContinueToolRule) + elif isinstance(rule, ContinueToolRule): self.continue_tool_rules.append(rule) - elif rule.type == ToolRuleType.max_count_per_step: - assert isinstance(rule, MaxCountPerStepToolRule) + elif isinstance(rule, MaxCountPerStepToolRule): self.child_based_tool_rules.append(rule) - elif rule.type == ToolRuleType.parent_last_tool: - assert isinstance(rule, ParentToolRule) + elif isinstance(rule, ParentToolRule): self.parent_tool_rules.append(rule) - elif rule.type == ToolRuleType.required_before_exit: - assert isinstance(rule, RequiredBeforeExitToolRule) + elif isinstance(rule, RequiredBeforeExitToolRule): self.required_before_exit_tool_rules.append(rule) def register_tool_call(self, tool_name: str): @@ -107,12 +78,12 @@ class ToolRulesSolver(BaseModel): self.tool_call_history.clear() def get_allowed_tool_names( - self, available_tools: set[str], error_on_empty: bool = True, last_function_response: str | None = None - ) -> List[str]: + self, available_tools: set[ToolName], error_on_empty: bool = True, last_function_response: str | None = None + ) -> list[ToolName]: """Get a list of tool names allowed based on the last tool called. The logic is as follows: - 1. if there are no previous tool calls and we have InitToolRules, those are the only options for the first tool call + 1. if there are no previous tool calls, and we have InitToolRules, those are the only options for the first tool call 2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options 3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools """ @@ -134,23 +105,23 @@ class ToolRulesSolver(BaseModel): return list(final_allowed_tools) - def is_terminal_tool(self, tool_name: str) -> bool: + def is_terminal_tool(self, tool_name: ToolName) -> bool: """Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules.""" return any(rule.tool_name == tool_name for rule in self.terminal_tool_rules) - def has_children_tools(self, tool_name): + def has_children_tools(self, tool_name: ToolName): """Check if the tool has children tools""" return any(rule.tool_name == tool_name for rule in self.child_based_tool_rules) - def is_continue_tool(self, tool_name): + def is_continue_tool(self, tool_name: ToolName): """Check if the tool is defined as a continue tool in the tool rules.""" return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) - def has_required_tools_been_called(self, available_tools: set[str]) -> bool: + def has_required_tools_been_called(self, available_tools: set[ToolName]) -> bool: """Check if all required-before-exit tools have been called.""" return len(self.get_uncalled_required_tools(available_tools=available_tools)) == 0 - def get_uncalled_required_tools(self, available_tools: set[str]) -> List[str]: + def get_uncalled_required_tools(self, available_tools: set[ToolName]) -> list[str]: """Get the list of required-before-exit tools that have not been called yet.""" if not self.required_before_exit_tool_rules: return [] # No required tools means no uncalled tools @@ -161,16 +132,12 @@ class ToolRulesSolver(BaseModel): # Get required tools that are uncalled AND available return list((required_tool_names & available_tools) - called_tool_names) - def get_ending_tool_names(self) -> List[str]: - """Get the names of tools that are required before exit.""" - return [rule.tool_name for rule in self.required_before_exit_tool_rules] - - def compile_tool_rule_prompts(self) -> Optional[Block]: + def compile_tool_rule_prompts(self) -> Block | None: """ Compile prompt templates from all tool rules into an ephemeral Block. Returns: - Optional[str]: Compiled prompt string with tool rule constraints, or None if no templates exist. + Block | None: Compiled prompt block with tool rule constraints, or None if no templates exist. """ compiled_prompts = [] @@ -191,20 +158,20 @@ class ToolRulesSolver(BaseModel): return Block( label="tool_usage_rules", value="\n".join(compiled_prompts), - description="The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow. A single response may contain multiple tool calls.", + description=COMPILED_PROMPT_DESCRIPTION, ) return None - def guess_rule_violation(self, tool_name: str) -> List[str]: + def guess_rule_violation(self, tool_name: ToolName) -> list[str]: """ Check if the given tool name or the previous tool in history matches any tool rule, - and return rendered prompt templates for matching rules. + and return rendered prompt templates for matching rule violations. Args: tool_name: The name of the tool to check for rule violations Returns: - List of rendered prompt templates from matching tool rules + list of rendered prompt templates from matching tool rules """ violated_rules = [] @@ -228,18 +195,3 @@ class ToolRulesSolver(BaseModel): violated_rules.append(rendered_prompt) return violated_rules - - @staticmethod - def validate_conditional_tool(rule: ConditionalToolRule): - """ - Validate a conditional tool rule - - Args: - rule (ConditionalToolRule): The conditional tool rule to validate - - Raises: - ToolRuleValidationError: If the rule is invalid - """ - if len(rule.child_output_mapping) == 0: - raise ToolRuleValidationError("Conditional tool rule must have at least one child tool.") - return True diff --git a/letta/schemas/response_format.py b/letta/schemas/response_format.py index 08197c57..e928b9cb 100644 --- a/letta/schemas/response_format.py +++ b/letta/schemas/response_format.py @@ -1,7 +1,7 @@ from enum import Enum from typing import Annotated, Any, Dict, Literal, Union -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator class ResponseFormatType(str, Enum): @@ -52,13 +52,12 @@ class JsonSchemaResponseFormat(ResponseFormat): description="The JSON schema of the response.", ) - @validator("json_schema") - def validate_json_schema(cls, v: Dict[str, Any]) -> Dict[str, Any]: + @field_validator("json_schema") + @classmethod + def validate_json_schema(cls, v: dict[str, Any]) -> Dict[str, Any]: """Validate that the provided schema is a valid JSON schema.""" - if not isinstance(v, dict): - raise ValueError("JSON schema must be a dictionary") if "schema" not in v: - raise ValueError("JSON schema should include a $schema property") + raise ValueError("JSON schema should include a schema property") return v diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index af347d9b..b4744abf 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -3,7 +3,7 @@ import logging from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union from jinja2 import Template -from pydantic import Field +from pydantic import Field, field_validator from letta.schemas.enums import ToolRuleType from letta.schemas.letta_base import LettaBase @@ -117,6 +117,13 @@ class ConditionalToolRule(BaseToolRule): return {self.default_child} if self.default_child else available_tools + @field_validator("child_output_mapping") + @classmethod + def validate_child_output_mapping(cls, v): + if len(v) == 0: + raise ValueError("Conditional tool rule must have at least one child tool.") + return v + @staticmethod def _matches_key(function_output: str, key: Any) -> bool: """Helper function to determine if function output matches a mapping key.""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 9516a34e..5751ac40 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1706,6 +1706,7 @@ class AgentManager: else: return agent_state + # Do not remove comment. (cliandy) # TODO: This is probably one of the worst pieces of code I've ever written please rip up as you see wish @enforce_types @trace_method @@ -1715,7 +1716,6 @@ class AgentManager: actor: PydanticUser, force=False, update_timestamp=True, - tool_rules_solver: Optional[ToolRulesSolver] = None, dry_run: bool = False, ) -> Tuple[PydanticAgentState, Optional[PydanticMessage], int, int]: """Rebuilds the system message with the latest memory object and any shared memory block updates @@ -1728,8 +1728,7 @@ class AgentManager: num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory", "sources", "tools"], actor=actor) - if not tool_rules_solver: - tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) + tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) curr_system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor) diff --git a/letta/settings.py b/letta/settings.py index 1c047b47..975872ee 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -201,7 +201,7 @@ class DatabaseChoice(str, Enum): class Settings(BaseSettings): model_config = SettingsConfigDict(env_prefix="letta_", extra="ignore") - letta_dir: Optional[Path] = Field(Path.home() / ".letta", env="LETTA_DIR") + letta_dir: Optional[Path] = Field(Path.home() / ".letta", alias="LETTA_DIR") debug: Optional[bool] = False cors_origins: Optional[list] = cors_origins diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 0e5f4ed5..25bb5d31 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -1,7 +1,6 @@ import pytest from letta.helpers import ToolRulesSolver -from letta.helpers.tool_rule_solver import ToolRuleValidationError from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, @@ -101,12 +100,8 @@ def test_conditional_tool_rule(): def test_invalid_conditional_tool_rule(): - init_rule = InitToolRule(tool_name=START_TOOL) - terminal_rule = TerminalToolRule(tool_name=END_TOOL) - invalid_rule_1 = ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={}) - - with pytest.raises(ToolRuleValidationError, match="Conditional tool rule must have at least one child tool."): - ToolRulesSolver(tool_rules=[init_rule, invalid_rule_1, terminal_rule]) + with pytest.raises(ValueError, match="Conditional tool rule must have at least one child tool."): + ConditionalToolRule(tool_name=START_TOOL, default_child=END_TOOL, child_output_mapping={}) def test_tool_rules_with_invalid_path():