feat: Support pre-filling arguments on InitToolRule [LET-4569] (#5057)

* Add args

* Add testing to tool rule solver

* Add live integration tests for args prefilling

* Add args override
This commit is contained in:
Matthew Zhou
2025-10-01 11:59:43 -07:00
committed by Caren Thomas
parent 6c7c12ad0f
commit 803b837c64
7 changed files with 621 additions and 5 deletions

View File

@@ -12394,6 +12394,36 @@
{}
],
"nullable": true
},
"args": {
"oneOf": [
{
"nullable": true
},
{
"type": "string",
"format": "null",
"nullable": true
},
{
"type": "array",
"items": {
"oneOf": [
{
"nullable": true
},
{
"type": "string",
"format": "null",
"nullable": true
}
],
"nullable": true
}
},
{}
],
"nullable": true
}
},
"required": ["tool_name"]
@@ -12895,6 +12925,36 @@
{}
],
"nullable": true
},
"args": {
"oneOf": [
{
"nullable": true
},
{
"type": "string",
"format": "null",
"nullable": true
},
{
"type": "array",
"items": {
"oneOf": [
{
"nullable": true
},
{
"type": "string",
"format": "null",
"nullable": true
}
],
"nullable": true
}
},
{}
],
"nullable": true
}
},
"required": ["tool_name"]
@@ -26972,6 +27032,19 @@
],
"title": "Prompt Template",
"description": "Optional template string (ignored). Rendering uses fast built-in formatting for performance."
},
"args": {
"anyOf": [
{
"additionalProperties": true,
"type": "object"
},
{
"type": "null"
}
],
"title": "Args",
"description": "Optional prefilled arguments for this tool. When present, these values will override any LLM-provided arguments with the same keys during invocation. Keys must match the tool's parameter names and values must satisfy the tool's JSON schema. Supports partial prefill; non-overlapping parameters are left to the model."
}
},
"additionalProperties": false,

View File

@@ -1,7 +1,7 @@
import json
import uuid
import xml.etree.ElementTree as ET
from typing import List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID, uuid4
from letta.errors import PendingApprovalError
@@ -258,6 +258,106 @@ def _safe_load_tool_call_str(tool_call_args_str: str) -> dict:
return tool_args
def _json_type_matches(value: Any, expected_type: Any) -> bool:
"""Basic JSON Schema type checking for common types.
expected_type can be a string (e.g., "string") or a list (union).
This is intentionally lightweight; deeper validation can be added as needed.
"""
def match_one(v: Any, t: str) -> bool:
if t == "string":
return isinstance(v, str)
if t == "integer":
# bool is subclass of int in Python; exclude
return isinstance(v, int) and not isinstance(v, bool)
if t == "number":
return (isinstance(v, int) and not isinstance(v, bool)) or isinstance(v, float)
if t == "boolean":
return isinstance(v, bool)
if t == "object":
return isinstance(v, dict)
if t == "array":
return isinstance(v, list)
if t == "null":
return v is None
# Fallback: don't over-reject on unknown types
return True
if isinstance(expected_type, list):
return any(match_one(value, t) for t in expected_type)
if isinstance(expected_type, str):
return match_one(value, expected_type)
return True
def _schema_accepts_value(prop_schema: Dict[str, Any], value: Any) -> bool:
"""Check if a value is acceptable for a property schema.
Handles: type, enum, const, anyOf, oneOf (by shallow traversal).
"""
if prop_schema is None:
return True
# const has highest precedence
if "const" in prop_schema:
return value == prop_schema["const"]
# enums
if "enum" in prop_schema:
try:
return value in prop_schema["enum"]
except Exception:
return False
# unions
for union_key in ("anyOf", "oneOf"):
if union_key in prop_schema and isinstance(prop_schema[union_key], list):
for sub in prop_schema[union_key]:
if _schema_accepts_value(sub, value):
return True
return False
# type-based
if "type" in prop_schema:
if not _json_type_matches(value, prop_schema["type"]):
return False
# No strict constraints specified: accept
return True
def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]:
"""Merge LLM-provided args with prefilled args from tool rules.
- Overlapping keys are replaced by prefilled values (prefilled wins).
- Validates that prefilled keys exist on the tool schema and that values satisfy
basic JSON Schema constraints (type/enum/const/anyOf/oneOf).
- Returns merged args, or raises ValueError on invalid prefilled inputs.
"""
from letta.schemas.tool import Tool # local import to avoid circulars in type hints
assert isinstance(tool, Tool)
schema = (tool.json_schema or {}).get("parameters", {})
props: Dict[str, Any] = schema.get("properties", {}) if isinstance(schema, dict) else {}
errors: list[str] = []
for k, v in prefilled_args.items():
if k not in props:
errors.append(f"Unknown argument '{k}' for tool '{tool.name}'.")
continue
if not _schema_accepts_value(props.get(k), v):
expected = props.get(k, {}).get("type")
errors.append(f"Invalid value for '{k}': {v!r} does not match expected schema type {expected!r}.")
if errors:
raise ValueError("; ".join(errors))
merged = dict(llm_args or {})
merged.update(prefilled_args)
return merged
def _pop_heartbeat(tool_args: dict) -> bool:
hb = tool_args.pop("request_heartbeat", False)
return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb)

