feat: Support pre-filling arguments on InitToolRule [LET-4569] (#5057)
* Add args * Add testing to tool rule solver * Add live integration tests for args prefilling * Add args override
This commit is contained in:
committed by
Caren Thomas
parent
6c7c12ad0f
commit
803b837c64
@@ -8,6 +8,7 @@ from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
|
||||
@@ -225,6 +226,29 @@ async def cleanup_temp_files_tool(server):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def validate_api_key_tool(server):
|
||||
SECRET = "REAL_KEY_123"
|
||||
|
||||
def validate_api_key(secret_key: str):
|
||||
"""
|
||||
Validates an API key; errors if incorrect.
|
||||
|
||||
Args:
|
||||
secret_key (str): The provided key.
|
||||
|
||||
Returns:
|
||||
str: Confirmation string when key is valid.
|
||||
"""
|
||||
if secret_key != SECRET:
|
||||
raise RuntimeError(f"Invalid secret key: {secret_key}")
|
||||
return "api key accepted"
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async()
|
||||
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=validate_api_key), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def validate_work_tool(server):
|
||||
def validate_work():
|
||||
@@ -536,6 +560,126 @@ async def test_init_tool_rule_always_fails(
|
||||
await cleanup_async(server=server, agent_uuid=agent_uuid, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_tool_rule_args_override_llm_payload(server, disable_e2b_api_key, validate_api_key_tool, default_user):
|
||||
"""InitToolRule args should override LLM-provided args for the initial tool call."""
|
||||
REAL = "REAL_KEY_123"
|
||||
|
||||
tools = [validate_api_key_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="validate_api_key", args={"secret_key": REAL}),
|
||||
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
agent_name = str(uuid.uuid4())
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
OPENAI_CONFIG,
|
||||
agent_uuid=agent_name,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
# Ask the model to call with a fake key; prefilled args should override
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Please validate my API key: FAKE_KEY_999, then send me confirmation.",
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "validate_api_key")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Verify the recorded tool-call arguments reflect the prefilled override
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage) and m.tool_call.name == "validate_api_key":
|
||||
args = json.loads(m.tool_call.arguments)
|
||||
assert args.get("secret_key") == REAL
|
||||
break
|
||||
|
||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_tool_rule_invalid_prefilled_type_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
|
||||
"""Invalid prefilled args should produce an error and prevent further flow (no send_message)."""
|
||||
tools = [validate_api_key_tool]
|
||||
# Provide wrong type for secret_key (expects string)
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="validate_api_key", args={"secret_key": 123}),
|
||||
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
agent_name = str(uuid.uuid4())
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
OPENAI_CONFIG,
|
||||
agent_uuid=agent_name,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[MessageCreate(role="user", content="try validating my key")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
|
||||
|
||||
# Should attempt validate_api_key but not proceed to send_message
|
||||
assert_invoked_function_call(response.messages, "validate_api_key")
|
||||
with pytest.raises(Exception):
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_tool_rule_unknown_prefilled_key_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
|
||||
"""Unknown prefilled arg key should error and block flow."""
|
||||
tools = [validate_api_key_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="validate_api_key", args={"not_a_param": "value"}),
|
||||
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
agent_name = str(uuid.uuid4())
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
OPENAI_CONFIG,
|
||||
agent_uuid=agent_name,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[MessageCreate(role="user", content="validate with your best guess")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
|
||||
|
||||
assert_invoked_function_call(response.messages, "validate_api_key")
|
||||
with pytest.raises(Exception):
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continue_tool_rule(server, default_user):
|
||||
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
|
||||
|
||||
Reference in New Issue
Block a user