feat: add "always continue" tool rule and configure default tool rules (#1033)

Co-authored-by: Shubham Naik <shub@letta.com>
This commit is contained in:
Sarah Wooders
2025-02-19 14:46:37 -08:00
committed by GitHub
parent cc6a965db5
commit 72875c7f63
11 changed files with 113 additions and 20 deletions

View File

@@ -5,7 +5,7 @@ import pytest
from letta import create_client
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
@@ -720,3 +720,51 @@ def test_init_tool_rule_always_fails_multiple_tools():
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls
assert_invoked_function_call(response.messages, bad_tool.name)
def test_continue_tool_rule():
"""Test the continue tool rule by forcing the send_message tool to continue"""
client = create_client()
cleanup(client=client, agent_uuid=agent_uuid)
continue_tool_rule = ContinueToolRule(
tool_name="send_message",
)
terminal_tool_rule = TerminalToolRule(
tool_name="core_memory_append",
)
rules = [continue_tool_rule, terminal_tool_rule]
core_memory_append_tool = client.get_tool_id("core_memory_append")
send_message_tool = client.get_tool_id("send_message")
# Set up agent with the tool rule
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_state = setup_agent(
client,
claude_config,
agent_uuid,
tool_rules=rules,
tool_ids=[core_memory_append_tool, send_message_tool],
include_base_tools=False,
include_base_tool_rules=False,
)
# Start conversation
response = client.user_message(agent_id=agent_state.id, message="blah blah blah")
# Verify the tool calls
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
assert len(tool_calls) >= 1
assert_invoked_function_call(response.messages, "send_message")
assert_invoked_function_call(response.messages, "core_memory_append")
# ensure send_message called before core_memory_append
send_message_call_index = None
core_memory_append_call_index = None
for i, call in enumerate(tool_calls):
if call.tool_call.name == "send_message":
send_message_call_index = i
if call.tool_call.name == "core_memory_append":
core_memory_append_call_index = i
assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append"