feat: Support pre-filling arguments on InitToolRule [LET-4569] (#5057)
* Add args * Add testing to tool rule solver * Add live integration tests for args prefilling * Add args override
This commit is contained in:
committed by
Caren Thomas
parent
6c7c12ad0f
commit
803b837c64
@@ -12394,6 +12394,36 @@
|
||||
{}
|
||||
],
|
||||
"nullable": true
|
||||
},
|
||||
"args": {
|
||||
"oneOf": [
|
||||
{
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "null",
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "null",
|
||||
"nullable": true
|
||||
}
|
||||
],
|
||||
"nullable": true
|
||||
}
|
||||
},
|
||||
{}
|
||||
],
|
||||
"nullable": true
|
||||
}
|
||||
},
|
||||
"required": ["tool_name"]
|
||||
@@ -12895,6 +12925,36 @@
|
||||
{}
|
||||
],
|
||||
"nullable": true
|
||||
},
|
||||
"args": {
|
||||
"oneOf": [
|
||||
{
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "null",
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "array",
|
||||
"items": {
|
||||
"oneOf": [
|
||||
{
|
||||
"nullable": true
|
||||
},
|
||||
{
|
||||
"type": "string",
|
||||
"format": "null",
|
||||
"nullable": true
|
||||
}
|
||||
],
|
||||
"nullable": true
|
||||
}
|
||||
},
|
||||
{}
|
||||
],
|
||||
"nullable": true
|
||||
}
|
||||
},
|
||||
"required": ["tool_name"]
|
||||
@@ -26972,6 +27032,19 @@
|
||||
],
|
||||
"title": "Prompt Template",
|
||||
"description": "Optional template string (ignored). Rendering uses fast built-in formatting for performance."
|
||||
},
|
||||
"args": {
|
||||
"anyOf": [
|
||||
{
|
||||
"additionalProperties": true,
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Args",
|
||||
"description": "Optional prefilled arguments for this tool. When present, these values will override any LLM-provided arguments with the same keys during invocation. Keys must match the tool's parameter names and values must satisfy the tool's JSON schema. Supports partial prefill; non-overlapping parameters are left to the model."
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import uuid
|
||||
import xml.etree.ElementTree as ET
|
||||
from typing import List, Optional, Tuple
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
from uuid import UUID, uuid4
|
||||
|
||||
from letta.errors import PendingApprovalError
|
||||
@@ -258,6 +258,106 @@ def _safe_load_tool_call_str(tool_call_args_str: str) -> dict:
|
||||
return tool_args
|
||||
|
||||
|
||||
def _json_type_matches(value: Any, expected_type: Any) -> bool:
|
||||
"""Basic JSON Schema type checking for common types.
|
||||
|
||||
expected_type can be a string (e.g., "string") or a list (union).
|
||||
This is intentionally lightweight; deeper validation can be added as needed.
|
||||
"""
|
||||
|
||||
def match_one(v: Any, t: str) -> bool:
|
||||
if t == "string":
|
||||
return isinstance(v, str)
|
||||
if t == "integer":
|
||||
# bool is subclass of int in Python; exclude
|
||||
return isinstance(v, int) and not isinstance(v, bool)
|
||||
if t == "number":
|
||||
return (isinstance(v, int) and not isinstance(v, bool)) or isinstance(v, float)
|
||||
if t == "boolean":
|
||||
return isinstance(v, bool)
|
||||
if t == "object":
|
||||
return isinstance(v, dict)
|
||||
if t == "array":
|
||||
return isinstance(v, list)
|
||||
if t == "null":
|
||||
return v is None
|
||||
# Fallback: don't over-reject on unknown types
|
||||
return True
|
||||
|
||||
if isinstance(expected_type, list):
|
||||
return any(match_one(value, t) for t in expected_type)
|
||||
if isinstance(expected_type, str):
|
||||
return match_one(value, expected_type)
|
||||
return True
|
||||
|
||||
|
||||
def _schema_accepts_value(prop_schema: Dict[str, Any], value: Any) -> bool:
|
||||
"""Check if a value is acceptable for a property schema.
|
||||
|
||||
Handles: type, enum, const, anyOf, oneOf (by shallow traversal).
|
||||
"""
|
||||
if prop_schema is None:
|
||||
return True
|
||||
|
||||
# const has highest precedence
|
||||
if "const" in prop_schema:
|
||||
return value == prop_schema["const"]
|
||||
|
||||
# enums
|
||||
if "enum" in prop_schema:
|
||||
try:
|
||||
return value in prop_schema["enum"]
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
# unions
|
||||
for union_key in ("anyOf", "oneOf"):
|
||||
if union_key in prop_schema and isinstance(prop_schema[union_key], list):
|
||||
for sub in prop_schema[union_key]:
|
||||
if _schema_accepts_value(sub, value):
|
||||
return True
|
||||
return False
|
||||
|
||||
# type-based
|
||||
if "type" in prop_schema:
|
||||
if not _json_type_matches(value, prop_schema["type"]):
|
||||
return False
|
||||
|
||||
# No strict constraints specified: accept
|
||||
return True
|
||||
|
||||
|
||||
def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Merge LLM-provided args with prefilled args from tool rules.
|
||||
|
||||
- Overlapping keys are replaced by prefilled values (prefilled wins).
|
||||
- Validates that prefilled keys exist on the tool schema and that values satisfy
|
||||
basic JSON Schema constraints (type/enum/const/anyOf/oneOf).
|
||||
- Returns merged args, or raises ValueError on invalid prefilled inputs.
|
||||
"""
|
||||
from letta.schemas.tool import Tool # local import to avoid circulars in type hints
|
||||
|
||||
assert isinstance(tool, Tool)
|
||||
schema = (tool.json_schema or {}).get("parameters", {})
|
||||
props: Dict[str, Any] = schema.get("properties", {}) if isinstance(schema, dict) else {}
|
||||
|
||||
errors: list[str] = []
|
||||
for k, v in prefilled_args.items():
|
||||
if k not in props:
|
||||
errors.append(f"Unknown argument '{k}' for tool '{tool.name}'.")
|
||||
continue
|
||||
if not _schema_accepts_value(props.get(k), v):
|
||||
expected = props.get(k, {}).get("type")
|
||||
errors.append(f"Invalid value for '{k}': {v!r} does not match expected schema type {expected!r}.")
|
||||
|
||||
if errors:
|
||||
raise ValueError("; ".join(errors))
|
||||
|
||||
merged = dict(llm_args or {})
|
||||
merged.update(prefilled_args)
|
||||
return merged
|
||||
|
||||
|
||||
def _pop_heartbeat(tool_args: dict) -> bool:
|
||||
hb = tool_args.pop("request_heartbeat", False)
|
||||
return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb)
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.agents.helpers import (
|
||||
_prepare_in_context_messages_no_persist_async,
|
||||
_safe_load_tool_call_str,
|
||||
generate_step_id,
|
||||
merge_and_validate_prefilled_args,
|
||||
)
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM
|
||||
@@ -678,6 +679,68 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if tool_rule_violated:
|
||||
tool_execution_result = _build_rule_violation_result(tool_call_name, valid_tool_names, tool_rules_solver)
|
||||
else:
|
||||
# Prefill + validate args if a rule provided them
|
||||
prefill_args = self.tool_rules_solver.last_prefilled_args_by_tool.get(tool_call_name)
|
||||
if prefill_args:
|
||||
# Find tool object for schema validation
|
||||
target_tool = next((t for t in agent_state.tools if t.name == tool_call_name), None)
|
||||
provenance = self.tool_rules_solver.last_prefilled_args_provenance.get(tool_call_name)
|
||||
try:
|
||||
tool_args = merge_and_validate_prefilled_args(
|
||||
tool=target_tool,
|
||||
llm_args=tool_args,
|
||||
prefilled_args=prefill_args,
|
||||
)
|
||||
except ValueError as ve:
|
||||
# Treat invalid prefilled args as user error and end the step
|
||||
error_prefix = "Invalid prefilled tool arguments from tool rules"
|
||||
prov_suffix = f" (source={provenance})" if provenance else ""
|
||||
err_msg = f"{error_prefix}{prov_suffix}: {str(ve)}"
|
||||
tool_execution_result = ToolExecutionResult(status="error", func_return=err_msg)
|
||||
|
||||
# Create messages and early return persistence path below
|
||||
continue_stepping, heartbeat_reason, stop_reason = (
|
||||
False,
|
||||
None,
|
||||
LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value),
|
||||
)
|
||||
tool_call_messages = create_letta_messages_from_llm_response(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
function_arguments=tool_args,
|
||||
tool_execution_result=tool_execution_result,
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=False,
|
||||
function_response=tool_execution_result.func_return,
|
||||
timezone=agent_state.timezone,
|
||||
actor=self.actor,
|
||||
continue_stepping=continue_stepping,
|
||||
heartbeat_reason=None,
|
||||
reasoning_content=content,
|
||||
pre_computed_assistant_message_id=pre_computed_assistant_message_id,
|
||||
step_id=step_id,
|
||||
run_id=run_id,
|
||||
is_approval_response=is_approval or is_denial,
|
||||
force_set_request_heartbeat=False,
|
||||
add_heartbeat_on_continue=False,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
|
||||
# Set run_id on all messages before persisting
|
||||
for message in messages_to_persist:
|
||||
if message.run_id is None:
|
||||
message.run_id = run_id
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=agent_state.project_id,
|
||||
template_id=agent_state.template_id,
|
||||
)
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
# Track tool execution time
|
||||
tool_start_time = get_utc_timestamp_ns()
|
||||
tool_execution_result = await self._execute_tool(
|
||||
|
||||
@@ -50,6 +50,16 @@ class ToolRulesSolver(BaseModel):
|
||||
)
|
||||
tool_call_history: list[str] = Field(default_factory=list, description="History of tool calls, updated with each tool call.")
|
||||
|
||||
# Last-evaluated prefilled args cache (per step)
|
||||
last_prefilled_args_by_tool: dict[str, dict] = Field(
|
||||
default_factory=dict, description="Cached mapping of tool name to prefilled args from the last allowlist evaluation.", exclude=True
|
||||
)
|
||||
last_prefilled_args_provenance: dict[str, str] = Field(
|
||||
default_factory=dict,
|
||||
description="Cached mapping of tool name to a short description of which rule provided the prefilled args.",
|
||||
exclude=True,
|
||||
)
|
||||
|
||||
def __init__(self, tool_rules: list[ToolRule] | None = None, **kwargs):
|
||||
super().__init__(tool_rules=tool_rules, **kwargs)
|
||||
|
||||
@@ -88,15 +98,17 @@ class ToolRulesSolver(BaseModel):
|
||||
) -> list[ToolName]:
|
||||
"""Get a list of tool names allowed based on the last tool called.
|
||||
|
||||
Side-effect: also caches any prefilled args provided by active rules into
|
||||
`last_prefilled_args_by_tool` and `last_prefilled_args_provenance`.
|
||||
|
||||
The logic is as follows:
|
||||
1. if there are no previous tool calls, and we have InitToolRules, those are the only options for the first tool call
|
||||
2. else we take the intersection of the Parent/Child/Conditional/MaxSteps as the options
|
||||
3. Continue/Terminal/RequiredBeforeExit rules are applied in the agent loop flow, not to restrict tools
|
||||
"""
|
||||
# TODO: This piece of code here is quite ugly and deserves a refactor
|
||||
# TODO: -> Tool rules should probably be refactored to take in a set of tool names?
|
||||
# Compute allowed tools first
|
||||
if not self.tool_call_history and self.init_tool_rules:
|
||||
return [rule.tool_name for rule in self.init_tool_rules]
|
||||
allowed = [rule.tool_name for rule in self.init_tool_rules]
|
||||
else:
|
||||
valid_tool_sets = []
|
||||
for rule in self.child_based_tool_rules + self.parent_tool_rules:
|
||||
@@ -109,7 +121,42 @@ class ToolRulesSolver(BaseModel):
|
||||
if error_on_empty and not final_allowed_tools:
|
||||
raise ValueError("No valid tools found based on tool rules.")
|
||||
|
||||
return list(final_allowed_tools)
|
||||
allowed = list(final_allowed_tools)
|
||||
|
||||
# Build prefilled args cache for current allowed set
|
||||
args_by_tool: dict[str, dict] = {}
|
||||
provenance_by_tool: dict[str, str] = {}
|
||||
|
||||
def _store_args(tool_name: str, args: dict, rule: BaseModel):
|
||||
if not isinstance(args, dict) or len(args) == 0:
|
||||
return
|
||||
if tool_name not in args_by_tool:
|
||||
args_by_tool[tool_name] = {}
|
||||
args_by_tool[tool_name].update(args) # last-write-wins
|
||||
provenance_by_tool[tool_name] = f"{rule.__class__.__name__}({getattr(rule, 'tool_name', tool_name)})"
|
||||
|
||||
allowed_set = set(allowed)
|
||||
|
||||
if not self.tool_call_history and self.init_tool_rules:
|
||||
for rule in self.init_tool_rules:
|
||||
if hasattr(rule, "args") and getattr(rule, "args") and rule.tool_name in allowed_set:
|
||||
_store_args(rule.tool_name, getattr(rule, "args"), rule)
|
||||
else:
|
||||
for rule in (
|
||||
self.child_based_tool_rules
|
||||
+ self.parent_tool_rules
|
||||
+ self.continue_tool_rules
|
||||
+ self.terminal_tool_rules
|
||||
+ self.required_before_exit_tool_rules
|
||||
+ self.requires_approval_tool_rules
|
||||
):
|
||||
if hasattr(rule, "args") and getattr(rule, "args") and getattr(rule, "tool_name", None) in allowed_set:
|
||||
_store_args(rule.tool_name, getattr(rule, "args"), rule)
|
||||
|
||||
self.last_prefilled_args_by_tool = args_by_tool
|
||||
self.last_prefilled_args_provenance = provenance_by_tool
|
||||
|
||||
return allowed
|
||||
|
||||
def is_terminal_tool(self, tool_name: ToolName) -> bool:
|
||||
"""Check if the tool is defined as a terminal tool in the terminal tool rules or required-before-exit tool rules."""
|
||||
|
||||
@@ -209,6 +209,14 @@ class InitToolRule(BaseToolRule):
|
||||
"""
|
||||
|
||||
type: Literal[ToolRuleType.run_first] = ToolRuleType.run_first
|
||||
args: Optional[Dict[str, Any]] = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Optional prefilled arguments for this tool. When present, these values will override any LLM-provided "
|
||||
"arguments with the same keys during invocation. Keys must match the tool's parameter names and values "
|
||||
"must satisfy the tool's JSON schema. Supports partial prefill; non-overlapping parameters are left to the model."
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def requires_force_tool_call(self) -> bool:
|
||||
|
||||
@@ -8,6 +8,7 @@ from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.letta_message import ToolCallMessage
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.tool_rule import ChildToolRule, ContinueToolRule, InitToolRule, RequiredBeforeExitToolRule, TerminalToolRule
|
||||
@@ -225,6 +226,29 @@ async def cleanup_temp_files_tool(server):
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def validate_api_key_tool(server):
|
||||
SECRET = "REAL_KEY_123"
|
||||
|
||||
def validate_api_key(secret_key: str):
|
||||
"""
|
||||
Validates an API key; errors if incorrect.
|
||||
|
||||
Args:
|
||||
secret_key (str): The provided key.
|
||||
|
||||
Returns:
|
||||
str: Confirmation string when key is valid.
|
||||
"""
|
||||
if secret_key != SECRET:
|
||||
raise RuntimeError(f"Invalid secret key: {secret_key}")
|
||||
return "api key accepted"
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async()
|
||||
tool = await server.tool_manager.create_or_update_tool_async(create_tool_from_func(func=validate_api_key), actor=actor)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
async def validate_work_tool(server):
|
||||
def validate_work():
|
||||
@@ -536,6 +560,126 @@ async def test_init_tool_rule_always_fails(
|
||||
await cleanup_async(server=server, agent_uuid=agent_uuid, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_tool_rule_args_override_llm_payload(server, disable_e2b_api_key, validate_api_key_tool, default_user):
|
||||
"""InitToolRule args should override LLM-provided args for the initial tool call."""
|
||||
REAL = "REAL_KEY_123"
|
||||
|
||||
tools = [validate_api_key_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="validate_api_key", args={"secret_key": REAL}),
|
||||
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
agent_name = str(uuid.uuid4())
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
OPENAI_CONFIG,
|
||||
agent_uuid=agent_name,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
# Ask the model to call with a fake key; prefilled args should override
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Please validate my API key: FAKE_KEY_999, then send me confirmation.",
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert_sanity_checks(response)
|
||||
assert_invoked_function_call(response.messages, "validate_api_key")
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
# Verify the recorded tool-call arguments reflect the prefilled override
|
||||
for m in response.messages:
|
||||
if isinstance(m, ToolCallMessage) and m.tool_call.name == "validate_api_key":
|
||||
args = json.loads(m.tool_call.arguments)
|
||||
assert args.get("secret_key") == REAL
|
||||
break
|
||||
|
||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_tool_rule_invalid_prefilled_type_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
|
||||
"""Invalid prefilled args should produce an error and prevent further flow (no send_message)."""
|
||||
tools = [validate_api_key_tool]
|
||||
# Provide wrong type for secret_key (expects string)
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="validate_api_key", args={"secret_key": 123}),
|
||||
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
agent_name = str(uuid.uuid4())
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
OPENAI_CONFIG,
|
||||
agent_uuid=agent_name,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[MessageCreate(role="user", content="try validating my key")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
|
||||
|
||||
# Should attempt validate_api_key but not proceed to send_message
|
||||
assert_invoked_function_call(response.messages, "validate_api_key")
|
||||
with pytest.raises(Exception):
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.timeout(60)
|
||||
@pytest.mark.asyncio
|
||||
async def test_init_tool_rule_unknown_prefilled_key_blocks_flow(server, disable_e2b_api_key, validate_api_key_tool, default_user):
|
||||
"""Unknown prefilled arg key should error and block flow."""
|
||||
tools = [validate_api_key_tool]
|
||||
tool_rules = [
|
||||
InitToolRule(tool_name="validate_api_key", args={"not_a_param": "value"}),
|
||||
ChildToolRule(tool_name="validate_api_key", children=["send_message"]),
|
||||
TerminalToolRule(tool_name="send_message"),
|
||||
]
|
||||
|
||||
agent_name = str(uuid.uuid4())
|
||||
agent_state = await setup_agent(
|
||||
server,
|
||||
OPENAI_CONFIG,
|
||||
agent_uuid=agent_name,
|
||||
tool_ids=[t.id for t in tools],
|
||||
tool_rules=tool_rules,
|
||||
)
|
||||
|
||||
response = await run_agent_step(
|
||||
agent_state=agent_state,
|
||||
input_messages=[MessageCreate(role="user", content="validate with your best guess")],
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert response.stop_reason == LettaStopReason(message_type="stop_reason", stop_reason=StopReasonType.invalid_tool_call)
|
||||
|
||||
assert_invoked_function_call(response.messages, "validate_api_key")
|
||||
with pytest.raises(Exception):
|
||||
assert_invoked_function_call(response.messages, "send_message")
|
||||
|
||||
await cleanup_async(server=server, agent_uuid=agent_name, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_continue_tool_rule(server, default_user):
|
||||
"""Test the continue tool rule by forcing send_message to loop before ending with core_memory_append."""
|
||||
|
||||
@@ -1,6 +1,9 @@
|
||||
import pytest
|
||||
|
||||
from letta.agents.helpers import merge_and_validate_prefilled_args
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.schemas.enums import ToolType
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import (
|
||||
ChildToolRule,
|
||||
ConditionalToolRule,
|
||||
@@ -722,3 +725,181 @@ def test_should_force_tool_call_mixed_rules():
|
||||
|
||||
solver.register_tool_call(NEXT_TOOL)
|
||||
assert solver.should_force_tool_call() is False, "Should return False when no constraining rules are active"
|
||||
|
||||
|
||||
def make_tool(name: str, properties: dict) -> Tool:
|
||||
"""Helper to build a minimal custom Tool with a JSON schema."""
|
||||
return Tool(
|
||||
name=name,
|
||||
tool_type=ToolType.CUSTOM,
|
||||
json_schema={
|
||||
"name": name,
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"required": [],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def test_init_rule_args_are_cached_in_solver():
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"x": 1, "y": "s"})])
|
||||
allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
|
||||
|
||||
assert set(allowed) == {"alpha"}
|
||||
# Cached mappings
|
||||
assert solver.last_prefilled_args_by_tool == {"alpha": {"x": 1, "y": "s"}}
|
||||
assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)"
|
||||
|
||||
|
||||
def test_cached_provenance_format():
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="tool_one", args={"a": 123})])
|
||||
_ = solver.get_allowed_tool_names(available_tools={"tool_one"})
|
||||
prov = solver.last_prefilled_args_provenance.get("tool_one")
|
||||
assert prov.startswith("InitToolRule(") and prov.endswith(")") and "tool_one" in prov
|
||||
|
||||
|
||||
def test_cache_empty_when_no_args():
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha")])
|
||||
allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
|
||||
|
||||
assert set(allowed) == {"alpha"}
|
||||
assert solver.last_prefilled_args_by_tool == {}
|
||||
assert solver.last_prefilled_args_provenance == {}
|
||||
|
||||
|
||||
def test_cache_recomputed_on_next_call():
|
||||
# First call caches args for init tool
|
||||
solver = ToolRulesSolver(tool_rules=[InitToolRule(tool_name="alpha", args={"p": 5})])
|
||||
_ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
|
||||
assert solver.last_prefilled_args_by_tool == {"alpha": {"p": 5}}
|
||||
|
||||
# After a tool call, init rules no longer apply; next computation should clear caches
|
||||
solver.register_tool_call("alpha")
|
||||
_ = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
|
||||
assert solver.last_prefilled_args_by_tool == {}
|
||||
assert solver.last_prefilled_args_provenance == {}
|
||||
|
||||
|
||||
def test_merge_and_validate_prefilled_args_overrides_llm_values():
|
||||
tool = make_tool("my_tool", properties={"a": {"type": "integer"}, "b": {"type": "string"}})
|
||||
llm_args = {"a": 1, "b": "hello"}
|
||||
prefilled = {"a": 42}
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, llm_args, prefilled)
|
||||
assert merged == {"a": 42, "b": "hello"}
|
||||
|
||||
|
||||
def test_merge_and_validate_prefilled_args_type_validation():
|
||||
tool = make_tool("typed_tool", properties={"a": {"type": "integer"}})
|
||||
llm_args = {"a": 1}
|
||||
prefilled = {"a": "not-an-int"}
|
||||
|
||||
with pytest.raises(ValueError) as ei:
|
||||
_ = merge_and_validate_prefilled_args(tool, llm_args, prefilled)
|
||||
assert "Invalid value for 'a'" in str(ei.value)
|
||||
assert "integer" in str(ei.value)
|
||||
|
||||
|
||||
def test_merge_and_validate_prefilled_args_unknown_key_fails():
|
||||
tool = make_tool("limited_tool", properties={"a": {"type": "integer"}})
|
||||
with pytest.raises(ValueError) as ei:
|
||||
_ = merge_and_validate_prefilled_args(tool, llm_args={}, prefilled_args={"z": 3})
|
||||
assert "Unknown argument 'z'" in str(ei.value)
|
||||
|
||||
|
||||
def test_merge_and_validate_prefilled_args_enum_const_anyof_oneof():
|
||||
tool = make_tool(
|
||||
"rich_tool",
|
||||
properties={
|
||||
"c": {"enum": ["x", "y"]},
|
||||
"d": {"const": 5},
|
||||
"e": {"anyOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
"f": {"oneOf": [{"type": "string"}, {"type": "integer"}]},
|
||||
"g": {"type": "number"},
|
||||
},
|
||||
)
|
||||
|
||||
# Valid cases
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"c": "x"})
|
||||
assert merged["c"] == "x"
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"d": 5})
|
||||
assert merged["d"] == 5
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"e": 7})
|
||||
assert merged["e"] == 7
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"f": "hello"})
|
||||
assert merged["f"] == "hello"
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3.14})
|
||||
assert merged["g"] == 3.14
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"g": 3})
|
||||
assert merged["g"] == 3
|
||||
|
||||
# Invalid cases
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"c": "z"}) # enum fail
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"d": 6}) # const fail
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"e": []}) # anyOf none match
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"f": []}) # oneOf none match
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"g": True}) # bool not a number
|
||||
|
||||
|
||||
def test_merge_and_validate_prefilled_args_union_with_null():
|
||||
tool = make_tool("union_tool", properties={"h": {"type": ["string", "null"]}})
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"h": None})
|
||||
assert "h" in merged and merged["h"] is None
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"h": "ok"})
|
||||
assert merged["h"] == "ok"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"h": 5})
|
||||
|
||||
|
||||
def test_merge_and_validate_prefilled_args_object_and_array_types():
|
||||
tool = make_tool(
|
||||
"container_tool",
|
||||
properties={
|
||||
"obj": {"type": "object"},
|
||||
"arr": {"type": "array"},
|
||||
},
|
||||
)
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"obj": {"k": 1}})
|
||||
assert merged["obj"] == {"k": 1}
|
||||
|
||||
merged = merge_and_validate_prefilled_args(tool, {}, {"arr": [1, 2, 3]})
|
||||
assert merged["arr"] == [1, 2, 3]
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"obj": "nope"})
|
||||
with pytest.raises(ValueError):
|
||||
_ = merge_and_validate_prefilled_args(tool, {}, {"arr": {}})
|
||||
|
||||
|
||||
def test_multiple_rules_args_last_write_wins_and_provenance():
|
||||
# Two init rules for the same tool; the latter should overwrite overlapping keys and provenance
|
||||
r1 = InitToolRule(tool_name="alpha", args={"x": 1, "y": "first"})
|
||||
r2 = InitToolRule(tool_name="alpha", args={"y": "second", "z": True})
|
||||
solver = ToolRulesSolver(tool_rules=[r1, r2])
|
||||
|
||||
allowed = solver.get_allowed_tool_names(available_tools={"alpha", "beta"})
|
||||
assert set(allowed) == {"alpha"}
|
||||
|
||||
assert solver.last_prefilled_args_by_tool["alpha"] == {"x": 1, "y": "second", "z": True}
|
||||
assert solver.last_prefilled_args_provenance.get("alpha") == "InitToolRule(alpha)"
|
||||
|
||||
Reference in New Issue
Block a user