diff --git a/fern/openapi.json b/fern/openapi.json index f0c6b88c..edd18d07 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -12394,6 +12394,36 @@ {} ], "nullable": true + }, + "args": { + "oneOf": [ + { + "nullable": true + }, + { + "type": "string", + "format": "null", + "nullable": true + }, + { + "type": "array", + "items": { + "oneOf": [ + { + "nullable": true + }, + { + "type": "string", + "format": "null", + "nullable": true + } + ], + "nullable": true + } + }, + {} + ], + "nullable": true } }, "required": ["tool_name"] @@ -12895,6 +12925,36 @@ {} ], "nullable": true + }, + "args": { + "oneOf": [ + { + "nullable": true + }, + { + "type": "string", + "format": "null", + "nullable": true + }, + { + "type": "array", + "items": { + "oneOf": [ + { + "nullable": true + }, + { + "type": "string", + "format": "null", + "nullable": true + } + ], + "nullable": true + } + }, + {} + ], + "nullable": true } }, "required": ["tool_name"] @@ -26972,6 +27032,19 @@ ], "title": "Prompt Template", "description": "Optional template string (ignored). Rendering uses fast built-in formatting for performance." + }, + "args": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Args", + "description": "Optional prefilled arguments for this tool. When present, these values will override any LLM-provided arguments with the same keys during invocation. Keys must match the tool's parameter names and values must satisfy the tool's JSON schema. Supports partial prefill; non-overlapping parameters are left to the model." } }, "additionalProperties": false, diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 18130964..e511c846 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -1,7 +1,7 @@ import json import uuid import xml.etree.ElementTree as ET -from typing import List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple from uuid import UUID, uuid4 from letta.errors import PendingApprovalError @@ -258,6 +258,106 @@ def _safe_load_tool_call_str(tool_call_args_str: str) -> dict: return tool_args +def _json_type_matches(value: Any, expected_type: Any) -> bool: + """Basic JSON Schema type checking for common types. + + expected_type can be a string (e.g., "string") or a list (union). + This is intentionally lightweight; deeper validation can be added as needed. + """ + + def match_one(v: Any, t: str) -> bool: + if t == "string": + return isinstance(v, str) + if t == "integer": + # bool is subclass of int in Python; exclude + return isinstance(v, int) and not isinstance(v, bool) + if t == "number": + return (isinstance(v, int) and not isinstance(v, bool)) or isinstance(v, float) + if t == "boolean": + return isinstance(v, bool) + if t == "object": + return isinstance(v, dict) + if t == "array": + return isinstance(v, list) + if t == "null": + return v is None + # Fallback: don't over-reject on unknown types + return True + + if isinstance(expected_type, list): + return any(match_one(value, t) for t in expected_type) + if isinstance(expected_type, str): + return match_one(value, expected_type) + return True + + +def _schema_accepts_value(prop_schema: Dict[str, Any], value: Any) -> bool: + """Check if a value is acceptable for a property schema. + + Handles: type, enum, const, anyOf, oneOf (by shallow traversal). + """ + if prop_schema is None: + return True + + # const has highest precedence + if "const" in prop_schema: + return value == prop_schema["const"] + + # enums + if "enum" in prop_schema: + try: + return value in prop_schema["enum"] + except Exception: + return False + + # unions + for union_key in ("anyOf", "oneOf"): + if union_key in prop_schema and isinstance(prop_schema[union_key], list): + for sub in prop_schema[union_key]: + if _schema_accepts_value(sub, value): + return True + return False + + # type-based + if "type" in prop_schema: + if not _json_type_matches(value, prop_schema["type"]): + return False + + # No strict constraints specified: accept + return True + + +def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]: + """Merge LLM-provided args with prefilled args from tool rules. + + - Overlapping keys are replaced by prefilled values (prefilled wins). + - Validates that prefilled keys exist on the tool schema and that values satisfy + basic JSON Schema constraints (type/enum/const/anyOf/oneOf). + - Returns merged args, or raises ValueError on invalid prefilled inputs. + """ + from letta.schemas.tool import Tool # local import to avoid circulars in type hints + + assert isinstance(tool, Tool) + schema = (tool.json_schema or {}).get("parameters", {}) + props: Dict[str, Any] = schema.get("properties", {}) if isinstance(schema, dict) else {} + + errors: list[str] = [] + for k, v in prefilled_args.items(): + if k not in props: + errors.append(f"Unknown argument '{k}' for tool '{tool.name}'.") + continue + if not _schema_accepts_value(props.get(k), v): + expected = props.get(k, {}).get("type") + errors.append(f"Invalid value for '{k}': {v!r} does not match expected schema type {expected!r}.") + + if errors: + raise ValueError("; ".join(errors)) + + merged = dict(llm_args or {}) + merged.update(prefilled_args) + return merged + + def _pop_heartbeat(tool_args: dict) -> bool: hb = tool_args.pop("request_heartbeat", False) return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 6d353949..db95249d 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -13,6 +13,7 @@ from letta.agents.helpers import ( _prepare_in_context_messages_no_persist_async, _safe_load_tool_call_str, generate_step_id, + merge_and_validate_prefilled_args, ) from letta.agents.letta_agent_v2 import LettaAgentV2 from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM @@ -678,6 +679,68 @@ class LettaAgentV3(LettaAgentV2): if tool_rule_violated: tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver) else: + # Prefill + validate args if a rule provided them + prefill_args = self.tool_rules_solver.last_prefilled_args_by_tool.get(tool_call_name) + if prefill_args: + # Find tool object for schema validation + target_tool = next((t for t in agent_state.tools if t.name == tool_call_name), None) + provenance = self.tool_rules_solver.last_prefilled_args_provenance.get(tool_call_name) + try: + tool_args = merge_and_validate_prefilled_args( + tool=target_tool, + llm_args=tool_args, + prefilled_args=prefill_args, + ) + except ValueError as ve: + # Treat invalid prefilled args as user error and end the step + error_prefix = "Invalid prefilled tool arguments from tool rules" + prov_suffix = f" (source={provenance})" if provenance else "" + err_msg = f"{error_prefix}{prov_suffix}: {str(ve)}" + tool_execution_result = ToolExecutionResult(status="error", func_return=err_msg) + + # Create messages and early return persistence path below + continue_stepping, heartbeat_reason, stop_reason = ( + False, + None, + LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value), + ) + tool_call_messages = create_letta_messages_from_llm_response( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + function_name=tool_call_name, + function_arguments=tool_args, + tool_execution_result=tool_execution_result, + tool_call_id=tool_call_id, + function_call_success=False, + function_response=tool_execution_result.func_return, + timezone=agent_state.timezone, + actor=self.actor, + continue_stepping=continue_stepping, + heartbeat_reason=None, + reasoning_content=content, + pre_computed_assistant_message_id=pre_computed_assistant_message_id, + step_id=step_id, + run_id=run_id, + is_approval_response=is_approval or is_denial, + force_set_request_heartbeat=False, + add_heartbeat_on_continue=False, + ) + messages_to_persist = (initial_messages or []) + tool_call_messages + + # Set run_id on all messages before persisting + for message in messages_to_persist: + if message.run_id is None: + message.run_id = run_id + + persisted_messages = await self.message_manager.create_many_messages_async( + messages_to_persist, + actor=self.actor, + run_id=run_id, + project_id=agent_state.project_id, + template_id=agent_state.template_id, + ) + return persisted_messages, continue_stepping, stop_reason + # Track tool execution time tool_start_time = get_utc_timestamp_ns() tool_execution_result = await self._execute_tool( diff --git a/letta/helpers/tool_rule_solver.py b/letta/helpers/tool_rule_solver.py index 4cb9c86c..637c19ce 100644 --- a/letta/helpers/tool_rule_solver.py +++ b/letta/helpers/tool_rule_solver.py @@ -50,6 +50,16 @@ class ToolRulesSolver(BaseModel): ) tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.") + # Last-evaluated prefilled args cache (per step) + last_prefilled_args_by_tool: dict[str, dict] = Field( + default_factory=dict, description="Cached mapping of tool name to prefilled args from the last allowlist evaluation.", exclude=True + ) + last_prefilled_args_provenance: dict[str, str] = Field( + default_factory=dict, + description="Cached mapping of tool name to a short description of which rule provided the prefilled args.", + exclude=True, + ) + def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs): super().__init__(tool_rules=tool_rules, **kwargs) @@ -88,15 +98,17 @@ class ToolRulesSolver(BaseModel): ) -> list[ToolName]: """Get a list of tool names allowed based on the last tool called. + Side-effect: also caches any prefilled args provided by active rules into + `last_prefilled_args_by_tool` and `last_prefilled_args_provenance`. + The logic is as follows: 1. if there are no previous tool calls, and we have InitToolRules, those are the only options for the first tool call 2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options 3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools """ - # TODO: This piece of code here is quite ugly and deserves a refactor - # TODO: -> Tool rules should probably be refactored to take in a set of tool names? + # Compute allowed tools first if not self.tool_call_history and self.init_tool_rules: - return [rule.tool_name for rule in self.init_tool_rules] + allowed = [rule.tool_name for rule in self.init_tool_rules] else: valid_tool_sets = [] for rule in self.child_based_tool_rules + self.parent_tool_rules: @@ -109,7 +121,42 @@ class ToolRulesSolver(BaseModel): if error_on_empty and not final_allowed_tools: raise ValueError("No valid tools found based on tool rules.") - return list(final_allowed_tools) + allowed = list(final_allowed_tools) + + # Build prefilled args cache for current allowed set + args_by_tool: dict[str, dict] = {} + provenance_by_tool: dict[str, str] = {} + + def _store_args(tool_name: str, args: dict, rule: BaseModel): + if not isinstance(args, dict) or len(args) == 0: + return + if tool_name not in args_by_tool: + args_by_tool[tool_name] = {} + args_by_tool[tool_name].update(args) # last-write-wins + provenance_by_tool[tool_name] = f"{rule.__class__.__name__}({getattr(rule, 'tool_name', tool_name)})" + + allowed_set = set(allowed) + + if not self.tool_call_history and self.init_tool_rules: + for rule in self.init_tool_rules: + if hasattr(rule, "args") and getattr(rule, "args") and rule.tool_name in allowed_set: + _store_args(rule.tool_name, getattr(rule, "args"), rule) + else: + for rule in ( + self.child_based_tool_rules + + self.parent_tool_rules + + self.continue_tool_rules + + self.terminal_tool_rules + + self.required_before_exit_tool_rules + + self.requires_approval_tool_rules + ): + if hasattr(rule, "args") and getattr(rule, "args") and getattr(rule, "tool_name", None) in allowed_set: + _store_args(rule.tool_name, getattr(rule, "args"), rule) + + self.last_prefilled_args_by_tool = args_by_tool + self.last_prefilled_args_provenance = provenance_by_tool + + return allowed def is_terminal_tool(self, tool_name: ToolName) -> bool: """Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules.""" diff --git a/letta/schemas/tool_rule.py b/letta/schemas/tool_rule.py index 94116e55..4cb4de5b 100644 --- a/letta/schemas/tool_rule.py +++ b/letta/schemas/tool_rule.py @@ -209,6 +209,14 @@ class InitToolRule(BaseToolRule): """ type: Literal[ToolRuleType.run_first] = ToolRuleType.run_first + args: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Optional prefilled arguments for this tool. When present, these values will override any LLM-provided " + "arguments with the same keys during invocation. Keys must match the tool's parameter names and values " + "must satisfy the tool's JSON schema. Supports partial prefill; non-overlapping parameters are left to the model." + ), + ) @property def requires_force_tool_call(self) -> bool: diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 162b0ee4..dac703fd 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -8,6 +8,7 @@ from letta.agents.letta_agent_v2 import LettaAgentV2 from letta.agents.letta_agent_v3 import LettaAgentV3 from letta.config import LettaConfig from letta.schemas.letta_message import ToolCallMessage +from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import MessageCreate from letta.schemas.run import Run from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule @@ -225,6 +226,29 @@ async def cleanup_temp_files_tool(server): yield tool +@pytest.fixture(scope="function") +async def validate_api_key_tool(server): + SECRET = "REAL_KEY_123" + + def validate_api_key(secret_key: str): + """ + Validates an API key; errors if incorrect. + + Args: + secret_key (str): The provided key. + + Returns: + str: Confirmation string when key is valid. + """ + if secret_key != SECRET: + raise RuntimeError(f"Invalid secret key: {secret_key}") + return "api key accepted" + + actor = await server.user_manager.get_actor_or_default_async() + tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=validate_api_key), actor=actor) + yield tool + + @pytest.fixture(scope="function") async def validate_work_tool(server): def validate_work(): @@ -536,6 +560,126 @@ async def test_init_tool_rule_always_fails( await cleanup_async(server=server, agent_uuid=agent_uuid, actor=default_user) +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_init_tool_rule_args_override_llm_payload(server, disable_e2b_api_key, validate_api_key_tool, default_user): + """InitToolRule args should override LLM-provided args for the initial tool call.""" + REAL = "REAL_KEY_123" + + tools = [validate_api_key_tool] + tool_rules = [ + InitToolRule(tool_name="validate_api_key", args={"secret_key": REAL}), + ChildToolRule(tool_name="validate_api_key", children=["send_message"]), + TerminalToolRule(tool_name="send_message"), + ] + + agent_name = str(uuid.uuid4()) + agent_state = await setup_agent( + server, + OPENAI_CONFIG, + agent_uuid=agent_name, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) + + # Ask the model to call with a fake key; prefilled args should override + response = await run_agent_step( + agent_state=agent_state, + input_messages=[ + MessageCreate( + role="user", + content="Please validate my API key: FAKE_KEY_999, then send me confirmation.", + ) + ], + actor=default_user, + ) + + assert_sanity_checks(response) + assert_invoked_function_call(response.messages, "validate_api_key") + assert_invoked_function_call(response.messages, "send_message") + + # Verify the recorded tool-call arguments reflect the prefilled override + for m in response.messages: + if isinstance(m, ToolCallMessage) and m.tool_call.name == "validate_api_key": + args = json.loads(m.tool_call.arguments) + assert args.get("secret_key") == REAL + break + + await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user) + + +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_init_tool_rule_invalid_prefilled_type_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user): + """Invalid prefilled args should produce an error and prevent further flow (no send_message).""" + tools = [validate_api_key_tool] + # Provide wrong type for secret_key (expects string) + tool_rules = [ + InitToolRule(tool_name="validate_api_key", args={"secret_key": 123}), + ChildToolRule(tool_name="validate_api_key", children=["send_message"]), + TerminalToolRule(tool_name="send_message"), + ] + + agent_name = str(uuid.uuid4()) + agent_state = await setup_agent( + server, + OPENAI_CONFIG, + agent_uuid=agent_name, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) + + response = await run_agent_step( + agent_state=agent_state, + input_messages=[MessageCreate(role="user", content="try validating my key")], + actor=default_user, + ) + + assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call) + + # Should attempt validate_api_key but not proceed to send_message + assert_invoked_function_call(response.messages, "validate_api_key") + with pytest.raises(Exception): + assert_invoked_function_call(response.messages, "send_message") + + await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user) + + +@pytest.mark.timeout(60) +@pytest.mark.asyncio +async def test_init_tool_rule_unknown_prefilled_key_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user): + """Unknown prefilled arg key should error and block flow.""" + tools = [validate_api_key_tool] + tool_rules = [ + InitToolRule(tool_name="validate_api_key", args={"not_a_param": "value"}), + ChildToolRule(tool_name="validate_api_key", children=["send_message"]), + TerminalToolRule(tool_name="send_message"), + ] + + agent_name = str(uuid.uuid4()) + agent_state = await setup_agent( + server, + OPENAI_CONFIG, + agent_uuid=agent_name, + tool_ids=[t.id for t in tools], + tool_rules=tool_rules, + ) + + response = await run_agent_step( + agent_state=agent_state, + input_messages=[MessageCreate(role="user", content="validate with your best guess")], + actor=default_user, + ) + + assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call) + + assert_invoked_function_call(response.messages, "validate_api_key") + with pytest.raises(Exception): + assert_invoked_function_call(response.messages, "send_message") + + await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user) + + @pytest.mark.asyncio async def test_continue_tool_rule(server, default_user): """Test the continue tool rule by forcing send_message to loop before ending with core_memory_append.""" diff --git a/tests/test_tool_rule_solver.py b/tests/test_tool_rule_solver.py index 2609bae7..c5c3615b 100644 --- a/tests/test_tool_rule_solver.py +++ b/tests/test_tool_rule_solver.py @@ -1,6 +1,9 @@ import pytest +from letta.agents.helpers import merge_and_validate_prefilled_args from letta.helpers import ToolRulesSolver +from letta.schemas.enums import ToolType +from letta.schemas.tool import Tool from letta.schemas.tool_rule import ( ChildToolRule, ConditionalToolRule, @@ -722,3 +725,181 @@ def test_should_force_tool_call_mixed_rules(): solver.register_tool_call(NEXT_TOOL) assert solver.should_force_tool_call() is False, "Should return False when no constraining rules are active" + + +def make_tool(name: str, properties: dict) -> Tool: + """Helper to build a minimal custom Tool with a JSON schema.""" + return Tool( + name=name, + tool_type=ToolType.CUSTOM, + json_schema={ + "name": name, + "parameters": { + "type": "object", + "properties": properties, + "required": [], + "additionalProperties": False, + }, + }, + ) + + +def test_init_rule_args_are_cached_in_solver(): + solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"x": 1, "y": "s"})]) + allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) + + assert set(allowed) == {"alpha"} + # Cached mappings + assert solver.last_prefilled_args_by_tool == {"alpha": {"x": 1, "y": "s"}} + assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)" + + +def test_cached_provenance_format(): + solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="tool_one", args={"a": 123})]) + _ = solver.get_allowed_tool_names(available_tools={"tool_one"}) + prov = solver.last_prefilled_args_provenance.get("tool_one") + assert prov.startswith("InitToolRule(") and prov.endswith(")") and "tool_one" in prov + + +def test_cache_empty_when_no_args(): + solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha")]) + allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) + + assert set(allowed) == {"alpha"} + assert solver.last_prefilled_args_by_tool == {} + assert solver.last_prefilled_args_provenance == {} + + +def test_cache_recomputed_on_next_call(): + # First call caches args for init tool + solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"p": 5})]) + _ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) + assert solver.last_prefilled_args_by_tool == {"alpha": {"p": 5}} + + # After a tool call, init rules no longer apply; next computation should clear caches + solver.register_tool_call("alpha") + _ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) + assert solver.last_prefilled_args_by_tool == {} + assert solver.last_prefilled_args_provenance == {} + + +def test_merge_and_validate_prefilled_args_overrides_llm_values(): + tool = make_tool("my_tool", properties={"a": {"type": "integer"}, "b": {"type": "string"}}) + llm_args = {"a": 1, "b": "hello"} + prefilled = {"a": 42} + + merged = merge_and_validate_prefilled_args(tool, llm_args, prefilled) + assert merged == {"a": 42, "b": "hello"} + + +def test_merge_and_validate_prefilled_args_type_validation(): + tool = make_tool("typed_tool", properties={"a": {"type": "integer"}}) + llm_args = {"a": 1} + prefilled = {"a": "not-an-int"} + + with pytest.raises(ValueError) as ei: + _ = merge_and_validate_prefilled_args(tool, llm_args, prefilled) + assert "Invalid value for 'a'" in str(ei.value) + assert "integer" in str(ei.value) + + +def test_merge_and_validate_prefilled_args_unknown_key_fails(): + tool = make_tool("limited_tool", properties={"a": {"type": "integer"}}) + with pytest.raises(ValueError) as ei: + _ = merge_and_validate_prefilled_args(tool, llm_args={}, prefilled_args={"z": 3}) + assert "Unknown argument 'z'" in str(ei.value) + + +def test_merge_and_validate_prefilled_args_enum_const_anyof_oneof(): + tool = make_tool( + "rich_tool", + properties={ + "c": {"enum": ["x", "y"]}, + "d": {"const": 5}, + "e": {"anyOf": [{"type": "string"}, {"type": "integer"}]}, + "f": {"oneOf": [{"type": "string"}, {"type": "integer"}]}, + "g": {"type": "number"}, + }, + ) + + # Valid cases + merged = merge_and_validate_prefilled_args(tool, {}, {"c": "x"}) + assert merged["c"] == "x" + + merged = merge_and_validate_prefilled_args(tool, {}, {"d": 5}) + assert merged["d"] == 5 + + merged = merge_and_validate_prefilled_args(tool, {}, {"e": 7}) + assert merged["e"] == 7 + + merged = merge_and_validate_prefilled_args(tool, {}, {"f": "hello"}) + assert merged["f"] == "hello" + + merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3.14}) + assert merged["g"] == 3.14 + + merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3}) + assert merged["g"] == 3 + + # Invalid cases + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"c": "z"}) # enum fail + + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"d": 6}) # const fail + + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"e": []}) # anyOf none match + + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"f": []}) # oneOf none match + + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"g": True}) # bool not a number + + +def test_merge_and_validate_prefilled_args_union_with_null(): + tool = make_tool("union_tool", properties={"h": {"type": ["string", "null"]}}) + + merged = merge_and_validate_prefilled_args(tool, {}, {"h": None}) + assert "h" in merged and merged["h"] is None + + merged = merge_and_validate_prefilled_args(tool, {}, {"h": "ok"}) + assert merged["h"] == "ok" + + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"h": 5}) + + +def test_merge_and_validate_prefilled_args_object_and_array_types(): + tool = make_tool( + "container_tool", + properties={ + "obj": {"type": "object"}, + "arr": {"type": "array"}, + }, + ) + + merged = merge_and_validate_prefilled_args(tool, {}, {"obj": {"k": 1}}) + assert merged["obj"] == {"k": 1} + + merged = merge_and_validate_prefilled_args(tool, {}, {"arr": [1, 2, 3]}) + assert merged["arr"] == [1, 2, 3] + + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"obj": "nope"}) + with pytest.raises(ValueError): + _ = merge_and_validate_prefilled_args(tool, {}, {"arr": {}}) + + +def test_multiple_rules_args_last_write_wins_and_provenance(): + # Two init rules for the same tool; the latter should overwrite overlapping keys and provenance + r1 = InitToolRule(tool_name="alpha", args={"x": 1, "y": "first"}) + r2 = InitToolRule(tool_name="alpha", args={"y": "second", "z": True}) + solver = ToolRulesSolver(tool_rules=[r1, r2]) + + allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"}) + assert set(allowed) == {"alpha"} + + assert solver.last_prefilled_args_by_tool["alpha"] == {"x": 1, "y": "second", "z": True} + assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)"