feat: Add required before exit tool rule (#2977)
This commit is contained in:
@@ -1,15 +1,15 @@
|
||||
import time
|
||||
import asyncio
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.agents.letta_agent import LettaAgent
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, TerminalToolRule
|
||||
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.telemetry_manager import NoopTelemetryManager
|
||||
from tests.helpers.endpoints_helper import (
|
||||
assert_invoked_function_call,
|
||||
assert_invoked_send_message_with_keyword,
|
||||
@@ -25,6 +25,13 @@ agent_uuid = str(uuid.uuid5(namespace, "test_agent_tool_graph"))
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def event_loop():
|
||||
loop = asyncio.new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def server():
|
||||
config = LettaConfig.load()
|
||||
@@ -181,13 +188,83 @@ def auto_error_tool(server):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def save_data_tool(server):
|
||||
def save_data():
|
||||
"""
|
||||
Saves important data before exiting.
|
||||
|
||||
Returns:
|
||||
str: Confirmation that data was saved.
|
||||
"""
|
||||
return "Data saved successfully"
|
||||
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=save_data), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def cleanup_temp_files_tool(server):
|
||||
def cleanup_temp_files():
|
||||
"""
|
||||
Cleans up temporary files before exiting.
|
||||
|
||||
Returns:
|
||||
str: Confirmation that cleanup was completed.
|
||||
"""
|
||||
return "Temporary files cleaned up"
|
||||
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=cleanup_temp_files), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def validate_work_tool(server):
|
||||
def validate_work():
|
||||
"""
|
||||
Validates that work is complete before exiting.
|
||||
|
||||
Returns:
|
||||
str: Validation result.
|
||||
"""
|
||||
return "Work validation passed"
|
||||
|
||||
actor = server.user_manager.get_user_or_default()
|
||||
tool = server.tool_manager.create_or_update_tool(create_tool_from_func(func=validate_work), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_user(server):
|
||||
yield server.user_manager.get_user_or_default()
|
||||
|
||||
|
||||
async def run_agent_step(server, agent_id, input_messages, actor):
|
||||
"""Helper function to run agent step using LettaAgent directly instead of server.send_messages."""
|
||||
agent_loop = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
agent_manager=server.agent_manager,
|
||||
block_manager=server.block_manager,
|
||||
job_manager=server.job_manager,
|
||||
passage_manager=server.passage_manager,
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=NoopTelemetryManager(),
|
||||
)
|
||||
|
||||
return await agent_loop.step(
|
||||
input_messages,
|
||||
max_steps=50,
|
||||
use_assistant_message=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60) # Sets a 60-second timeout for the test since this could loop infinitely
|
||||
def test_single_path_agent_tool_call_graph(
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_path_agent_tool_call_graph(
|
||||
server, disable_e2b_api_key, first_secret_tool, second_secret_tool, third_secret_tool, fourth_secret_tool, auto_error_tool, default_user
|
||||
):
|
||||
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
|
||||
@@ -207,18 +284,11 @@ def test_single_path_agent_tool_call_graph(
|
||||
|
||||
# Make agent state
|
||||
agent_state = setup_agent(server, config_file, agent_uuid=agent_uuid, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
usage_stats = server.send_messages(
|
||||
actor=default_user,
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="What is the fourth secret word?")],
|
||||
)
|
||||
messages = [message for step_messages in usage_stats.steps_messages for message in step_messages]
|
||||
letta_messages = []
|
||||
for m in messages:
|
||||
letta_messages += m.to_letta_messages()
|
||||
|
||||
response = LettaResponse(
|
||||
messages=letta_messages, stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value), usage=usage_stats
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Make checks
|
||||
@@ -299,7 +369,8 @@ def test_check_tool_rules_with_different_models_parametrized(
|
||||
|
||||
|
||||
@pytest.mark.timeout(180)
|
||||
def test_claude_initial_tool_rule_enforced(
|
||||
@pytest.mark.asyncio
|
||||
async def test_claude_initial_tool_rule_enforced(
|
||||
server,
|
||||
disable_e2b_api_key,
|
||||
first_secret_tool,
|
||||
@@ -325,20 +396,11 @@ def test_claude_initial_tool_rule_enforced(
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
usage_stats = server.send_messages(
|
||||
actor=default_user,
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="What is the second secret word?")],
|
||||
)
|
||||
messages = [m for step in usage_stats.steps_messages for m in step]
|
||||
letta_messages = []
|
||||
for m in messages:
|
||||
letta_messages += m.to_letta_messages()
|
||||
|
||||
response = LettaResponse(
|
||||
messages=letta_messages,
|
||||
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
|
||||
usage=usage_stats,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert_sanity_checks(response)
|
||||
@@ -359,7 +421,7 @@ def test_claude_initial_tool_rule_enforced(
|
||||
# Exponential backoff
|
||||
if i < 2:
|
||||
backoff_time = 10 * (2**i)
|
||||
time.sleep(backoff_time)
|
||||
await asyncio.sleep(backoff_time)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@@ -370,7 +432,8 @@ def test_claude_initial_tool_rule_enforced(
|
||||
"tests/configs/llm_model_configs/openai-gpt-4o.json",
|
||||
],
|
||||
)
|
||||
def test_agent_no_structured_output_with_one_child_tool_parametrized(
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_no_structured_output_with_one_child_tool_parametrized(
|
||||
server,
|
||||
disable_e2b_api_key,
|
||||
default_user,
|
||||
@@ -404,20 +467,11 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized(
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
usage_stats = server.send_messages(
|
||||
actor=default_user,
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="hi. run archival memory search")],
|
||||
)
|
||||
messages = [m for step in usage_stats.steps_messages for m in step]
|
||||
letta_messages = []
|
||||
for m in messages:
|
||||
letta_messages += m.to_letta_messages()
|
||||
|
||||
response = LettaResponse(
|
||||
messages=letta_messages,
|
||||
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
|
||||
usage=usage_stats,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Run assertions
|
||||
@@ -448,7 +502,8 @@ def test_agent_no_structured_output_with_one_child_tool_parametrized(
|
||||
|
||||
@pytest.mark.timeout(30)
|
||||
@pytest.mark.parametrize("include_base_tools", [False, True])
|
||||
def test_init_tool_rule_always_fails(
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_tool_rule_always_fails(
|
||||
server,
|
||||
disable_e2b_api_key,
|
||||
auto_error_tool,
|
||||
@@ -469,17 +524,11 @@ def test_init_tool_rule_always_fails(
|
||||
include_base_tools=include_base_tools,
|
||||
)
|
||||
|
||||
usage_stats = server.send_messages(
|
||||
actor=default_user,
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="blah blah blah")],
|
||||
)
|
||||
messages = [m for step in usage_stats.steps_messages for m in step]
|
||||
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
|
||||
response = LettaResponse(
|
||||
messages=letta_messages,
|
||||
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
|
||||
usage=usage_stats,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert_invoked_function_call(response.messages, auto_error_tool.name)
|
||||
@@ -487,7 +536,8 @@ def test_init_tool_rule_always_fails(
|
||||
cleanup(server=server, agent_uuid=agent_uuid, actor=default_user)
|
||||
|
||||
|
||||
def test_continue_tool_rule(server, default_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_continue_tool_rule(server, default_user):
|
||||
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
|
||||
config_file = "tests/configs/llm_model_configs/claude-3-5-sonnet.json"
|
||||
agent_uuid = str(uuid.uuid4())
|
||||
@@ -512,17 +562,11 @@ def test_continue_tool_rule(server, default_user):
|
||||
include_base_tool_rules=False,
|
||||
)
|
||||
|
||||
usage_stats = server.send_messages(
|
||||
actor=default_user,
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="Send me some messages, and then call core_memory_append to end your turn.")],
|
||||
)
|
||||
messages = [m for step in usage_stats.steps_messages for m in step]
|
||||
letta_messages = [msg for m in messages for msg in m.to_letta_messages()]
|
||||
response = LettaResponse(
|
||||
messages=letta_messages,
|
||||
stop_reason=LettaStopReason(stop_reason=StopReasonType.end_turn.value),
|
||||
usage=usage_stats,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
@@ -775,3 +819,180 @@ def test_continue_tool_rule(server, default_user):
|
||||
# assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin"
|
||||
#
|
||||
# cleanup(client, agent_uuid=agent_state.id)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_required_before_exit_tool(server, disable_e2b_api_key, save_data_tool, default_user):
|
||||
"""Test that agent is forced to call a single required-before-exit tool before ending."""
|
||||
agent_name = "required_exit_single_tool_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules
|
||||
tools = [save_data_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="send_message"),
|
||||
RequiredBeforeExitToolRule(tool_name="save_data"),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that would normally cause exit
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="Please finish your work and send me a message.")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "save_data")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# The key test is that both tools were called - the agent was forced to call save_data
|
||||
# even when it tried to exit early with send_message
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
|
||||
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
|
||||
|
||||
assert len(save_data_calls) >= 1, "save_data should be called at least once"
|
||||
assert len(send_message_calls) >= 1, "send_message should be called at least once"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' successfully called required tool before exit")
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_multiple_required_before_exit_tools(server, disable_e2b_api_key, save_data_tool, cleanup_temp_files_tool, default_user):
|
||||
"""Test that agent calls all required-before-exit tools before ending."""
|
||||
agent_name = "required_exit_multi_tool_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules
|
||||
tools = [save_data_tool, cleanup_temp_files_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="send_message"),
|
||||
RequiredBeforeExitToolRule(tool_name="save_data"),
|
||||
RequiredBeforeExitToolRule(tool_name="cleanup_temp_files"),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that would normally cause exit
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="Complete all necessary tasks and then send me a message.")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "save_data")
|
||||
assert_invoked_function_call(response.messages, "cleanup_temp_files")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Verify that all required tools were eventually called
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
|
||||
cleanup_calls = [tc for tc in tool_calls if tc.tool_call.name == "cleanup_temp_files"]
|
||||
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
|
||||
|
||||
assert len(save_data_calls) >= 1, "save_data should be called at least once"
|
||||
assert len(cleanup_calls) >= 1, "cleanup_temp_files should be called at least once"
|
||||
assert len(send_message_calls) >= 1, "send_message should be called at least once"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' successfully called all required tools before exit")
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_before_exit_with_other_rules(server, disable_e2b_api_key, first_secret_tool, save_data_tool, default_user):
|
||||
"""Test required-before-exit rules work alongside other tool rules."""
|
||||
agent_name = "required_exit_with_rules_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules - combine with child tool rules
|
||||
tools = [first_secret_tool, save_data_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="first_secret_word"),
|
||||
ChildToolRule(tool_name="first_secret_word", children=["send_message"]),
|
||||
RequiredBeforeExitToolRule(tool_name="save_data"),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that would trigger tool flow
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="Get the first secret word and then finish up.")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "first_secret_word")
|
||||
assert_invoked_function_call(response.messages, "save_data")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Verify that all tools were called (first_secret_word due to InitToolRule, save_data due to RequiredBeforeExitToolRule)
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
first_secret_calls = [tc for tc in tool_calls if tc.tool_call.name == "first_secret_word"]
|
||||
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
|
||||
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
|
||||
|
||||
assert len(first_secret_calls) >= 1, "first_secret_word should be called due to InitToolRule"
|
||||
assert len(save_data_calls) >= 1, "save_data should be called due to RequiredBeforeExitToolRule"
|
||||
assert len(send_message_calls) >= 1, "send_message should be called eventually"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' successfully handled mixed tool rules")
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_required_tools_called_during_normal_flow(server, disable_e2b_api_key, save_data_tool, default_user):
|
||||
"""Test that agent can exit normally when required tools are called during regular operation."""
|
||||
agent_name = "required_exit_normal_flow_agent"
|
||||
config_file = "tests/configs/llm_model_configs/openai-gpt-4o.json"
|
||||
|
||||
# Set up tools and rules
|
||||
tools = [save_data_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="save_data"),
|
||||
RequiredBeforeExitToolRule(tool_name="send_message"),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
# Create agent
|
||||
agent_state = setup_agent(server, config_file, agent_uuid=agent_name, tool_ids=[t.id for t in tools], tool_rules=tool_rules)
|
||||
|
||||
# Send message that explicitly mentions calling the required tool
|
||||
response = await run_agent_step(
|
||||
server=server,
|
||||
agent_id=agent_state.id,
|
||||
input_messages=[MessageCreate(role="user", content="Please save data and then send me a message when done.")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Assertions
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "save_data")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Should not have excessive tool calls - agent should exit cleanly after requirements are met
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
save_data_calls = [tc for tc in tool_calls if tc.tool_call.name == "save_data"]
|
||||
send_message_calls = [tc for tc in tool_calls if tc.tool_call.name == "send_message"]
|
||||
|
||||
assert len(save_data_calls) == 1, "Should call save_data exactly once"
|
||||
assert len(send_message_calls) == 1, "Should call send_message exactly once"
|
||||
|
||||
print(f"✓ Agent '{agent_name}' exited cleanly after calling required tool normally")
|
||||
|
||||
Reference in New Issue
Block a user