feat: Add prompting to guide tool rule usage (#2742)
This commit is contained in:
@@ -4,6 +4,7 @@ from typing import Any, AsyncGenerator, List, Optional, Union
|
||||
import openai
|
||||
|
||||
from letta.constants import DEFAULT_MAX_STEPS
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -16,6 +17,7 @@ from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.utils import united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -40,6 +42,8 @@ class BaseAgent(ABC):
|
||||
self.openai_client = openai_client
|
||||
self.message_manager = message_manager
|
||||
self.agent_manager = agent_manager
|
||||
# TODO: Pass this in
|
||||
self.passage_manager = PassageManager()
|
||||
self.actor = actor
|
||||
self.logger = get_logger(agent_id)
|
||||
|
||||
@@ -78,8 +82,9 @@ class BaseAgent(ABC):
|
||||
self,
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
num_messages: int | None = None, # storing these calculations is specific to the voice agent
|
||||
num_archival_memories: int | None = None,
|
||||
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
||||
num_messages: Optional[int] = None, # storing these calculations is specific to the voice agent
|
||||
num_archival_memories: Optional[int] = None,
|
||||
) -> List[Message]:
|
||||
"""
|
||||
Async version of function above. For now before breaking up components, changes should be made in both places.
|
||||
@@ -113,6 +118,7 @@ class BaseAgent(ABC):
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
|
||||
@@ -784,7 +784,11 @@ class LettaAgent(BaseAgent):
|
||||
),
|
||||
)
|
||||
in_context_messages = await self._rebuild_memory_async(
|
||||
in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories
|
||||
in_context_messages,
|
||||
agent_state,
|
||||
num_messages=self.num_messages,
|
||||
num_archival_memories=self.num_archival_memories,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
)
|
||||
|
||||
tools = [
|
||||
|
||||
@@ -2,6 +2,7 @@ from typing import List, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.tool_rule import (
|
||||
BaseToolRule,
|
||||
@@ -116,10 +117,10 @@ class ToolRulesSolver(BaseModel):
|
||||
return list(available_tools)
|
||||
else:
|
||||
# Collect valid tools from all child-based rules
|
||||
valid_tool_sets = [
|
||||
rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
||||
for rule in self.child_based_tool_rules + self.parent_tool_rules
|
||||
]
|
||||
valid_tool_sets = []
|
||||
for rule in self.child_based_tool_rules + self.parent_tool_rules:
|
||||
tools = rule.get_valid_tools(self.tool_call_history, available_tools, last_function_response)
|
||||
valid_tool_sets.append(tools)
|
||||
|
||||
# Compute intersection of all valid tool sets
|
||||
final_allowed_tools = set.intersection(*valid_tool_sets) if valid_tool_sets else available_tools
|
||||
@@ -141,6 +142,36 @@ class ToolRulesSolver(BaseModel):
|
||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||
|
||||
def compile_tool_rule_prompts(self) -> Optional[Block]:
|
||||
"""
|
||||
Compile prompt templates from all tool rules into an ephemeral Block.
|
||||
|
||||
Returns:
|
||||
Optional[str]: Compiled prompt string with tool rule constraints, or None if no templates exist.
|
||||
"""
|
||||
compiled_prompts = []
|
||||
|
||||
all_rules = (
|
||||
self.init_tool_rules
|
||||
+ self.continue_tool_rules
|
||||
+ self.child_based_tool_rules
|
||||
+ self.parent_tool_rules
|
||||
+ self.terminal_tool_rules
|
||||
)
|
||||
|
||||
for rule in all_rules:
|
||||
rendered = rule.render_prompt()
|
||||
if rendered:
|
||||
compiled_prompts.append(rendered)
|
||||
|
||||
if compiled_prompts:
|
||||
return Block(
|
||||
label="tool_usage_rules",
|
||||
value="\n".join(compiled_prompts),
|
||||
description="The following constraints define rules for tool usage and guide desired behavior. These rules must be followed to ensure proper tool execution and workflow.",
|
||||
)
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def validate_conditional_tool(rule: ConditionalToolRule):
|
||||
"""
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import Dict, List, Optional, Union
|
||||
|
||||
@@ -271,6 +272,8 @@ class AnthropicClient(LLMClientBase):
|
||||
return data
|
||||
|
||||
async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int:
|
||||
logging.getLogger("httpx").setLevel(logging.WARNING)
|
||||
|
||||
client = anthropic.AsyncAnthropic()
|
||||
if messages and len(messages) == 0:
|
||||
messages = None
|
||||
|
||||
@@ -1,20 +1,43 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Set, Union
|
||||
|
||||
from jinja2 import Template
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaseToolRule(LettaBase):
|
||||
__id_prefix__ = "tool_rule"
|
||||
tool_name: str = Field(..., description="The name of the tool. Must exist in the database for the user's organization.")
|
||||
type: ToolRuleType = Field(..., description="The type of the message.")
|
||||
prompt_template: Optional[str] = Field(
|
||||
None,
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule. Template can use variables like 'tool_name' and rule-specific attributes.",
|
||||
)
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> set[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def render_prompt(self) -> Optional[str]:
|
||||
"""Render the prompt template with this rule's attributes."""
|
||||
if not self.prompt_template:
|
||||
return None
|
||||
|
||||
try:
|
||||
template = Template(self.prompt_template)
|
||||
return template.render(**self.model_dump())
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to render prompt template for tool rule '{self.tool_name}' (type: {self.type}). "
|
||||
f"Template: '{self.prompt_template}'. Error: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
class ChildToolRule(BaseToolRule):
|
||||
"""
|
||||
@@ -128,6 +151,10 @@ class MaxCountPerStepToolRule(BaseToolRule):
|
||||
|
||||
type: Literal[ToolRuleType.max_count_per_step] = ToolRuleType.max_count_per_step
|
||||
max_count_limit: int = Field(..., description="The max limit for the total number of times this tool can be invoked in a single step.")
|
||||
prompt_template: Optional[str] = Field(
|
||||
default="<tool_constraint>{{ tool_name }}: max {{ max_count_limit }} use(s) per turn</tool_constraint>",
|
||||
description="Optional Jinja2 template for generating agent prompt about this tool rule.",
|
||||
)
|
||||
|
||||
def get_valid_tools(self, tool_call_history: List[str], available_tools: Set[str], last_function_response: Optional[str]) -> Set[str]:
|
||||
"""Restricts the tool if it has been called max_count_limit times in the current step."""
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.constants import (
|
||||
FILES_TOOLS,
|
||||
MULTI_AGENT_TOOLS,
|
||||
)
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
@@ -1444,7 +1445,7 @@ class AgentManager:
|
||||
@trace_method
|
||||
@enforce_types
|
||||
async def rebuild_system_prompt_async(
|
||||
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True
|
||||
self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True, tool_rules_solver: Optional[ToolRulesSolver] = None
|
||||
) -> PydanticAgentState:
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
|
||||
@@ -1453,6 +1454,8 @@ class AgentManager:
|
||||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||||
"""
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, include_relationships=["memory"], actor=actor)
|
||||
if not tool_rules_solver:
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
|
||||
curr_system_message = await self.get_system_message_async(
|
||||
agent_id=agent_id, actor=actor
|
||||
@@ -1492,6 +1495,7 @@ class AgentManager:
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
previous_message_count=num_messages,
|
||||
archival_memory_size=num_archival_memories,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
|
||||
@@ -229,6 +229,7 @@ def compile_system_message(
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
@@ -237,6 +238,11 @@ def compile_system_message(
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
# Add tool rule constraints if available
|
||||
if tool_rules_solver is not None:
|
||||
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
|
||||
if tool_constraint_block: # There may not be any depending on if there are tool rules attached
|
||||
in_context_memory.blocks.append(tool_constraint_block)
|
||||
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
|
||||
@@ -26,7 +26,11 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
assert agent_state is not None, "Agent state is required for external Composio tools"
|
||||
if agent_state is None:
|
||||
return ToolExecutionResult(
|
||||
status="error",
|
||||
func_return="Agent state is required for external Composio tools. Please contact Letta support if you see this error.",
|
||||
)
|
||||
action_name = generate_composio_action_from_func_name(tool.name)
|
||||
|
||||
# Get entity ID from the agent_state
|
||||
|
||||
Reference in New Issue
Block a user