feat: Add required before exit tool rule (#2977)

This commit is contained in:
Matthew Zhou
2025-06-23 17:02:40 -07:00
committed by GitHub
parent 343dbb5359
commit 54562d88d7
12 changed files with 495 additions and 82 deletions

View File

@@ -1,15 +1,15 @@
import time
import asyncio
import uuid
import pytest
from letta.agents.letta_agent import LettaAgent
from letta.config import LettaConfig
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
from letta.server.server import SyncServer
from letta.services.telemetry_manager import NoopTelemetryManager
from tests.helpers.endpoints_helper import (
assert_invoked_function_call,
assert_invoked_send_message_with_keyword,
@@ -25,6 +25,13 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
@pytest.fixture(scope="module")
def event_loop():
loop = asyncio.new_event_loop()
yield loop
loop.close()
@pytest.fixture()
def server():
config = LettaConfig.load()
@@ -181,13 +188,83 @@ def auto_error_tool(server):
yield tool
@pytest.fixture(scope="function")
def save_data_tool(server):
def save_data():
"""
Saves important data before exiting.
Returns:
str: Confirmation that data was saved.
"""
return "Data saved successfully"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=save_data), actor=actor)
yield tool
@pytest.fixture(scope="function")
def cleanup_temp_files_tool(server):
def cleanup_temp_files():
"""
Cleans up temporary files before exiting.
Returns:
str: Confirmation that cleanup was completed.
"""
return "Temporary files cleaned up"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=cleanup_temp_files), actor=actor)
yield tool
@pytest.fixture(scope="function")
def validate_work_tool(server):
def validate_work():
"""
Validates that work is complete before exiting.
Returns:
str: Validation result.
"""
return "Work validation passed"
actor = server.user_manager.get_user_or_default()
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=validate_work), actor=actor)
yield tool
@pytest.fixture
def default_user(server):
yield server.user_manager.get_user_or_default()
async def run_agent_step(server, agent_id, input_messages, actor):
"""Helper function to run agent step using LettaAgent directly instead of server.send_messages."""
agent_loop = LettaAgent(
agent_id=agent_id,
message_manager=server.message_manager,
agent_manager=server.agent_manager,
block_manager=server.block_manager,
job_manager=server.job_manager,
passage_manager=server.passage_manager,
actor=actor,
step_manager=server.step_manager,
telemetry_manager=NoopTelemetryManager(),
)
return await agent_loop.step(
input_messages,
max_steps=50,
use_assistant_message=False,
)
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
def test_single_path_agent_tool_call_graph(
@pytest.mark.asyncio
async def test_single_path_agent_tool_call_graph(
server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool, default_user
):
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
@@ -207,18 +284,11 @@ def test_single_path_agent_tool_call_graph(
# Make agent state
agent_state = setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
usage_stats = server.send_messages(
actor=default_user,
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
)
messages = [message for step_messages in usage_stats.steps_messages for message in step_messages]
letta_messages = []
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(
messages=letta_messages, stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), usage=usage_stats
actor=default_user,
)
# Make checks
@@ -299,7 +369,8 @@ def test_check_tool_rules_with_different_models_parametrized(
@pytest.mark.timeout(180)
def test_claude_initial_tool_rule_enforced(
@pytest.mark.asyncio
async def test_claude_initial_tool_rule_enforced(
server,
disable_e2b_api_key,
first_secret_tool,
@@ -325,20 +396,11 @@ def test_claude_initial_tool_rule_enforced(
tool_rules=tool_rules,
)
usage_stats = server.send_messages(
actor=default_user,
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="What is the second secret word?")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = []
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
actor=default_user,
)
assert_sanity_checks(response)
@@ -359,7 +421,7 @@ def test_claude_initial_tool_rule_enforced(
# Exponential backoff
if i < 2:
backoff_time = 10 * (2**i)
time.sleep(backoff_time)
await asyncio.sleep(backoff_time)
@pytest.mark.timeout(60)
@@ -370,7 +432,8 @@ def test_claude_initial_tool_rule_enforced(
"tests/configs/llm_model_configs/openai-gpt-4o.json",
],
)
def test_agent_no_structured_output_with_one_child_tool_parametrized(
@pytest.mark.asyncio
async def test_agent_no_structured_output_with_one_child_tool_parametrized(
server,
disable_e2b_api_key,
default_user,
@@ -404,20 +467,11 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized(
tool_rules=tool_rules,
)
usage_stats = server.send_messages(
actor=default_user,
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="hi. run archival memory search")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = []
for m in messages:
letta_messages += m.to_letta_messages()
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
actor=default_user,
)
# Run assertions
@@ -448,7 +502,8 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized(
@pytest.mark.timeout(30)
@pytest.mark.parametrize("include_base_tools", [False, True])
def test_init_tool_rule_always_fails(
@pytest.mark.asyncio
async def test_init_tool_rule_always_fails(
server,
disable_e2b_api_key,
auto_error_tool,
@@ -469,17 +524,11 @@ def test_init_tool_rule_always_fails(
include_base_tools=include_base_tools,
)
usage_stats = server.send_messages(
actor=default_user,
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="blah blah blah")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
actor=default_user,
)
assert_invoked_function_call(response.messages, auto_error_tool.name)
@@ -487,7 +536,8 @@ def test_init_tool_rule_always_fails(
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
def test_continue_tool_rule(server, default_user):
@pytest.mark.asyncio
async def test_continue_tool_rule(server, default_user):
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
agent_uuid = str(uuid.uuid4())
@@ -512,17 +562,11 @@ def test_continue_tool_rule(server, default_user):
include_base_tool_rules=False,
)
usage_stats = server.send_messages(
actor=default_user,
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")],
)
messages = [m for step in usage_stats.steps_messages for m in step]
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
response = LettaResponse(
messages=letta_messages,
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
usage=usage_stats,
actor=default_user,
)
assert_invoked_function_call(response.messages, "send_message")
@@ -775,3 +819,180 @@ def test_continue_tool_rule(server, default_user):
# assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin"
#
# cleanup(client, agent_uuid=agent_state.id)
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user):
"""Test that agent is forced to call a single required-before-exit tool before ending."""
agent_name = "required_exit_single_tool_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules
tools = [save_data_tool]
tool_rules = [
InitToolRule(tool_name="send_message"),
RequiredBeforeExitToolRule(tool_name="save_data"),
TerminalToolRule(tool_name="send_message"),
]
# Create agent
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that would normally cause exit
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Please finish your work and send me a message.")],
actor=default_user,
)
# Assertions
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "save_data")
assert_invoked_function_call(response.messages, "send_message")
# The key test is that both tools were called - the agent was forced to call save_data
# even when it tried to exit early with send_message
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
assert len(save_data_calls) >= 1, "save_data should be called at least once"
assert len(send_message_calls) >= 1, "send_message should be called at least once"
print(f"✓ Agent '{agent_name}' successfully called required tool before exit")
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user):
"""Test that agent calls all required-before-exit tools before ending."""
agent_name = "required_exit_multi_tool_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules
tools = [save_data_tool, cleanup_temp_files_tool]
tool_rules = [
InitToolRule(tool_name="send_message"),
RequiredBeforeExitToolRule(tool_name="save_data"),
RequiredBeforeExitToolRule(tool_name="cleanup_temp_files"),
TerminalToolRule(tool_name="send_message"),
]
# Create agent
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that would normally cause exit
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Complete all necessary tasks and then send me a message.")],
actor=default_user,
)
# Assertions
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "save_data")
assert_invoked_function_call(response.messages, "cleanup_temp_files")
assert_invoked_function_call(response.messages, "send_message")
# Verify that all required tools were eventually called
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
cleanup_calls = [tc for tc in tool_calls if tc.tool_call.name == "cleanup_temp_files"]
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
assert len(save_data_calls) >= 1, "save_data should be called at least once"
assert len(cleanup_calls) >= 1, "cleanup_temp_files should be called at least once"
assert len(send_message_calls) >= 1, "send_message should be called at least once"
print(f"✓ Agent '{agent_name}' successfully called all required tools before exit")
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user):
"""Test required-before-exit rules work alongside other tool rules."""
agent_name = "required_exit_with_rules_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules - combine with child tool rules
tools = [first_secret_tool, save_data_tool]
tool_rules = [
InitToolRule(tool_name="first_secret_word"),
ChildToolRule(tool_name="first_secret_word", children=["send_message"]),
RequiredBeforeExitToolRule(tool_name="save_data"),
TerminalToolRule(tool_name="send_message"),
]
# Create agent
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that would trigger tool flow
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Get the first secret word and then finish up.")],
actor=default_user,
)
# Assertions
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "first_secret_word")
assert_invoked_function_call(response.messages, "save_data")
assert_invoked_function_call(response.messages, "send_message")
# Verify that all tools were called (first_secret_word due to InitToolRule, save_data due to RequiredBeforeExitToolRule)
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
first_secret_calls = [tc for tc in tool_calls if tc.tool_call.name == "first_secret_word"]
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
assert len(first_secret_calls) >= 1, "first_secret_word should be called due to InitToolRule"
assert len(save_data_calls) >= 1, "save_data should be called due to RequiredBeforeExitToolRule"
assert len(send_message_calls) >= 1, "send_message should be called eventually"
print(f"✓ Agent '{agent_name}' successfully handled mixed tool rules")
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user):
"""Test that agent can exit normally when required tools are called during regular operation."""
agent_name = "required_exit_normal_flow_agent"
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
# Set up tools and rules
tools = [save_data_tool]
tool_rules = [
InitToolRule(tool_name="save_data"),
RequiredBeforeExitToolRule(tool_name="send_message"),
TerminalToolRule(tool_name="send_message"),
]
# Create agent
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
# Send message that explicitly mentions calling the required tool
response = await run_agent_step(
server=server,
agent_id=agent_state.id,
input_messages=[MessageCreate(role="user", content="Please save data and then send me a message when done.")],
actor=default_user,
)
# Assertions
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "save_data")
assert_invoked_function_call(response.messages, "send_message")
# Should not have excessive tool calls - agent should exit cleanly after requirements are met
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
assert len(save_data_calls) == 1, "Should call save_data exactly once"
assert len(send_message_calls) == 1, "Should call send_message exactly once"
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")