View File

@@ -13,6 +13,7 @@ from letta.agents.helpers import (
_prepare_in_context_messages_no_persist_async,
_safe_load_tool_call_str,
generate_step_id,
merge_and_validate_prefilled_args,
)
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
@@ -678,6 +679,68 @@ class LettaAgentV3(LettaAgentV2):
if tool_rule_violated:
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
else:
# Prefill + validate args if a rule provided them
prefill_args = self.tool_rules_solver.last_prefilled_args_by_tool.get(tool_call_name)
if prefill_args:
# Find tool object for schema validation
target_tool = next((t for t in agent_state.tools if t.name == tool_call_name), None)
provenance = self.tool_rules_solver.last_prefilled_args_provenance.get(tool_call_name)
try:
tool_args = merge_and_validate_prefilled_args(
tool=target_tool,
llm_args=tool_args,
prefilled_args=prefill_args,
)
except ValueError as ve:
# Treat invalid prefilled args as user error and end the step
error_prefix = "Invalid prefilled tool arguments from tool rules"
prov_suffix = f" (source={provenance})" if provenance else ""
err_msg = f"{error_prefix}{prov_suffix}: {str(ve)}"
tool_execution_result = ToolExecutionResult(status="error", func_return=err_msg)
# Create messages and early return persistence path below
continue_stepping, heartbeat_reason, stop_reason = (
False,
None,
LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value),
)
tool_call_messages = create_letta_messages_from_llm_response(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
function_name=tool_call_name,
function_arguments=tool_args,
tool_execution_result=tool_execution_result,
tool_call_id=tool_call_id,
function_call_success=False,
function_response=tool_execution_result.func_return,
timezone=agent_state.timezone,
actor=self.actor,
continue_stepping=continue_stepping,
heartbeat_reason=None,
reasoning_content=content,
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
step_id=step_id,
run_id=run_id,
is_approval_response=is_approval or is_denial,
force_set_request_heartbeat=False,
add_heartbeat_on_continue=False,
)
messages_to_persist = (initial_messages or []) + tool_call_messages
# Set run_id on all messages before persisting
for message in messages_to_persist:
if message.run_id is None:
message.run_id = run_id
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
return persisted_messages, continue_stepping, stop_reason
# Track tool execution time
tool_start_time = get_utc_timestamp_ns()
tool_execution_result = await self._execute_tool(

View File

@@ -50,6 +50,16 @@ class ToolRulesSolver(BaseModel):
)
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
# Last-evaluated prefilled args cache (per step)
last_prefilled_args_by_tool: dict[str, dict] = Field(
default_factory=dict, description="Cached mapping of tool name to prefilled args from the last allowlist evaluation.", exclude=True
)
last_prefilled_args_provenance: dict[str, str] = Field(
default_factory=dict,
description="Cached mapping of tool name to a short description of which rule provided the prefilled args.",
exclude=True,
)
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
super().__init__(tool_rules=tool_rules, **kwargs)
@@ -88,15 +98,17 @@ class ToolRulesSolver(BaseModel):
) -> list[ToolName]:
"""Get a list of tool names allowed based on the last tool called.
Side-effect: also caches any prefilled args provided by active rules into
`last_prefilled_args_by_tool` and `last_prefilled_args_provenance`.
The logic is as follows:
1. if there are no previous tool calls, and we have InitToolRules, those are the only options for the first tool call
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
"""
# TODO: This piece of code here is quite ugly and deserves a refactor
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
# Compute allowed tools first
if not self.tool_call_history and self.init_tool_rules:
return [rule.tool_name for rule in self.init_tool_rules]
allowed = [rule.tool_name for rule in self.init_tool_rules]
else:
valid_tool_sets = []
for rule in self.child_based_tool_rules + self.parent_tool_rules:
@@ -109,7 +121,42 @@ class ToolRulesSolver(BaseModel):
if error_on_empty and not final_allowed_tools:
raise ValueError("No valid tools found based on tool rules.")
return list(final_allowed_tools)
allowed = list(final_allowed_tools)
# Build prefilled args cache for current allowed set
args_by_tool: dict[str, dict] = {}
provenance_by_tool: dict[str, str] = {}
def _store_args(tool_name: str, args: dict, rule: BaseModel):
if not isinstance(args, dict) or len(args) == 0:
return
if tool_name not in args_by_tool:
args_by_tool[tool_name] = {}
args_by_tool[tool_name].update(args) # last-write-wins
provenance_by_tool[tool_name] = f"{rule.__class__.__name__}({getattr(rule, 'tool_name', tool_name)})"
allowed_set = set(allowed)
if not self.tool_call_history and self.init_tool_rules:
for rule in self.init_tool_rules:
if hasattr(rule, "args") and getattr(rule, "args") and rule.tool_name in allowed_set:
_store_args(rule.tool_name, getattr(rule, "args"), rule)
else:
for rule in (
self.child_based_tool_rules
+ self.parent_tool_rules
+ self.continue_tool_rules
+ self.terminal_tool_rules
+ self.required_before_exit_tool_rules
+ self.requires_approval_tool_rules
):
if hasattr(rule, "args") and getattr(rule, "args") and getattr(rule, "tool_name", None) in allowed_set:
_store_args(rule.tool_name, getattr(rule, "args"), rule)
self.last_prefilled_args_by_tool = args_by_tool
self.last_prefilled_args_provenance = provenance_by_tool
return allowed
def is_terminal_tool(self, tool_name: ToolName) -> bool:
"""Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules."""

