From 72875c7f63bd11c2f8fbd41f6e1b036978dc5c71 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Wed, 19 Feb 2025 14:46:37 -0800 Subject: [PATCH] feat: add "always continue" tool rule and configure default tool rules (#1033) Co-authored-by: Shubham Naik --- letta/agent.py | 5 +++ letta/client/client.py | 2 + letta/functions/schema_generator.py | 12 +++--- letta/helpers/converters.py | 7 +-- letta/helpers/tool_rule_solver.py | 12 +++++- letta/orm/agent.py | 3 -- letta/schemas/agent.py | 3 ++ letta/schemas/tool_rule.py | 10 ++++- letta/services/agent_manager.py | 27 +++++++++--- tests/helpers/endpoints_helper.py | 2 + tests/integration_test_agent_tool_graph.py | 50 +++++++++++++++++++++- 11 files changed, 113 insertions(+), 20 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 4ac5d792..813f4801 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -633,6 +633,11 @@ class Agent(BaseAgent): elif self.tool_rules_solver.is_terminal_tool(function_name): heartbeat_request = False + # if continue tool rule, then must request a heartbeat + # TODO: dont even include heartbeats in the args + if self.tool_rules_solver.is_continue_tool(function_name): + heartbeat_request = True + log_telemetry(self.logger, "_handle_ai_response finish") return messages, heartbeat_request, function_failed diff --git a/letta/client/client.py b/letta/client/client.py index ed7a3220..58f680c9 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2351,6 +2351,7 @@ class LocalClient(AbstractClient): tool_rules: Optional[List[BaseToolRule]] = None, include_base_tools: Optional[bool] = True, include_multi_agent_tools: bool = False, + include_base_tool_rules: bool = True, # metadata metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA}, description: Optional[str] = None, @@ -2402,6 +2403,7 @@ class LocalClient(AbstractClient): "tool_rules": tool_rules, "include_base_tools": include_base_tools, "include_multi_agent_tools": include_multi_agent_tools, + "include_base_tool_rules": include_base_tool_rules, "system": system, "agent_type": agent_type, "llm_config": llm_config if llm_config else self._default_llm_config, diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 3b1560e8..8273b4bb 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -394,12 +394,12 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ # append the heartbeat # TODO: don't hard-code # TODO: if terminal, don't include this - if function.__name__ not in ["send_message"]: - schema["parameters"]["properties"]["request_heartbeat"] = { - "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", - } - schema["parameters"]["required"].append("request_heartbeat") + # if function.__name__ not in ["send_message"]: + schema["parameters"]["properties"]["request_heartbeat"] = { + "type": "boolean", + "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", + } + schema["parameters"]["required"].append("request_heartbeat") return schema diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 56757ef8..f65d2653 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -9,7 +9,7 @@ from sqlalchemy import Dialect from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import ToolRuleType from letta.schemas.llm_config import LLMConfig -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule, ToolRule +from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule, ToolRule # -------------------------- # LLMConfig Serialization @@ -74,7 +74,7 @@ def deserialize_tool_rules(data: Optional[List[Dict]]) -> List[Union[ChildToolRu return [deserialize_tool_rule(rule_data) for rule_data in data] -def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule]: +def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule]: """Deserialize a dictionary to the appropriate ToolRule subclass based on 'type'.""" rule_type = ToolRuleType(data.get("type")) @@ -86,7 +86,8 @@ def deserialize_tool_rule(data: Dict) -> Union[ChildToolRule, InitToolRule, Term return ChildToolRule(**data) elif rule_type == ToolRuleType.conditional: return ConditionalToolRule(**data) - + elif rule_type == ToolRuleType.continue_loop: + return ContinueToolRule(**data) raise ValueError(f"Unknown ToolRule type: {rule_type}") diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index cba8a0ca..ca885616 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -4,7 +4,7 @@ from typing import List, Optional, Union from pydantic import BaseModel, Field from letta.schemas.enums import ToolRuleType -from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import BaseToolRule, ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule class ToolRuleValidationError(Exception): @@ -18,6 +18,9 @@ class ToolRulesSolver(BaseModel): init_tool_rules: List[InitToolRule] = Field( default_factory=list, description="Initial tool rules to be used at the start of tool execution." ) + continue_tool_rules: List[ContinueToolRule] = Field( + default_factory=list, description="Continue tool rules to be used to continue tool execution." + ) tool_rules: List[Union[ChildToolRule, ConditionalToolRule]] = Field( default_factory=list, description="Standard tool rules for controlling execution sequence and allowed transitions." ) @@ -43,6 +46,9 @@ class ToolRulesSolver(BaseModel): elif rule.type == ToolRuleType.exit_loop: assert isinstance(rule, TerminalToolRule) self.terminal_tool_rules.append(rule) + elif rule.type == ToolRuleType.continue_loop: + assert isinstance(rule, ContinueToolRule) + self.continue_tool_rules.append(rule) def update_tool_usage(self, tool_name: str): """Update the internal state to track the last tool called.""" @@ -80,6 +86,10 @@ class ToolRulesSolver(BaseModel): """Check if the tool has children tools""" return any(rule.tool_name == tool_name for rule in self.tool_rules) + def is_continue_tool(self, tool_name): + """Check if the tool is defined as a continue tool in the tool rules.""" + return any(rule.tool_name == tool_name for rule in self.continue_tool_rules) + def validate_conditional_tool(self, rule: ConditionalToolRule): """ Validate a conditional tool rule diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 918e5fa9..d8b90b56 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -136,9 +136,6 @@ class Agent(SqlalchemyBase, OrganizationMixin): """converts to the basic pydantic model counterpart""" # add default rule for having send_message be a terminal tool tool_rules = self.tool_rules - if not tool_rules: - tool_rules = [TerminalToolRule(tool_name="send_message"), TerminalToolRule(tool_name="send_message_to_agent_async")] - state = { "id": self.id, "organization_id": self.organization_id, diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 2be37939..089f2fd3 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -132,6 +132,9 @@ class CreateAgent(BaseModel, validate_assignment=True): # include_multi_agent_tools: bool = Field( False, description="If true, attaches the Letta multi-agent tools (e.g. sending a message to another agent)." ) + include_base_tool_rules: bool = Field( + True, description="If true, attaches the Letta base tool rules (e.g. deny all tools not explicitly allowed)." + ) description: Optional[str] = Field(None, description="The description of the agent.") metadata: Optional[Dict] = Field(None, description="The metadata of the agent.") model: Optional[str] = Field( diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index fd1f66cd..e0065e68 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -48,7 +48,15 @@ class TerminalToolRule(BaseToolRule): type: Literal[ToolRuleType.exit_loop] = ToolRuleType.exit_loop +class ContinueToolRule(BaseToolRule): + """ + Represents a tool rule configuration where if this tool gets called, it must continue the agent loop. + """ + + type: Literal[ToolRuleType.continue_loop] = ToolRuleType.continue_loop + + ToolRule = Annotated[ - Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule], + Union[ChildToolRule, InitToolRule, TerminalToolRule, ConditionalToolRule, ContinueToolRule], Field(discriminator="type"), ] diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 23922326..d921540d 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -27,6 +27,8 @@ from letta.schemas.message import MessageCreate from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.source import Source as PydanticSource from letta.schemas.tool import Tool as PydanticTool +from letta.schemas.tool_rule import ContinueToolRule as PydanticContinueToolRule +from letta.schemas.tool_rule import TerminalToolRule as PydanticTerminalToolRule from letta.schemas.tool_rule import ToolRule as PydanticToolRule from letta.schemas.user import User as PydanticUser from letta.serialize_schemas import SerializedAgentSchema @@ -79,10 +81,6 @@ class AgentManager: if not agent_create.llm_config or not agent_create.embedding_config: raise ValueError("llm_config and embedding_config are required") - # Check tool rules are valid - if agent_create.tool_rules: - check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules) - # create blocks (note: cannot be linked into the agent_id is created) block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original if agent_create.memory_blocks: @@ -102,6 +100,25 @@ class AgentManager: # Remove duplicates tool_names = list(set(tool_names)) + # add default tool rules + if agent_create.include_base_tool_rules: + if not agent_create.tool_rules: + tool_rules = [] + else: + tool_rules = agent_create.tool_rules + + # apply default tool rules + for tool_name in tool_names: + if tool_name == "send_message" or tool_name == "send_message_to_agent_async": + tool_rules.append(PydanticTerminalToolRule(tool_name=tool_name)) + elif tool_name in BASE_TOOLS: + tool_rules.append(PydanticContinueToolRule(tool_name=tool_name)) + else: + tool_rules = agent_create.tool_rules + # Check tool rules are valid + if agent_create.tool_rules: + check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules) + tool_ids = agent_create.tool_ids or [] for tool_name in tool_names: tool = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor) @@ -123,7 +140,7 @@ class AgentManager: tags=agent_create.tags or [], description=agent_create.description, metadata=agent_create.metadata, - tool_rules=agent_create.tool_rules, + tool_rules=tool_rules, actor=actor, project_id=agent_create.project_id, template_id=agent_create.template_id, diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index 36fd66c7..2c721262 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -52,6 +52,7 @@ def setup_agent( tool_rules: Optional[List[BaseToolRule]] = None, agent_uuid: str = agent_uuid, include_base_tools: bool = True, + include_base_tool_rules: bool = True, ) -> AgentState: config_data = json.load(open(filename, "r")) llm_config = LLMConfig(**config_data) @@ -72,6 +73,7 @@ def setup_agent( tool_ids=tool_ids, tool_rules=tool_rules, include_base_tools=include_base_tools, + include_base_tool_rules=include_base_tool_rules, ) return agent_state diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 97b8709f..9c931ee2 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -5,7 +5,7 @@ import pytest from letta import create_client from letta.schemas.letta_message import ToolCallMessage -from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule +from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule from tests.helpers.endpoints_helper import ( assert_invoked_function_call, assert_invoked_send_message_with_keyword, @@ -720,3 +720,51 @@ def test_init_tool_rule_always_fails_multiple_tools(): tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls assert_invoked_function_call(response.messages, bad_tool.name) + + +def test_continue_tool_rule(): + """Test the continue tool rule by forcing the send_message tool to continue""" + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + continue_tool_rule = ContinueToolRule( + tool_name="send_message", + ) + terminal_tool_rule = TerminalToolRule( + tool_name="core_memory_append", + ) + rules = [continue_tool_rule, terminal_tool_rule] + + core_memory_append_tool = client.get_tool_id("core_memory_append") + send_message_tool = client.get_tool_id("send_message") + + # Set up agent with the tool rule + claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json" + agent_state = setup_agent( + client, + claude_config, + agent_uuid, + tool_rules=rules, + tool_ids=[core_memory_append_tool, send_message_tool], + include_base_tools=False, + include_base_tool_rules=False, + ) + + # Start conversation + response = client.user_message(agent_id=agent_state.id, message="blah blah blah") + + # Verify the tool calls + tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] + assert len(tool_calls) >= 1 + assert_invoked_function_call(response.messages, "send_message") + assert_invoked_function_call(response.messages, "core_memory_append") + + # ensure send_message called before core_memory_append + send_message_call_index = None + core_memory_append_call_index = None + for i, call in enumerate(tool_calls): + if call.tool_call.name == "send_message": + send_message_call_index = i + if call.tool_call.name == "core_memory_append": + core_memory_append_call_index = i + assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append"