feat: add "always continue" tool rule and configure default tool rules (#1033)
Co-authored-by: Shubham Naik <shub@letta.com>
This commit is contained in:
@@ -633,6 +633,11 @@ class Agent(BaseAgent):
|
||||
elif self.tool_rules_solver.is_terminal_tool(function_name):
|
||||
heartbeat_request = False
|
||||
|
||||
# if continue tool rule, then must request a heartbeat
|
||||
# TODO: dont even include heartbeats in the args
|
||||
if self.tool_rules_solver.is_continue_tool(function_name):
|
||||
heartbeat_request = True
|
||||
|
||||
log_telemetry(self.logger, "_handle_ai_response finish")
|
||||
return messages, heartbeat_request, function_failed
|
||||
|
||||
|
||||
@@ -2351,6 +2351,7 @@ class LocalClient(AbstractClient):
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
include_multi_agent_tools: bool = False,
|
||||
include_base_tool_rules: bool = True,
|
||||
# metadata
|
||||
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
|
||||
description: Optional[str] = None,
|
||||
@@ -2402,6 +2403,7 @@ class LocalClient(AbstractClient):
|
||||
"tool_rules": tool_rules,
|
||||
"include_base_tools": include_base_tools,
|
||||
"include_multi_agent_tools": include_multi_agent_tools,
|
||||
"include_base_tool_rules": include_base_tool_rules,
|
||||
"system": system,
|
||||
"agent_type": agent_type,
|
||||
"llm_config": llm_config if llm_config else self._default_llm_config,
|
||||
|
||||
@@ -394,12 +394,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
# append the heartbeat
|
||||
# TODO: don't hard-code
|
||||
# TODO: if terminal, don't include this
|
||||
if function.__name__ not in ["send_message"]:
|
||||
schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
schema["parameters"]["required"].append("request_heartbeat")
|
||||
# if function.__name__ not in ["send_message"]:
|
||||
schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from sqlalchemy import Dialect
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule, ToolRule
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule
|
||||
|
||||
# --------------------------
|
||||
# LLMConfig Serialization
|
||||
@@ -74,7 +74,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu
|
||||
return [deserialize_tool_rule(rule_data) for rule_data in data]
|
||||
|
||||
|
||||
def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]:
|
||||
def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule]:
|
||||
"""Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'."""
|
||||
rule_type = ToolRuleType(data.get("type"))
|
||||
|
||||
@@ -86,7 +86,8 @@ def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, Term
|
||||
return ChildToolRule(**data)
|
||||
elif rule_type == ToolRuleType.conditional:
|
||||
return ConditionalToolRule(**data)
|
||||
|
||||
elif rule_type == ToolRuleType.continue_loop:
|
||||
return ContinueToolRule(**data)
|
||||
raise ValueError(f"Unknown ToolRule type: {rule_type}")
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ from typing import List, Optional, Union
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.enums import ToolRuleType
|
||||
from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
|
||||
|
||||
|
||||
class ToolRuleValidationError(Exception):
|
||||
@@ -18,6 +18,9 @@ class ToolRulesSolver(BaseModel):
|
||||
init_tool_rules: List[InitToolRule] = Field(
|
||||
default_factory=list, description="Initial tool rules to be used at the start of tool execution."
|
||||
)
|
||||
continue_tool_rules: List[ContinueToolRule] = Field(
|
||||
default_factory=list, description="Continue tool rules to be used to continue tool execution."
|
||||
)
|
||||
tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field(
|
||||
default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions."
|
||||
)
|
||||
@@ -43,6 +46,9 @@ class ToolRulesSolver(BaseModel):
|
||||
elif rule.type == ToolRuleType.exit_loop:
|
||||
assert isinstance(rule, TerminalToolRule)
|
||||
self.terminal_tool_rules.append(rule)
|
||||
elif rule.type == ToolRuleType.continue_loop:
|
||||
assert isinstance(rule, ContinueToolRule)
|
||||
self.continue_tool_rules.append(rule)
|
||||
|
||||
def update_tool_usage(self, tool_name: str):
|
||||
"""Update the internal state to track the last tool called."""
|
||||
@@ -80,6 +86,10 @@ class ToolRulesSolver(BaseModel):
|
||||
"""Check if the tool has children tools"""
|
||||
return any(rule.tool_name == tool_name for rule in self.tool_rules)
|
||||
|
||||
def is_continue_tool(self, tool_name):
|
||||
"""Check if the tool is defined as a continue tool in the tool rules."""
|
||||
return any(rule.tool_name == tool_name for rule in self.continue_tool_rules)
|
||||
|
||||
def validate_conditional_tool(self, rule: ConditionalToolRule):
|
||||
"""
|
||||
Validate a conditional tool rule
|
||||
|
||||
@@ -136,9 +136,6 @@ class Agent(SqlalchemyBase, OrganizationMixin):
|
||||
"""converts to the basic pydantic model counterpart"""
|
||||
# add default rule for having send_message be a terminal tool
|
||||
tool_rules = self.tool_rules
|
||||
if not tool_rules:
|
||||
tool_rules = [TerminalToolRule(tool_name="send_message"), TerminalToolRule(tool_name="send_message_to_agent_async")]
|
||||
|
||||
state = {
|
||||
"id": self.id,
|
||||
"organization_id": self.organization_id,
|
||||
|
||||
@@ -132,6 +132,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
include_multi_agent_tools: bool = Field(
|
||||
False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)."
|
||||
)
|
||||
include_base_tool_rules: bool = Field(
|
||||
True, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)."
|
||||
)
|
||||
description: Optional[str] = Field(None, description="The description of the agent.")
|
||||
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
|
||||
model: Optional[str] = Field(
|
||||
|
||||
@@ -48,7 +48,15 @@ class TerminalToolRule(BaseToolRule):
|
||||
type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop
|
||||
|
||||
|
||||
class ContinueToolRule(BaseToolRule):
|
||||
"""
|
||||
Represents a tool rule configuration where if this tool gets called, it must continue the agent loop.
|
||||
"""
|
||||
|
||||
type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop
|
||||
|
||||
|
||||
ToolRule = Annotated[
|
||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule],
|
||||
Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
@@ -27,6 +27,8 @@ from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule
|
||||
from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.serialize_schemas import SerializedAgentSchema
|
||||
@@ -79,10 +81,6 @@ class AgentManager:
|
||||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||||
raise ValueError("llm_config and embedding_config are required")
|
||||
|
||||
# Check tool rules are valid
|
||||
if agent_create.tool_rules:
|
||||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
|
||||
|
||||
# create blocks (note: cannot be linked into the agent_id is created)
|
||||
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
||||
if agent_create.memory_blocks:
|
||||
@@ -102,6 +100,25 @@ class AgentManager:
|
||||
# Remove duplicates
|
||||
tool_names = list(set(tool_names))
|
||||
|
||||
# add default tool rules
|
||||
if agent_create.include_base_tool_rules:
|
||||
if not agent_create.tool_rules:
|
||||
tool_rules = []
|
||||
else:
|
||||
tool_rules = agent_create.tool_rules
|
||||
|
||||
# apply default tool rules
|
||||
for tool_name in tool_names:
|
||||
if tool_name == "send_message" or tool_name == "send_message_to_agent_async":
|
||||
tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name))
|
||||
elif tool_name in BASE_TOOLS:
|
||||
tool_rules.append(PydanticContinueToolRule(tool_name=tool_name))
|
||||
else:
|
||||
tool_rules = agent_create.tool_rules
|
||||
# Check tool rules are valid
|
||||
if agent_create.tool_rules:
|
||||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
|
||||
|
||||
tool_ids = agent_create.tool_ids or []
|
||||
for tool_name in tool_names:
|
||||
tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
|
||||
@@ -123,7 +140,7 @@ class AgentManager:
|
||||
tags=agent_create.tags or [],
|
||||
description=agent_create.description,
|
||||
metadata=agent_create.metadata,
|
||||
tool_rules=agent_create.tool_rules,
|
||||
tool_rules=tool_rules,
|
||||
actor=actor,
|
||||
project_id=agent_create.project_id,
|
||||
template_id=agent_create.template_id,
|
||||
|
||||
@@ -52,6 +52,7 @@ def setup_agent(
|
||||
tool_rules: Optional[List[BaseToolRule]] = None,
|
||||
agent_uuid: str = agent_uuid,
|
||||
include_base_tools: bool = True,
|
||||
include_base_tool_rules: bool = True,
|
||||
) -> AgentState:
|
||||
config_data = json.load(open(filename, "r"))
|
||||
llm_config = LLMConfig(**config_data)
|
||||
@@ -72,6 +73,7 @@ def setup_agent(
|
||||
tool_ids=tool_ids,
|
||||
tool_rules=tool_rules,
|
||||
include_base_tools=include_base_tools,
|
||||
include_base_tool_rules=include_base_tool_rules,
|
||||
)
|
||||
|
||||
return agent_state
|
||||
|
||||
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
@@ -720,3 +720,51 @@ def test_init_tool_rule_always_fails_multiple_tools():
|
||||
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
|
||||
assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls
|
||||
assert_invoked_function_call(response.messages, bad_tool.name)
|
||||
|
||||
|
||||
def test_continue_tool_rule():
|
||||
"""Test the continue tool rule by forcing the send_message tool to continue"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
continue_tool_rule = ContinueToolRule(
|
||||
tool_name="send_message",
|
||||
)
|
||||
terminal_tool_rule = TerminalToolRule(
|
||||
tool_name="core_memory_append",
|
||||
)
|
||||
rules = [continue_tool_rule, terminal_tool_rule]
|
||||
|
||||
core_memory_append_tool = client.get_tool_id("core_memory_append")
|
||||
send_message_tool = client.get_tool_id("send_message")
|
||||
|
||||
# Set up agent with the tool rule
|
||||
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(
|
||||
client,
|
||||
claude_config,
|
||||
agent_uuid,
|
||||
tool_rules=rules,
|
||||
tool_ids=[core_memory_append_tool, send_message_tool],
|
||||
include_base_tools=False,
|
||||
include_base_tool_rules=False,
|
||||
)
|
||||
|
||||
# Start conversation
|
||||
response = client.user_message(agent_id=agent_state.id, message="blah blah blah")
|
||||
|
||||
# Verify the tool calls
|
||||
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
|
||||
assert len(tool_calls) >= 1
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
assert_invoked_function_call(response.messages, "core_memory_append")
|
||||
|
||||
# ensure send_message called before core_memory_append
|
||||
send_message_call_index = None
|
||||
core_memory_append_call_index = None
|
||||
for i, call in enumerate(tool_calls):
|
||||
if call.tool_call.name == "send_message":
|
||||
send_message_call_index = i
|
||||
if call.tool_call.name == "core_memory_append":
|
||||
core_memory_append_call_index = i
|
||||
assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append"
|
||||
|
||||
Reference in New Issue
Block a user