View File

@@ -209,6 +209,14 @@ class InitToolRule(BaseToolRule):
"""
type: Literal[ToolRuleType.run_first] = ToolRuleType.run_first
args: Optional[Dict[str, Any]] = Field(
default=None,
description=(
"Optional prefilled arguments for this tool. When present, these values will override any LLM-provided "
"arguments with the same keys during invocation. Keys must match the tool's parameter names and values "
"must satisfy the tool's JSON schema. Supports partial prefill; non-overlapping parameters are left to the model."
),
)
@property
def requires_force_tool_call(self) -> bool:

View File

@@ -8,6 +8,7 @@ from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.config import LettaConfig
from letta.schemas.letta_message import ToolCallMessage
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.run import Run
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
@@ -225,6 +226,29 @@ async def cleanup_temp_files_tool(server):
yield tool
@pytest.fixture(scope="function")
async def validate_api_key_tool(server):
SECRET = "REAL_KEY_123"
def validate_api_key(secret_key: str):
"""
Validates an API key; errors if incorrect.
Args:
secret_key (str): The provided key.
Returns:
str: Confirmation string when key is valid.
"""
if secret_key != SECRET:
raise RuntimeError(f"Invalid secret key: {secret_key}")
return "api key accepted"
actor = await server.user_manager.get_actor_or_default_async()
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=validate_api_key), actor=actor)
yield tool
@pytest.fixture(scope="function")
async def validate_work_tool(server):
def validate_work():
@@ -536,6 +560,126 @@ async def test_init_tool_rule_always_fails(
await cleanup_async(server=server, agent_uuid=agent_uuid, actor=default_user)
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_init_tool_rule_args_override_llm_payload(server, disable_e2b_api_key, validate_api_key_tool, default_user):
"""InitToolRule args should override LLM-provided args for the initial tool call."""
REAL = "REAL_KEY_123"
tools = [validate_api_key_tool]
tool_rules = [
InitToolRule(tool_name="validate_api_key", args={"secret_key": REAL}),
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
agent_name = str(uuid.uuid4())
agent_state = await setup_agent(
server,
OPENAI_CONFIG,
agent_uuid=agent_name,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
# Ask the model to call with a fake key; prefilled args should override
response = await run_agent_step(
agent_state=agent_state,
input_messages=[
MessageCreate(
role="user",
content="Please validate my API key: FAKE_KEY_999, then send me confirmation.",
)
],
actor=default_user,
)
assert_sanity_checks(response)
assert_invoked_function_call(response.messages, "validate_api_key")
assert_invoked_function_call(response.messages, "send_message")
# Verify the recorded tool-call arguments reflect the prefilled override
for m in response.messages:
if isinstance(m, ToolCallMessage) and m.tool_call.name == "validate_api_key":
args = json.loads(m.tool_call.arguments)
assert args.get("secret_key") == REAL
break
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_init_tool_rule_invalid_prefilled_type_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
"""Invalid prefilled args should produce an error and prevent further flow (no send_message)."""
tools = [validate_api_key_tool]
# Provide wrong type for secret_key (expects string)
tool_rules = [
InitToolRule(tool_name="validate_api_key", args={"secret_key": 123}),
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
agent_name = str(uuid.uuid4())
agent_state = await setup_agent(
server,
OPENAI_CONFIG,
agent_uuid=agent_name,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
response = await run_agent_step(
agent_state=agent_state,
input_messages=[MessageCreate(role="user", content="try validating my key")],
actor=default_user,
)
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
# Should attempt validate_api_key but not proceed to send_message
assert_invoked_function_call(response.messages, "validate_api_key")
with pytest.raises(Exception):
assert_invoked_function_call(response.messages, "send_message")
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
@pytest.mark.timeout(60)
@pytest.mark.asyncio
async def test_init_tool_rule_unknown_prefilled_key_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
"""Unknown prefilled arg key should error and block flow."""
tools = [validate_api_key_tool]
tool_rules = [
InitToolRule(tool_name="validate_api_key", args={"not_a_param": "value"}),
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
TerminalToolRule(tool_name="send_message"),
]
agent_name = str(uuid.uuid4())
agent_state = await setup_agent(
server,
OPENAI_CONFIG,
agent_uuid=agent_name,
tool_ids=[t.id for t in tools],
tool_rules=tool_rules,
)
response = await run_agent_step(
agent_state=agent_state,
input_messages=[MessageCreate(role="user", content="validate with your best guess")],
actor=default_user,
)
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
assert_invoked_function_call(response.messages, "validate_api_key")
with pytest.raises(Exception):
assert_invoked_function_call(response.messages, "send_message")
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
@pytest.mark.asyncio
async def test_continue_tool_rule(server, default_user):
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""

View File

@@ -1,6 +1,9 @@
import pytest
from letta.agents.helpers import merge_and_validate_prefilled_args
from letta.helpers import ToolRulesSolver
from letta.schemas.enums import ToolType
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import (
ChildToolRule,
ConditionalToolRule,
@@ -722,3 +725,181 @@ def test_should_force_tool_call_mixed_rules():
solver.register_tool_call(NEXT_TOOL)
assert solver.should_force_tool_call() is False, "Should return False when no constraining rules are active"
def make_tool(name: str, properties: dict) -> Tool:
"""Helper to build a minimal custom Tool with a JSON schema."""
return Tool(
name=name,
tool_type=ToolType.CUSTOM,
json_schema={
"name": name,
"parameters": {
"type": "object",
"properties": properties,
"required": [],
"additionalProperties": False,
},
},
)
def test_init_rule_args_are_cached_in_solver():
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"x": 1, "y": "s"})])
allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
assert set(allowed) == {"alpha"}
# Cached mappings
assert solver.last_prefilled_args_by_tool == {"alpha": {"x": 1, "y": "s"}}
assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)"
def test_cached_provenance_format():
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="tool_one", args={"a": 123})])
_ = solver.get_allowed_tool_names(available_tools={"tool_one"})
prov = solver.last_prefilled_args_provenance.get("tool_one")
assert prov.startswith("InitToolRule(") and prov.endswith(")") and "tool_one" in prov
def test_cache_empty_when_no_args():
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha")])
allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
assert set(allowed) == {"alpha"}
assert solver.last_prefilled_args_by_tool == {}
assert solver.last_prefilled_args_provenance == {}
def test_cache_recomputed_on_next_call():
# First call caches args for init tool
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"p": 5})])
_ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
assert solver.last_prefilled_args_by_tool == {"alpha": {"p": 5}}
# After a tool call, init rules no longer apply; next computation should clear caches
solver.register_tool_call("alpha")
_ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
assert solver.last_prefilled_args_by_tool == {}
assert solver.last_prefilled_args_provenance == {}
def test_merge_and_validate_prefilled_args_overrides_llm_values():
tool = make_tool("my_tool", properties={"a": {"type": "integer"}, "b": {"type": "string"}})
llm_args = {"a": 1, "b": "hello"}
prefilled = {"a": 42}
merged = merge_and_validate_prefilled_args(tool, llm_args, prefilled)
assert merged == {"a": 42, "b": "hello"}
def test_merge_and_validate_prefilled_args_type_validation():
tool = make_tool("typed_tool", properties={"a": {"type": "integer"}})
llm_args = {"a": 1}
prefilled = {"a": "not-an-int"}
with pytest.raises(ValueError) as ei:
_ = merge_and_validate_prefilled_args(tool, llm_args, prefilled)
assert "Invalid value for 'a'" in str(ei.value)
assert "integer" in str(ei.value)
def test_merge_and_validate_prefilled_args_unknown_key_fails():
tool = make_tool("limited_tool", properties={"a": {"type": "integer"}})
with pytest.raises(ValueError) as ei:
_ = merge_and_validate_prefilled_args(tool, llm_args={}, prefilled_args={"z": 3})
assert "Unknown argument 'z'" in str(ei.value)
def test_merge_and_validate_prefilled_args_enum_const_anyof_oneof():
tool = make_tool(
"rich_tool",
properties={
"c": {"enum": ["x", "y"]},
"d": {"const": 5},
"e": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
"f": {"oneOf": [{"type": "string"}, {"type": "integer"}]},
"g": {"type": "number"},
},
)
# Valid cases
merged = merge_and_validate_prefilled_args(tool, {}, {"c": "x"})
assert merged["c"] == "x"
merged = merge_and_validate_prefilled_args(tool, {}, {"d": 5})
assert merged["d"] == 5
merged = merge_and_validate_prefilled_args(tool, {}, {"e": 7})
assert merged["e"] == 7
merged = merge_and_validate_prefilled_args(tool, {}, {"f": "hello"})
assert merged["f"] == "hello"
merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3.14})
assert merged["g"] == 3.14
merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3})
assert merged["g"] == 3
# Invalid cases
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"c": "z"}) # enum fail
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"d": 6}) # const fail
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"e": []}) # anyOf none match
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"f": []}) # oneOf none match
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"g": True}) # bool not a number
def test_merge_and_validate_prefilled_args_union_with_null():
tool = make_tool("union_tool", properties={"h": {"type": ["string", "null"]}})
merged = merge_and_validate_prefilled_args(tool, {}, {"h": None})
assert "h" in merged and merged["h"] is None
merged = merge_and_validate_prefilled_args(tool, {}, {"h": "ok"})
assert merged["h"] == "ok"
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"h": 5})
def test_merge_and_validate_prefilled_args_object_and_array_types():
tool = make_tool(
"container_tool",
properties={
"obj": {"type": "object"},
"arr": {"type": "array"},
},
)
merged = merge_and_validate_prefilled_args(tool, {}, {"obj": {"k": 1}})
assert merged["obj"] == {"k": 1}
merged = merge_and_validate_prefilled_args(tool, {}, {"arr": [1, 2, 3]})
assert merged["arr"] == [1, 2, 3]
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"obj": "nope"})
with pytest.raises(ValueError):
_ = merge_and_validate_prefilled_args(tool, {}, {"arr": {}})
def test_multiple_rules_args_last_write_wins_and_provenance():
# Two init rules for the same tool; the latter should overwrite overlapping keys and provenance
r1 = InitToolRule(tool_name="alpha", args={"x": 1, "y": "first"})
r2 = InitToolRule(tool_name="alpha", args={"y": "second", "z": True})
solver = ToolRulesSolver(tool_rules=[r1, r2])
allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
assert set(allowed) == {"alpha"}
assert solver.last_prefilled_args_by_tool["alpha"] == {"x": 1, "y": "second", "z": True}
assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)"