feat: Add prompting to guide tool rule usage (#2742)

This commit is contained in:
Matthew Zhou
2025-06-10 16:21:27 -07:00
committed by GitHub
parent 8ced9e57ba
commit ba3d59bba5
8 changed files with 94 additions and 9 deletions

View File

@@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, List, Optional, Union
import openai
from letta.constants import DEFAULT_MAX_STEPS
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.schemas.agent import AgentState
@@ -16,6 +17,7 @@ from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.utils import united_diff
logger = get_logger(__name__)
@@ -40,6 +42,8 @@ class BaseAgent(ABC):
self.openai_client = openai_client
self.message_manager = message_manager
self.agent_manager = agent_manager
# TODO: Pass this in
self.passage_manager = PassageManager()
self.actor = actor
self.logger = get_logger(agent_id)
@@ -78,8 +82,9 @@ class BaseAgent(ABC):
self,
in_context_messages: List[Message],
agent_state: AgentState,
num_messages: int | None = None, # storing these calculations is specific to the voice agent
num_archival_memories: int | None = None,
tool_rules_solver: Optional[ToolRulesSolver] = None,
num_messages: Optional[int] = None, # storing these calculations is specific to the voice agent
num_archival_memories: Optional[int] = None,
) -> List[Message]:
"""
Async version of function above. For now before breaking up components, changes should be made in both places.
@@ -113,6 +118,7 @@ class BaseAgent(ABC):
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
tool_rules_solver=tool_rules_solver,
)
diff = united_diff(curr_system_message_text, new_system_message_str)

View File

@@ -784,7 +784,11 @@ class LettaAgent(BaseAgent):
),
)
in_context_messages = await self._rebuild_memory_async(
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
in_context_messages,
agent_state,
num_messages=self.num_messages,
num_archival_memories=self.num_archival_memories,
tool_rules_solver=tool_rules_solver,
)
tools = [

View File

@@ -2,6 +2,7 @@ from typing import List, Optional, Set, Union
from pydantic import BaseModel, Field
from letta.schemas.block import Block
from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import (
BaseToolRule,
@@ -116,10 +117,10 @@ class ToolRulesSolver(BaseModel):
return list(available_tools)
else:
# Collect valid tools from all child-based rules
valid_tool_sets = [
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
for rule in self.child_based_tool_rules + self.parent_tool_rules
]
valid_tool_sets = []
for rule in self.child_based_tool_rules + self.parent_tool_rules:
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
valid_tool_sets.append(tools)
# Compute intersection of all valid tool sets
final_allowed_tools = set.intersection(*valid_tool_sets) if valid_tool_sets else available_tools
@@ -141,6 +142,36 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def compile_tool_rule_prompts(self) -> Optional[Block]:
"""
Compile prompt templates from all tool rules into an ephemeral Block.
Returns:
Optional[str]: Compiled prompt string with tool rule constraints, or None if no templates exist.
"""
compiled_prompts = []
all_rules = (
self.init_tool_rules
+ self.continue_tool_rules
+ self.child_based_tool_rules
+ self.parent_tool_rules
+ self.terminal_tool_rules
)
for rule in all_rules:
rendered = rule.render_prompt()
if rendered:
compiled_prompts.append(rendered)
if compiled_prompts:
return Block(
label="tool_usage_rules",
value="\n".join(compiled_prompts),
description="The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow.",
)
return None
@staticmethod
def validate_conditional_tool(rule: ConditionalToolRule):
"""

View File

@@ -1,4 +1,5 @@
import json
import logging
import re
from typing import Dict, List, Optional, Union
@@ -271,6 +272,8 @@ class AnthropicClient(LLMClientBase):
return data
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
logging.getLogger("httpx").setLevel(logging.WARNING)
client = anthropic.AsyncAnthropic()
if messages and len(messages) == 0:
messages = None

View File

@@ -1,20 +1,43 @@
import json
import logging
from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union
from jinja2 import Template
from pydantic import Field
from letta.schemas.enums import ToolRuleType
from letta.schemas.letta_base import LettaBase
logger = logging.getLogger(__name__)
class BaseToolRule(LettaBase):
__id_prefix__ = "tool_rule"
tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.")
type: ToolRuleType = Field(..., description="The type of the message.")
prompt_template: Optional[str] = Field(
None,
description="Optional Jinja2 template for generating agent prompt about this tool rule. Template can use variables like 'tool_name' and rule-specific attributes.",
)
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> set[str]:
raise NotImplementedError
def render_prompt(self) -> Optional[str]:
"""Render the prompt template with this rule's attributes."""
if not self.prompt_template:
return None
try:
template = Template(self.prompt_template)
return template.render(**self.model_dump())
except Exception as e:
logger.warning(
f"Failed to render prompt template for tool rule '{self.tool_name}' (type: {self.type}). "
f"Template: '{self.prompt_template}'. Error: {e}"
)
return None
class ChildToolRule(BaseToolRule):
"""
@@ -128,6 +151,10 @@ class MaxCountPerStepToolRule(BaseToolRule):
type: Literal[ToolRuleType.max_count_per_step] = ToolRuleType.max_count_per_step
max_count_limit: int = Field(..., description="The max limit for the total number of times this tool can be invoked in a single step.")
prompt_template: Optional[str] = Field(
default="<tool_constraint>{{ tool_name }}: max {{ max_count_limit }} use(s) per turn</tool_constraint>",
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
)
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
"""Restricts the tool if it has been called max_count_limit times in the current step."""

View File

@@ -19,6 +19,7 @@ from letta.constants import (
FILES_TOOLS,
MULTI_AGENT_TOOLS,
)
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
@@ -1444,7 +1445,7 @@ class AgentManager:
@trace_method
@enforce_types
async def rebuild_system_prompt_async(
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None
) -> PydanticAgentState:
"""Rebuilds the system message with the latest memory object and any shared memory block updates
@@ -1453,6 +1454,8 @@ class AgentManager:
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
"""
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
if not tool_rules_solver:
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
curr_system_message = await self.get_system_message_async(
agent_id=agent_id, actor=actor
@@ -1492,6 +1495,7 @@ class AgentManager:
in_context_memory_last_edit=memory_edit_timestamp,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
tool_rules_solver=tool_rules_solver,
)
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)

View File

@@ -229,6 +229,7 @@ def compile_system_message(
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
tool_rules_solver: Optional[ToolRulesSolver] = None,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
@@ -237,6 +238,11 @@ def compile_system_message(
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
# Add tool rule constraints if available
if tool_rules_solver is not None:
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
if tool_constraint_block: # There may not be any depending on if there are tool rules attached
in_context_memory.blocks.append(tool_constraint_block)
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject

View File

@@ -26,7 +26,11 @@ class ExternalComposioToolExecutor(ToolExecutor):
sandbox_config: Optional[SandboxConfig] = None,
sandbox_env_vars: Optional[Dict[str, Any]] = None,
) -> ToolExecutionResult:
assert agent_state is not None, "Agent state is required for external Composio tools"
if agent_state is None:
return ToolExecutionResult(
status="error",
func_return="Agent state is required for external Composio tools. Please contact Letta support if you see this error.",
)
action_name = generate_composio_action_from_func_name(tool.name)
# Get entity ID from the agent_state