fix: refactor enable strict mode for structured output (#8840)
* base * test
This commit is contained in:
@@ -1647,7 +1647,9 @@ class LettaAgent(BaseAgent):
|
||||
if len(valid_tool_names) == 1:
|
||||
force_tool_call = valid_tool_names[0]
|
||||
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
allowed_tools = [
|
||||
enable_strict_mode(t.json_schema, strict=agent_state.llm_config.strict) for t in tools if t.name in set(valid_tool_names)
|
||||
]
|
||||
# Extract terminal tool names from tool rules
|
||||
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
|
||||
allowed_tools = runtime_override_tool_json_schema(
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, MCP_TOOL_METADATA_SCHEMA_WARNINGS
|
||||
@@ -8,6 +9,83 @@ from letta.utils import get_logger
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def _make_field_nullable(field_props: Dict[str, Any]) -> None:
|
||||
"""Make a field schema nullable by adding 'null' to its type.
|
||||
|
||||
This modifies field_props in place.
|
||||
|
||||
Args:
|
||||
field_props: The field schema to make nullable
|
||||
"""
|
||||
if "type" in field_props:
|
||||
field_type = field_props["type"]
|
||||
if isinstance(field_type, list):
|
||||
# Already an array of types - add null if not present
|
||||
if "null" not in field_type:
|
||||
field_type.append("null")
|
||||
elif field_type != "null":
|
||||
# Single type - convert to array with null
|
||||
field_props["type"] = [field_type, "null"]
|
||||
elif "anyOf" in field_props:
|
||||
# Check if null is already one of the options
|
||||
has_null = any(opt.get("type") == "null" for opt in field_props["anyOf"])
|
||||
if not has_null:
|
||||
field_props["anyOf"].append({"type": "null"})
|
||||
elif "$ref" in field_props:
|
||||
# For $ref schemas, wrap in anyOf with null option
|
||||
ref_value = field_props.pop("$ref")
|
||||
field_props["anyOf"] = [{"$ref": ref_value}, {"type": "null"}]
|
||||
else:
|
||||
# No type specified, add null type
|
||||
field_props["type"] = "null"
|
||||
|
||||
|
||||
def _process_property_for_strict_mode(prop: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Recursively process a property for strict mode.
|
||||
|
||||
Handles nested objects, arrays, and anyOf structures by setting
|
||||
additionalProperties: False and adding all properties to required.
|
||||
|
||||
Args:
|
||||
prop: The property schema to process
|
||||
|
||||
Returns:
|
||||
The processed property schema
|
||||
"""
|
||||
# Handle anyOf structures
|
||||
if "anyOf" in prop:
|
||||
prop["anyOf"] = [_process_property_for_strict_mode(opt) for opt in prop["anyOf"]]
|
||||
return prop
|
||||
|
||||
if "type" not in prop:
|
||||
return prop
|
||||
|
||||
param_type = prop["type"]
|
||||
|
||||
# Handle type arrays (e.g., ["string", "null"])
|
||||
if isinstance(param_type, list):
|
||||
return prop
|
||||
|
||||
if param_type == "object":
|
||||
if "properties" in prop:
|
||||
properties = prop["properties"]
|
||||
# Recursively process nested properties
|
||||
for key, value in properties.items():
|
||||
properties[key] = _process_property_for_strict_mode(value)
|
||||
# Set additionalProperties to False and require all properties
|
||||
prop["additionalProperties"] = False
|
||||
prop["required"] = list(properties.keys())
|
||||
return prop
|
||||
|
||||
elif param_type == "array":
|
||||
if "items" in prop:
|
||||
prop["items"] = _process_property_for_strict_mode(prop["items"])
|
||||
return prop
|
||||
|
||||
# Simple types - return as-is
|
||||
return prop
|
||||
|
||||
|
||||
def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict[str, Any]:
|
||||
"""Enables strict mode for a tool schema by setting 'strict' to True and
|
||||
disallowing additional properties in the parameters.
|
||||
@@ -15,6 +93,12 @@ def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict
|
||||
If the tool schema is NON_STRICT_ONLY, strict mode will not be applied.
|
||||
If strict=False, the function will only clean metadata without applying strict mode.
|
||||
|
||||
When strict mode is enabled:
|
||||
- All properties are added to the 'required' array (OpenAI requirement)
|
||||
- Optional properties are made nullable (type includes 'null') to preserve optionality
|
||||
- additionalProperties is set to False
|
||||
- Nested objects and arrays are recursively processed
|
||||
|
||||
Args:
|
||||
tool_schema (Dict[str, Any]): The original tool schema.
|
||||
strict (bool): Whether to enable strict mode. Defaults to True.
|
||||
@@ -22,7 +106,8 @@ def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict
|
||||
Returns:
|
||||
Dict[str, Any]: A new tool schema with strict mode conditionally enabled.
|
||||
"""
|
||||
schema = tool_schema.copy()
|
||||
# Deep copy to avoid mutating the original schema
|
||||
schema = copy.deepcopy(tool_schema)
|
||||
|
||||
# Check if schema has status metadata indicating NON_STRICT_ONLY
|
||||
schema_status = schema.get(MCP_TOOL_METADATA_SCHEMA_STATUS)
|
||||
@@ -48,9 +133,28 @@ def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict
|
||||
# Ensure parameters is a valid dictionary
|
||||
parameters = schema.get("parameters", {})
|
||||
if isinstance(parameters, dict) and parameters.get("type") == "object":
|
||||
# Set additionalProperties to False
|
||||
# Set additionalProperties to False (required for OpenAI strict mode)
|
||||
parameters["additionalProperties"] = False
|
||||
|
||||
# Get properties and current required list
|
||||
properties = parameters.get("properties", {})
|
||||
current_required = set(parameters.get("required", []))
|
||||
|
||||
# Process each property recursively and handle required/nullable
|
||||
for field_name, field_props in properties.items():
|
||||
# Recursively process nested structures
|
||||
properties[field_name] = _process_property_for_strict_mode(field_props)
|
||||
|
||||
# OpenAI strict mode requires ALL properties to be in the required array
|
||||
# For optional properties, we add them to required but make them nullable
|
||||
if field_name not in current_required:
|
||||
# Make the field nullable to preserve optionality
|
||||
_make_field_nullable(properties[field_name])
|
||||
|
||||
# Set all properties as required
|
||||
parameters["required"] = list(properties.keys())
|
||||
schema["parameters"] = parameters
|
||||
|
||||
# Remove the metadata fields from the schema
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_STATUS, None)
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_WARNINGS, None)
|
||||
|
||||
@@ -31,7 +31,6 @@ from letta.llm_api.error_utils import is_context_window_overflow_message
|
||||
from letta.llm_api.helpers import (
|
||||
add_inner_thoughts_to_functions,
|
||||
convert_response_format_to_responses_api,
|
||||
convert_to_structured_output,
|
||||
unpack_all_inner_thoughts_from_kwargs,
|
||||
)
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
@@ -297,40 +296,20 @@ class OpenAIClient(LLMClientBase):
|
||||
new_tools.append(tool.model_copy(deep=True))
|
||||
typed_tools = new_tools
|
||||
|
||||
# Convert to strict mode when strict is enabled
|
||||
if llm_config.strict and supports_structured_output(llm_config):
|
||||
for tool in typed_tools:
|
||||
try:
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
|
||||
# Finally convert to a Responses-friendly dict
|
||||
responses_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": t.function.name,
|
||||
"description": t.function.description,
|
||||
"parameters": t.function.parameters,
|
||||
"strict": True,
|
||||
}
|
||||
for t in typed_tools
|
||||
]
|
||||
|
||||
else:
|
||||
# Finally convert to a Responses-friendly dict
|
||||
# Note: strict field is required by OpenAI SDK's FunctionToolParam type
|
||||
responses_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": t.function.name,
|
||||
"description": t.function.description,
|
||||
"parameters": t.function.parameters,
|
||||
"strict": False,
|
||||
}
|
||||
for t in typed_tools
|
||||
]
|
||||
# Note: Tools are already processed by enable_strict_mode() in the workflow/agent code
|
||||
# (temporal_letta_v1_agent_workflow.py or letta_agent_v3.py) before reaching here.
|
||||
# enable_strict_mode() handles: strict flag, additionalProperties, required array, nullable fields
|
||||
# Convert to a Responses-friendly dict, preserving the strict setting from the tool schema
|
||||
responses_tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"name": t.function.name,
|
||||
"description": t.function.description,
|
||||
"parameters": t.function.parameters,
|
||||
"strict": t.function.strict,
|
||||
}
|
||||
for t in typed_tools
|
||||
]
|
||||
else:
|
||||
responses_tools = None
|
||||
|
||||
@@ -560,19 +539,15 @@ class OpenAIClient(LLMClientBase):
|
||||
new_tools.append(tool.model_copy(deep=True))
|
||||
data.tools = new_tools
|
||||
|
||||
# Note: Tools are already processed by enable_strict_mode() in the workflow/agent code
|
||||
# (temporal_letta_v1_agent_workflow.py or letta_agent_v3.py) before reaching here.
|
||||
# enable_strict_mode() handles: strict flag, additionalProperties, required array, nullable fields
|
||||
# We only need to ensure strict is False for providers that don't support structured output
|
||||
if data.tools is not None and len(data.tools) > 0:
|
||||
# Convert to structured output style when strict is enabled
|
||||
for tool in data.tools:
|
||||
if llm_config.strict and supports_structured_output(llm_config):
|
||||
try:
|
||||
structured_output_version = convert_to_structured_output(tool.function.model_dump())
|
||||
tool.function = FunctionSchema(**structured_output_version)
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
|
||||
else:
|
||||
# Ensure strict is False when not using structured output
|
||||
# This overrides any strict: True that may have been set by enable_strict_mode()
|
||||
tool.function.strict = False if not supports_structured_output(llm_config) else tool.function.strict
|
||||
if not supports_structured_output(llm_config):
|
||||
# Provider doesn't support structured output - ensure strict is False
|
||||
tool.function.strict = False
|
||||
request_data = data.model_dump(exclude_unset=True)
|
||||
|
||||
# If Ollama
|
||||
|
||||
@@ -12,8 +12,9 @@ from pydantic import BaseModel
|
||||
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.schema_generator import validate_google_style_docstring
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.llm_api.helpers import convert_to_structured_output
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, Tool, ToolCreate
|
||||
|
||||
|
||||
def _clean_diff(d1, d2):
|
||||
@@ -674,3 +675,253 @@ def test_complex_nested_anyof_schema_to_structured_output():
|
||||
|
||||
except Exception as e:
|
||||
pytest.fail(f"Failed to convert complex nested anyOf schema to structured output: {str(e)}")
|
||||
|
||||
|
||||
# ========== enable_strict_mode tests ==========
|
||||
|
||||
|
||||
def test_enable_strict_mode_adds_all_properties_to_required():
|
||||
"""Test that enable_strict_mode adds all properties to required array."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
"optional_field": {"type": "integer"},
|
||||
},
|
||||
"required": ["required_field"],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
assert result["strict"] is True
|
||||
assert set(result["parameters"]["required"]) == {"required_field", "optional_field"}
|
||||
|
||||
|
||||
def test_enable_strict_mode_makes_optional_fields_nullable():
|
||||
"""Test that optional fields are made nullable."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"required_field": {"type": "string"},
|
||||
"optional_field": {"type": "integer"},
|
||||
},
|
||||
"required": ["required_field"],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
# Required field should NOT be made nullable
|
||||
assert result["parameters"]["properties"]["required_field"]["type"] == "string"
|
||||
# Optional field should be made nullable
|
||||
assert result["parameters"]["properties"]["optional_field"]["type"] == ["integer", "null"]
|
||||
|
||||
|
||||
def test_enable_strict_mode_recursive_nested_objects():
|
||||
"""Test recursive handling of nested objects."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nested_field": {"type": "string"},
|
||||
"another_nested": {"type": "integer"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["config"],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
nested = result["parameters"]["properties"]["config"]
|
||||
assert nested["additionalProperties"] is False
|
||||
assert set(nested["required"]) == {"nested_field", "another_nested"}
|
||||
|
||||
|
||||
def test_enable_strict_mode_recursive_arrays():
|
||||
"""Test recursive handling of arrays with object items."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"item_field": {"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": ["items"],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
array_items = result["parameters"]["properties"]["items"]["items"]
|
||||
assert array_items["additionalProperties"] is False
|
||||
assert array_items["required"] == ["item_field"]
|
||||
|
||||
|
||||
def test_enable_strict_mode_strict_false_no_modification():
|
||||
"""Test that strict=False doesn't modify schema structure."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {"type": "string"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=False)
|
||||
|
||||
assert "strict" not in result
|
||||
assert result["parameters"]["required"] == []
|
||||
# Verify the field type is unchanged
|
||||
assert result["parameters"]["properties"]["field"]["type"] == "string"
|
||||
|
||||
|
||||
def test_enable_strict_mode_non_strict_only_tool():
|
||||
"""Test that NON_STRICT_ONLY tools are not modified."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
MCP_TOOL_METADATA_SCHEMA_STATUS: "NON_STRICT_ONLY",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {"type": "string"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
# Strict mode should not be applied
|
||||
assert "strict" not in result
|
||||
# Metadata should be removed
|
||||
assert MCP_TOOL_METADATA_SCHEMA_STATUS not in result
|
||||
# Required should be unchanged
|
||||
assert result["parameters"]["required"] == []
|
||||
|
||||
|
||||
def test_enable_strict_mode_preserves_existing_required():
|
||||
"""Test that fields already in required are not made nullable."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"already_required": {"type": "string"},
|
||||
"optional_field": {"type": "integer"},
|
||||
},
|
||||
"required": ["already_required"],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
# already_required should NOT be made nullable (it was already required)
|
||||
assert result["parameters"]["properties"]["already_required"]["type"] == "string"
|
||||
# optional_field should be made nullable
|
||||
assert result["parameters"]["properties"]["optional_field"]["type"] == ["integer", "null"]
|
||||
# Both should now be in required
|
||||
assert set(result["parameters"]["required"]) == {"already_required", "optional_field"}
|
||||
|
||||
|
||||
def test_enable_strict_mode_handles_anyof():
|
||||
"""Test that anyOf structures are recursively processed."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"nested_field": {"type": "string"},
|
||||
},
|
||||
},
|
||||
{"type": "null"},
|
||||
],
|
||||
},
|
||||
},
|
||||
"required": ["config"],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
# The object inside anyOf should have additionalProperties and required set
|
||||
anyof_options = result["parameters"]["properties"]["config"]["anyOf"]
|
||||
object_option = next(opt for opt in anyof_options if opt.get("type") == "object")
|
||||
assert object_option["additionalProperties"] is False
|
||||
assert object_option["required"] == ["nested_field"]
|
||||
|
||||
|
||||
def test_enable_strict_mode_handles_type_array_nullable():
|
||||
"""Test that fields with type array (already nullable) are handled correctly."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"already_nullable": {"type": ["string", "null"]},
|
||||
"not_nullable": {"type": "integer"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
# Already nullable field should not get duplicate null
|
||||
already_nullable_type = result["parameters"]["properties"]["already_nullable"]["type"]
|
||||
assert already_nullable_type.count("null") == 1
|
||||
# Not nullable should become nullable
|
||||
assert result["parameters"]["properties"]["not_nullable"]["type"] == ["integer", "null"]
|
||||
|
||||
|
||||
def test_enable_strict_mode_does_not_mutate_original():
|
||||
"""Test that the original schema is not mutated."""
|
||||
schema = {
|
||||
"name": "test_tool",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"field": {"type": "string"},
|
||||
},
|
||||
"required": [],
|
||||
},
|
||||
}
|
||||
|
||||
original_required = schema["parameters"]["required"].copy()
|
||||
original_field_type = schema["parameters"]["properties"]["field"]["type"]
|
||||
|
||||
result = enable_strict_mode(schema, strict=True)
|
||||
|
||||
# Original should be unchanged
|
||||
assert schema["parameters"]["required"] == original_required
|
||||
assert schema["parameters"]["properties"]["field"]["type"] == original_field_type
|
||||
assert "strict" not in schema
|
||||
# Result should be different
|
||||
assert result["strict"] is True
|
||||
assert len(result["parameters"]["required"]) == 1
|
||||
|
||||
Reference in New Issue
Block a user