From ba3d59bba53caa8e5ad3d85bf0ed3d27ba9b2c91 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 10 Jun 2025 16:21:27 -0700 Subject: [PATCH] feat: Add prompting to guide tool rule usage (#2742) --- letta/agents/base_agent.py | 10 ++++- letta/agents/letta_agent.py | 6 ++- letta/helpers/tool_rule_solver.py | 39 +++++++++++++++++-- letta/llm_api/anthropic_client.py | 3 ++ letta/schemas/tool_rule.py | 27 +++++++++++++ letta/services/agent_manager.py | 6 ++- .../services/helpers/agent_manager_helper.py | 6 +++ .../tool_executor/composio_tool_executor.py | 6 ++- 8 files changed, 94 insertions(+), 9 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index e275f7b6..dcc0f2ec 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, List, Optional, Union import openai from letta.constants import DEFAULT_MAX_STEPS +from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.schemas.agent import AgentState @@ -16,6 +17,7 @@ from letta.schemas.user import User from letta.services.agent_manager import AgentManager from letta.services.helpers.agent_manager_helper import compile_system_message from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager from letta.utils import united_diff logger = get_logger(__name__) @@ -40,6 +42,8 @@ class BaseAgent(ABC): self.openai_client = openai_client self.message_manager = message_manager self.agent_manager = agent_manager + # TODO: Pass this in + self.passage_manager = PassageManager() self.actor = actor self.logger = get_logger(agent_id) @@ -78,8 +82,9 @@ class BaseAgent(ABC): self, in_context_messages: List[Message], agent_state: AgentState, - num_messages: int | None = None, # storing these calculations is specific to the voice agent - num_archival_memories: int | None = None, + tool_rules_solver: Optional[ToolRulesSolver] = None, + num_messages: Optional[int] = None, # storing these calculations is specific to the voice agent + num_archival_memories: Optional[int] = None, ) -> List[Message]: """ Async version of function above. For now before breaking up components, changes should be made in both places. @@ -113,6 +118,7 @@ class BaseAgent(ABC): in_context_memory_last_edit=memory_edit_timestamp, previous_message_count=num_messages, archival_memory_size=num_archival_memories, + tool_rules_solver=tool_rules_solver, ) diff = united_diff(curr_system_message_text, new_system_message_str) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 7a0e0336..7d5149c3 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -784,7 +784,11 @@ class LettaAgent(BaseAgent): ), ) in_context_messages = await self._rebuild_memory_async( - in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories + in_context_messages, + agent_state, + num_messages=self.num_messages, + num_archival_memories=self.num_archival_memories, + tool_rules_solver=tool_rules_solver, ) tools = [ diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index b36d2b94..caecd271 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -2,6 +2,7 @@ from typing import List, Optional, Set, Union from pydantic import BaseModel, Field +from letta.schemas.block import Block from letta.schemas.enums import ToolRuleType from letta.schemas.tool_rule import ( BaseToolRule, @@ -116,10 +117,10 @@ class ToolRulesSolver(BaseModel): return list(available_tools) else: # Collect valid tools from all child-based rules - valid_tool_sets = [ - rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response) - for rule in self.child_based_tool_rules + self.parent_tool_rules - ] + valid_tool_sets = [] + for rule in self.child_based_tool_rules + self.parent_tool_rules: + tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response) + valid_tool_sets.append(tools) # Compute intersection of all valid tool sets final_allowed_tools = set.intersection(*valid_tool_sets) if valid_tool_sets else available_tools @@ -141,6 +142,36 @@ class ToolRulesSolver(BaseModel): """Check if the tool is defined as a continue tool in the tool rules.""" return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) + def compile_tool_rule_prompts(self) -> Optional[Block]: + """ + Compile prompt templates from all tool rules into an ephemeral Block. + + Returns: + Optional[str]: Compiled prompt string with tool rule constraints, or None if no templates exist. + """ + compiled_prompts = [] + + all_rules = ( + self.init_tool_rules + + self.continue_tool_rules + + self.child_based_tool_rules + + self.parent_tool_rules + + self.terminal_tool_rules + ) + + for rule in all_rules: + rendered = rule.render_prompt() + if rendered: + compiled_prompts.append(rendered) + + if compiled_prompts: + return Block( + label="tool_usage_rules", + value="\n".join(compiled_prompts), + description="The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow.", + ) + return None + @staticmethod def validate_conditional_tool(rule: ConditionalToolRule): """ diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 24c66bbf..5f790c22 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -1,4 +1,5 @@ import json +import logging import re from typing import Dict, List, Optional, Union @@ -271,6 +272,8 @@ class AnthropicClient(LLMClientBase): return data async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int: + logging.getLogger("httpx").setLevel(logging.WARNING) + client = anthropic.AsyncAnthropic() if messages and len(messages) == 0: messages = None diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 4a658e2c..7efa07c0 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -1,20 +1,43 @@ import json +import logging from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union +from jinja2 import Template from pydantic import Field from letta.schemas.enums import ToolRuleType from letta.schemas.letta_base import LettaBase +logger = logging.getLogger(__name__) + class BaseToolRule(LettaBase): __id_prefix__ = "tool_rule" tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.") type: ToolRuleType = Field(..., description="The type of the message.") + prompt_template: Optional[str] = Field( + None, + description="Optional Jinja2 template for generating agent prompt about this tool rule. Template can use variables like 'tool_name' and rule-specific attributes.", + ) def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> set[str]: raise NotImplementedError + def render_prompt(self) -> Optional[str]: + """Render the prompt template with this rule's attributes.""" + if not self.prompt_template: + return None + + try: + template = Template(self.prompt_template) + return template.render(**self.model_dump()) + except Exception as e: + logger.warning( + f"Failed to render prompt template for tool rule '{self.tool_name}' (type: {self.type}). " + f"Template: '{self.prompt_template}'. Error: {e}" + ) + return None + class ChildToolRule(BaseToolRule): """ @@ -128,6 +151,10 @@ class MaxCountPerStepToolRule(BaseToolRule): type: Literal[ToolRuleType.max_count_per_step] = ToolRuleType.max_count_per_step max_count_limit: int = Field(..., description="The max limit for the total number of times this tool can be invoked in a single step.") + prompt_template: Optional[str] = Field( + default="{{ tool_name }}: max {{ max_count_limit }} use(s) per turn", + description="Optional Jinja2 template for generating agent prompt about this tool rule.", + ) def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]: """Restricts the tool if it has been called max_count_limit times in the current step.""" diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index a126ff25..a100c5cd 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -19,6 +19,7 @@ from letta.constants import ( FILES_TOOLS, MULTI_AGENT_TOOLS, ) +from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time from letta.llm_api.llm_client import LLMClient from letta.log import get_logger @@ -1444,7 +1445,7 @@ class AgentManager: @trace_method @enforce_types async def rebuild_system_prompt_async( - self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True + self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None ) -> PydanticAgentState: """Rebuilds the system message with the latest memory object and any shared memory block updates @@ -1453,6 +1454,8 @@ class AgentManager: Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages """ agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor) + if not tool_rules_solver: + tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) curr_system_message = await self.get_system_message_async( agent_id=agent_id, actor=actor @@ -1492,6 +1495,7 @@ class AgentManager: in_context_memory_last_edit=memory_edit_timestamp, previous_message_count=num_messages, archival_memory_size=num_archival_memories, + tool_rules_solver=tool_rules_solver, ) diff = united_diff(curr_system_message_openai["content"], new_system_message_str) diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index b4935ef8..6839f8d8 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -229,6 +229,7 @@ def compile_system_message( template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", previous_message_count: int = 0, archival_memory_size: int = 0, + tool_rules_solver: Optional[ToolRulesSolver] = None, ) -> str: """Prepare the final/full system message that will be fed into the LLM API @@ -237,6 +238,11 @@ def compile_system_message( The following are reserved variables: - CORE_MEMORY: the in-context memory of the LLM """ + # Add tool rule constraints if available + if tool_rules_solver is not None: + tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts() + if tool_constraint_block: # There may not be any depending on if there are tool rules attached + in_context_memory.blocks.append(tool_constraint_block) if user_defined_variables is not None: # TODO eventually support the user defining their own variables to inject diff --git a/letta/services/tool_executor/composio_tool_executor.py b/letta/services/tool_executor/composio_tool_executor.py index 8053c521..d2e8e64e 100644 --- a/letta/services/tool_executor/composio_tool_executor.py +++ b/letta/services/tool_executor/composio_tool_executor.py @@ -26,7 +26,11 @@ class ExternalComposioToolExecutor(ToolExecutor): sandbox_config: Optional[SandboxConfig] = None, sandbox_env_vars: Optional[Dict[str, Any]] = None, ) -> ToolExecutionResult: - assert agent_state is not None, "Agent state is required for external Composio tools" + if agent_state is None: + return ToolExecutionResult( + status="error", + func_return="Agent state is required for external Composio tools. Please contact Letta support if you see this error.", + ) action_name = generate_composio_action_from_func_name(tool.name) # Get entity ID from the agent_state