feat: add "always continue" tool rule and configure default tool rules (#1033)

Co-authored-by: Shubham Naik <shub@letta.com>
This commit is contained in:
Sarah Wooders
2025-02-19 14:46:37 -08:00
committed by GitHub
parent cc6a965db5
commit 72875c7f63
11 changed files with 113 additions and 20 deletions

View File

@@ -633,6 +633,11 @@ class Agent(BaseAgent):
elif self.tool_rules_solver.is_terminal_tool(function_name):
heartbeat_request = False
# if continue tool rule, then must request a heartbeat
# TODO: dont even include heartbeats in the args
if self.tool_rules_solver.is_continue_tool(function_name):
heartbeat_request = True
log_telemetry(self.logger, "_handle_ai_response finish")
return messages, heartbeat_request, function_failed

View File

@@ -2351,6 +2351,7 @@ class LocalClient(AbstractClient):
tool_rules: Optional[List[BaseToolRule]] = None,
include_base_tools: Optional[bool] = True,
include_multi_agent_tools: bool = False,
include_base_tool_rules: bool = True,
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
@@ -2402,6 +2403,7 @@ class LocalClient(AbstractClient):
"tool_rules": tool_rules,
"include_base_tools": include_base_tools,
"include_multi_agent_tools": include_multi_agent_tools,
"include_base_tool_rules": include_base_tool_rules,
"system": system,
"agent_type": agent_type,
"llm_config": llm_config if llm_config else self._default_llm_config,

View File

@@ -394,12 +394,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
# append the heartbeat
# TODO: don't hard-code
# TODO: if terminal, don't include this
if function.__name__ not in ["send_message"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
schema["parameters"]["required"].append("request_heartbeat")
# if function.__name__ not in ["send_message"]:
schema["parameters"]["properties"]["request_heartbeat"] = {
"type": "boolean",
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
}
schema["parameters"]["required"].append("request_heartbeat")
return schema

View File

@@ -9,7 +9,7 @@ from sqlalchemy import Dialect
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ToolRuleType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule, ToolRule
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule
# --------------------------
# LLMConfig Serialization
@@ -74,7 +74,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu
return [deserialize_tool_rule(rule_data) for rule_data in data]
def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule]:
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
rule_type = ToolRuleType(data.get("type"))
@@ -86,7 +86,8 @@ def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, Term
return ChildToolRule(**data)
elif rule_type == ToolRuleType.conditional:
return ConditionalToolRule(**data)
elif rule_type == ToolRuleType.continue_loop:
return ContinueToolRule(**data)
raise ValueError(f"Unknown ToolRule type: {rule_type}")

View File

@@ -4,7 +4,7 @@ from typing import List, Optional, Union
from pydantic import BaseModel, Field
from letta.schemas.enums import ToolRuleType
from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
class ToolRuleValidationError(Exception):
@@ -18,6 +18,9 @@ class ToolRulesSolver(BaseModel):
init_tool_rules: List[InitToolRule] = Field(
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
)
continue_tool_rules: List[ContinueToolRule] = Field(
default_factory=list, description="Continue tool rules to be used to continue tool execution."
)
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
)
@@ -43,6 +46,9 @@ class ToolRulesSolver(BaseModel):
elif rule.type == ToolRuleType.exit_loop:
assert isinstance(rule, TerminalToolRule)
self.terminal_tool_rules.append(rule)
elif rule.type == ToolRuleType.continue_loop:
assert isinstance(rule, ContinueToolRule)
self.continue_tool_rules.append(rule)
def update_tool_usage(self, tool_name: str):
"""Update the internal state to track the last tool called."""
@@ -80,6 +86,10 @@ class ToolRulesSolver(BaseModel):
"""Check if the tool has children tools"""
return any(rule.tool_name == tool_name for rule in self.tool_rules)
def is_continue_tool(self, tool_name):
"""Check if the tool is defined as a continue tool in the tool rules."""
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
def validate_conditional_tool(self, rule: ConditionalToolRule):
"""
Validate a conditional tool rule

View File

@@ -136,9 +136,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
"""converts to the basic pydantic model counterpart"""
# add default rule for having send_message be a terminal tool
tool_rules = self.tool_rules
if not tool_rules:
tool_rules = [TerminalToolRule(tool_name="send_message"), TerminalToolRule(tool_name="send_message_to_agent_async")]
state = {
"id": self.id,
"organization_id": self.organization_id,

View File

@@ -132,6 +132,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
include_multi_agent_tools: bool = Field(
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
)
include_base_tool_rules: bool = Field(
True, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)."
)
description: Optional[str] = Field(None, description="The description of the agent.")
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
model: Optional[str] = Field(

View File

@@ -48,7 +48,15 @@ class TerminalToolRule(BaseToolRule):
type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop
class ContinueToolRule(BaseToolRule):
"""
Represents a tool rule configuration where if this tool gets called, it must continue the agent loop.
"""
type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop
ToolRule = Annotated[
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule],
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule],
Field(discriminator="type"),
]

View File

@@ -27,6 +27,8 @@ from letta.schemas.message import MessageCreate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
from letta.serialize_schemas import SerializedAgentSchema
@@ -79,10 +81,6 @@ class AgentManager:
if not agent_create.llm_config or not agent_create.embedding_config:
raise ValueError("llm_config and embedding_config are required")
# Check tool rules are valid
if agent_create.tool_rules:
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
# create blocks (note: cannot be linked into the agent_id is created)
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
if agent_create.memory_blocks:
@@ -102,6 +100,25 @@ class AgentManager:
# Remove duplicates
tool_names = list(set(tool_names))
# add default tool rules
if agent_create.include_base_tool_rules:
if not agent_create.tool_rules:
tool_rules = []
else:
tool_rules = agent_create.tool_rules
# apply default tool rules
for tool_name in tool_names:
if tool_name == "send_message" or tool_name == "send_message_to_agent_async":
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
elif tool_name in BASE_TOOLS:
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
else:
tool_rules = agent_create.tool_rules
# Check tool rules are valid
if agent_create.tool_rules:
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
tool_ids = agent_create.tool_ids or []
for tool_name in tool_names:
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
@@ -123,7 +140,7 @@ class AgentManager:
tags=agent_create.tags or [],
description=agent_create.description,
metadata=agent_create.metadata,
tool_rules=agent_create.tool_rules,
tool_rules=tool_rules,
actor=actor,
project_id=agent_create.project_id,
template_id=agent_create.template_id,

View File

@@ -52,6 +52,7 @@ def setup_agent(
tool_rules: Optional[List[BaseToolRule]] = None,
agent_uuid: str = agent_uuid,
include_base_tools: bool = True,
include_base_tool_rules: bool = True,
) -> AgentState:
config_data = json.load(open(filename, "r"))
llm_config = LLMConfig(**config_data)
@@ -72,6 +73,7 @@ def setup_agent(
tool_ids=tool_ids,
tool_rules=tool_rules,
include_base_tools=include_base_tools,
include_base_tool_rules=include_base_tool_rules,
)
return agent_state

View File

@@ -5,7 +5,7 @@ import pytest
from letta import create_client
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
@@ -720,3 +720,51 @@ def test_init_tool_rule_always_fails_multiple_tools():
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls
assert_invoked_function_call(response.messages, bad_tool.name)
def test_continue_tool_rule():
"""Test the continue tool rule by forcing the send_message tool to continue"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
continue_tool_rule = ContinueToolRule(
tool_name="send_message",
)
terminal_tool_rule = TerminalToolRule(
tool_name="core_memory_append",
)
rules = [continue_tool_rule, terminal_tool_rule]
core_memory_append_tool = client.get_tool_id("core_memory_append")
send_message_tool = client.get_tool_id("send_message")
# Set up agent with the tool rule
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(
client,
claude_config,
agent_uuid,
tool_rules=rules,
tool_ids=[core_memory_append_tool, send_message_tool],
include_base_tools=False,
include_base_tool_rules=False,
)
# Start conversation
response = client.user_message(agent_id=agent_state.id, message="blah blah blah")
# Verify the tool calls
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
assert len(tool_calls) >= 1
assert_invoked_function_call(response.messages, "send_message")
assert_invoked_function_call(response.messages, "core_memory_append")
# ensure send_message called before core_memory_append
send_message_call_index = None
core_memory_append_call_index = None
for i, call in enumerate(tool_calls):
if call.tool_call.name == "send_message":
send_message_call_index = i
if call.tool_call.name == "core_memory_append":
core_memory_append_call_index = i
assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append"