feat: Support pre-filling arguments on InitToolRule [LET-4569] (#5057)

* Add args

* Add testing to tool rule solver

* Add live integration tests for args prefilling

* Add args override
This commit is contained in:
Matthew Zhou
2025-10-01 11:59:43 -07:00
committed by Caren Thomas
parent 6c7c12ad0f
commit 803b837c64
7 changed files with 621 additions and 5 deletions

View File

@@ -8,6 +8,7 @@ from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.config import LettaConfig
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.run import Run
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
@@ -225,6 +226,29 @@ async def cleanup_temp_files_tool(server):
yield tool
@pytest.fixture(scope="function")
async def validate_api_key_tool(server):
SECRET = "REAL_KEY_123"
def validate_api_key(secret_key: str):
"""
Validates an API key; errors if incorrect.
Args:
secret_key (str): The provided key.
Returns:
str: Confirmation string when key is valid.
"""
if secret_key != SECRET:
raise RuntimeError(f"Invalid secret key: {secret_key}")
return "api key accepted"
actor = await server.user_manager.get_actor_or_default_async()
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=validate_api_key), actor=actor)
yield tool
@pytest.fixture(scope="function")
async def validate_work_tool(server):
def validate_work():
@@ -536,6 +560,126 @@ async def test_init_tool_rule_always_fails(
await cleanup_async(server=server, agent_uuid=agent_uuid, actor=default_user)
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_init_tool_rule_args_override_llm_payload(server, disable_e2b_api_key, validate_api_key_tool, default_user):
"""InitToolRule args should override LLM-provided args for the initial tool call."""
REAL = "REAL_KEY_123"
tools = [validate_api_key_tool]
tool_rules = [
InitToolRule(tool_name="validate_api_key", args={"secret_key": REAL}),
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
agent_name = str(uuid.uuid4())
agent_state = await setup_agent(
server,
OPENAI_CONFIG,
agent_uuid=agent_name,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
# Ask the model to call with a fake key; prefilled args should override
response = await run_agent_step(
agent_state=agent_state,
input_messages=[
MessageCreate(
role="user",
content="Please validate my API key: FAKE_KEY_999, then send me confirmation.",
)
],
actor=default_user,
)
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "validate_api_key")
assert_invoked_function_call(response.messages, "send_message")
# Verify the recorded tool-call arguments reflect the prefilled override
for m in response.messages:
if isinstance(m, ToolCallMessage) and m.tool_call.name == "validate_api_key":
args = json.loads(m.tool_call.arguments)
assert args.get("secret_key") == REAL
break
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_init_tool_rule_invalid_prefilled_type_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
"""Invalid prefilled args should produce an error and prevent further flow (no send_message)."""
tools = [validate_api_key_tool]
# Provide wrong type for secret_key (expects string)
tool_rules = [
InitToolRule(tool_name="validate_api_key", args={"secret_key": 123}),
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
agent_name = str(uuid.uuid4())
agent_state = await setup_agent(
server,
OPENAI_CONFIG,
agent_uuid=agent_name,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
response = await run_agent_step(
agent_state=agent_state,
input_messages=[MessageCreate(role="user", content="try validating my key")],
actor=default_user,
)
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
# Should attempt validate_api_key but not proceed to send_message
assert_invoked_function_call(response.messages, "validate_api_key")
with pytest.raises(Exception):
assert_invoked_function_call(response.messages, "send_message")
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_init_tool_rule_unknown_prefilled_key_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
"""Unknown prefilled arg key should error and block flow."""
tools = [validate_api_key_tool]
tool_rules = [
InitToolRule(tool_name="validate_api_key", args={"not_a_param": "value"}),
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
agent_name = str(uuid.uuid4())
agent_state = await setup_agent(
server,
OPENAI_CONFIG,
agent_uuid=agent_name,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
response = await run_agent_step(
agent_state=agent_state,
input_messages=[MessageCreate(role="user", content="validate with your best guess")],
actor=default_user,
)
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
assert_invoked_function_call(response.messages, "validate_api_key")
with pytest.raises(Exception):
assert_invoked_function_call(response.messages, "send_message")
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
@pytest.mark.asyncio
async def test_continue_tool_rule(server, default_user):
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""