feat: add "always continue" tool rule and configure default tool rules (#1033)
Co-authored-by: Shubham Naik <shub@letta.com>
This commit is contained in:
@@ -5,7 +5,7 @@ import pytest
|
||||
|
||||
from letta import create_client
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import ChildToolRule, ConditionalToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
@@ -720,3 +720,51 @@ def test_init_tool_rule_always_fails_multiple_tools():
|
||||
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
|
||||
assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls
|
||||
assert_invoked_function_call(response.messages, bad_tool.name)
|
||||
|
||||
|
||||
def test_continue_tool_rule():
|
||||
"""Test the continue tool rule by forcing the send_message tool to continue"""
|
||||
client = create_client()
|
||||
cleanup(client=client, agent_uuid=agent_uuid)
|
||||
|
||||
continue_tool_rule = ContinueToolRule(
|
||||
tool_name="send_message",
|
||||
)
|
||||
terminal_tool_rule = TerminalToolRule(
|
||||
tool_name="core_memory_append",
|
||||
)
|
||||
rules = [continue_tool_rule, terminal_tool_rule]
|
||||
|
||||
core_memory_append_tool = client.get_tool_id("core_memory_append")
|
||||
send_message_tool = client.get_tool_id("send_message")
|
||||
|
||||
# Set up agent with the tool rule
|
||||
claude_config = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_state = setup_agent(
|
||||
client,
|
||||
claude_config,
|
||||
agent_uuid,
|
||||
tool_rules=rules,
|
||||
tool_ids=[core_memory_append_tool, send_message_tool],
|
||||
include_base_tools=False,
|
||||
include_base_tool_rules=False,
|
||||
)
|
||||
|
||||
# Start conversation
|
||||
response = client.user_message(agent_id=agent_state.id, message="blah blah blah")
|
||||
|
||||
# Verify the tool calls
|
||||
tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)]
|
||||
assert len(tool_calls) >= 1
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
assert_invoked_function_call(response.messages, "core_memory_append")
|
||||
|
||||
# ensure send_message called before core_memory_append
|
||||
send_message_call_index = None
|
||||
core_memory_append_call_index = None
|
||||
for i, call in enumerate(tool_calls):
|
||||
if call.tool_call.name == "send_message":
|
||||
send_message_call_index = i
|
||||
if call.tool_call.name == "core_memory_append":
|
||||
core_memory_append_call_index = i
|
||||
assert send_message_call_index < core_memory_append_call_index, "send_message should have been called before core_memory_append"
|
||||
|
||||
Reference in New Issue
Block a user