fix: refactor enable strict mode for structured output (#8840)

* base

* test
This commit is contained in:
jnjpng
2026-01-16 12:52:42 -08:00
committed by Sarah Wooders
parent b62ce02930
commit a98bc31bf3
4 changed files with 383 additions and 51 deletions

View File

@@ -1647,7 +1647,9 @@ class LettaAgent(BaseAgent):
if len(valid_tool_names) == 1:
force_tool_call = valid_tool_names[0]
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
allowed_tools = [
enable_strict_mode(t.json_schema, strict=agent_state.llm_config.strict) for t in tools if t.name in set(valid_tool_names)
]
# Extract terminal tool names from tool rules
terminal_tool_names = {rule.tool_name for rule in tool_rules_solver.terminal_tool_rules}
allowed_tools = runtime_override_tool_json_schema(

View File

@@ -1,5 +1,6 @@
import copy
from collections import OrderedDict
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, MCP_TOOL_METADATA_SCHEMA_WARNINGS
@@ -8,6 +9,83 @@ from letta.utils import get_logger
logger = get_logger(__name__)
def _make_field_nullable(field_props: Dict[str, Any]) -> None:
"""Make a field schema nullable by adding 'null' to its type.
This modifies field_props in place.
Args:
field_props: The field schema to make nullable
"""
if "type" in field_props:
field_type = field_props["type"]
if isinstance(field_type, list):
# Already an array of types - add null if not present
if "null" not in field_type:
field_type.append("null")
elif field_type != "null":
# Single type - convert to array with null
field_props["type"] = [field_type, "null"]
elif "anyOf" in field_props:
# Check if null is already one of the options
has_null = any(opt.get("type") == "null" for opt in field_props["anyOf"])
if not has_null:
field_props["anyOf"].append({"type": "null"})
elif "$ref" in field_props:
# For $ref schemas, wrap in anyOf with null option
ref_value = field_props.pop("$ref")
field_props["anyOf"] = [{"$ref": ref_value}, {"type": "null"}]
else:
# No type specified, add null type
field_props["type"] = "null"
def _process_property_for_strict_mode(prop: Dict[str, Any]) -> Dict[str, Any]:
"""Recursively process a property for strict mode.
Handles nested objects, arrays, and anyOf structures by setting
additionalProperties: False and adding all properties to required.
Args:
prop: The property schema to process
Returns:
The processed property schema
"""
# Handle anyOf structures
if "anyOf" in prop:
prop["anyOf"] = [_process_property_for_strict_mode(opt) for opt in prop["anyOf"]]
return prop
if "type" not in prop:
return prop
param_type = prop["type"]
# Handle type arrays (e.g., ["string", "null"])
if isinstance(param_type, list):
return prop
if param_type == "object":
if "properties" in prop:
properties = prop["properties"]
# Recursively process nested properties
for key, value in properties.items():
properties[key] = _process_property_for_strict_mode(value)
# Set additionalProperties to False and require all properties
prop["additionalProperties"] = False
prop["required"] = list(properties.keys())
return prop
elif param_type == "array":
if "items" in prop:
prop["items"] = _process_property_for_strict_mode(prop["items"])
return prop
# Simple types - return as-is
return prop
def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict[str, Any]:
"""Enables strict mode for a tool schema by setting 'strict' to True and
disallowing additional properties in the parameters.
@@ -15,6 +93,12 @@ def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict
If the tool schema is NON_STRICT_ONLY, strict mode will not be applied.
If strict=False, the function will only clean metadata without applying strict mode.
When strict mode is enabled:
- All properties are added to the 'required' array (OpenAI requirement)
- Optional properties are made nullable (type includes 'null') to preserve optionality
- additionalProperties is set to False
- Nested objects and arrays are recursively processed
Args:
tool_schema (Dict[str, Any]): The original tool schema.
strict (bool): Whether to enable strict mode. Defaults to True.
@@ -22,7 +106,8 @@ def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict
Returns:
Dict[str, Any]: A new tool schema with strict mode conditionally enabled.
"""
schema = tool_schema.copy()
# Deep copy to avoid mutating the original schema
schema = copy.deepcopy(tool_schema)
# Check if schema has status metadata indicating NON_STRICT_ONLY
schema_status = schema.get(MCP_TOOL_METADATA_SCHEMA_STATUS)
@@ -48,9 +133,28 @@ def enable_strict_mode(tool_schema: Dict[str, Any], strict: bool = True) -> Dict
# Ensure parameters is a valid dictionary
parameters = schema.get("parameters", {})
if isinstance(parameters, dict) and parameters.get("type") == "object":
# Set additionalProperties to False
# Set additionalProperties to False (required for OpenAI strict mode)
parameters["additionalProperties"] = False
# Get properties and current required list
properties = parameters.get("properties", {})
current_required = set(parameters.get("required", []))
# Process each property recursively and handle required/nullable
for field_name, field_props in properties.items():
# Recursively process nested structures
properties[field_name] = _process_property_for_strict_mode(field_props)
# OpenAI strict mode requires ALL properties to be in the required array
# For optional properties, we add them to required but make them nullable
if field_name not in current_required:
# Make the field nullable to preserve optionality
_make_field_nullable(properties[field_name])
# Set all properties as required
parameters["required"] = list(properties.keys())
schema["parameters"] = parameters
# Remove the metadata fields from the schema
schema.pop(MCP_TOOL_METADATA_SCHEMA_STATUS, None)
schema.pop(MCP_TOOL_METADATA_SCHEMA_WARNINGS, None)

View File

@@ -31,7 +31,6 @@ from letta.llm_api.error_utils import is_context_window_overflow_message
from letta.llm_api.helpers import (
add_inner_thoughts_to_functions,
convert_response_format_to_responses_api,
convert_to_structured_output,
unpack_all_inner_thoughts_from_kwargs,
)
from letta.llm_api.llm_client_base import LLMClientBase
@@ -297,40 +296,20 @@ class OpenAIClient(LLMClientBase):
new_tools.append(tool.model_copy(deep=True))
typed_tools = new_tools
# Convert to strict mode when strict is enabled
if llm_config.strict and supports_structured_output(llm_config):
for tool in typed_tools:
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
# Finally convert to a Responses-friendly dict
responses_tools = [
{
"type": "function",
"name": t.function.name,
"description": t.function.description,
"parameters": t.function.parameters,
"strict": True,
}
for t in typed_tools
]
else:
# Finally convert to a Responses-friendly dict
# Note: strict field is required by OpenAI SDK's FunctionToolParam type
responses_tools = [
{
"type": "function",
"name": t.function.name,
"description": t.function.description,
"parameters": t.function.parameters,
"strict": False,
}
for t in typed_tools
]
# Note: Tools are already processed by enable_strict_mode() in the workflow/agent code
# (temporal_letta_v1_agent_workflow.py or letta_agent_v3.py) before reaching here.
# enable_strict_mode() handles: strict flag, additionalProperties, required array, nullable fields
# Convert to a Responses-friendly dict, preserving the strict setting from the tool schema
responses_tools = [
{
"type": "function",
"name": t.function.name,
"description": t.function.description,
"parameters": t.function.parameters,
"strict": t.function.strict,
}
for t in typed_tools
]
else:
responses_tools = None
@@ -560,19 +539,15 @@ class OpenAIClient(LLMClientBase):
new_tools.append(tool.model_copy(deep=True))
data.tools = new_tools
# Note: Tools are already processed by enable_strict_mode() in the workflow/agent code
# (temporal_letta_v1_agent_workflow.py or letta_agent_v3.py) before reaching here.
# enable_strict_mode() handles: strict flag, additionalProperties, required array, nullable fields
# We only need to ensure strict is False for providers that don't support structured output
if data.tools is not None and len(data.tools) > 0:
# Convert to structured output style when strict is enabled
for tool in data.tools:
if llm_config.strict and supports_structured_output(llm_config):
try:
structured_output_version = convert_to_structured_output(tool.function.model_dump())
tool.function = FunctionSchema(**structured_output_version)
except ValueError as e:
logger.warning(f"Failed to convert tool function to structured output, tool={tool}, error={e}")
else:
# Ensure strict is False when not using structured output
# This overrides any strict: True that may have been set by enable_strict_mode()
tool.function.strict = False if not supports_structured_output(llm_config) else tool.function.strict
if not supports_structured_output(llm_config):
# Provider doesn't support structured output - ensure strict is False
tool.function.strict = False
request_data = data.model_dump(exclude_unset=True)
# If Ollama

View File

@@ -12,8 +12,9 @@ from pydantic import BaseModel
from letta.functions.functions import derive_openai_json_schema
from letta.functions.schema_generator import validate_google_style_docstring
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.llm_api.helpers import convert_to_structured_output
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, Tool, ToolCreate
def _clean_diff(d1, d2):
@@ -674,3 +675,253 @@ def test_complex_nested_anyof_schema_to_structured_output():
except Exception as e:
pytest.fail(f"Failed to convert complex nested anyOf schema to structured output: {str(e)}")
# ========== enable_strict_mode tests ==========
def test_enable_strict_mode_adds_all_properties_to_required():
"""Test that enable_strict_mode adds all properties to required array."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"required_field": {"type": "string"},
"optional_field": {"type": "integer"},
},
"required": ["required_field"],
},
}
result = enable_strict_mode(schema, strict=True)
assert result["strict"] is True
assert set(result["parameters"]["required"]) == {"required_field", "optional_field"}
def test_enable_strict_mode_makes_optional_fields_nullable():
"""Test that optional fields are made nullable."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"required_field": {"type": "string"},
"optional_field": {"type": "integer"},
},
"required": ["required_field"],
},
}
result = enable_strict_mode(schema, strict=True)
# Required field should NOT be made nullable
assert result["parameters"]["properties"]["required_field"]["type"] == "string"
# Optional field should be made nullable
assert result["parameters"]["properties"]["optional_field"]["type"] == ["integer", "null"]
def test_enable_strict_mode_recursive_nested_objects():
"""Test recursive handling of nested objects."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"config": {
"type": "object",
"properties": {
"nested_field": {"type": "string"},
"another_nested": {"type": "integer"},
},
},
},
"required": ["config"],
},
}
result = enable_strict_mode(schema, strict=True)
nested = result["parameters"]["properties"]["config"]
assert nested["additionalProperties"] is False
assert set(nested["required"]) == {"nested_field", "another_nested"}
def test_enable_strict_mode_recursive_arrays():
"""Test recursive handling of arrays with object items."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"item_field": {"type": "string"},
},
},
},
},
"required": ["items"],
},
}
result = enable_strict_mode(schema, strict=True)
array_items = result["parameters"]["properties"]["items"]["items"]
assert array_items["additionalProperties"] is False
assert array_items["required"] == ["item_field"]
def test_enable_strict_mode_strict_false_no_modification():
"""Test that strict=False doesn't modify schema structure."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"field": {"type": "string"},
},
"required": [],
},
}
result = enable_strict_mode(schema, strict=False)
assert "strict" not in result
assert result["parameters"]["required"] == []
# Verify the field type is unchanged
assert result["parameters"]["properties"]["field"]["type"] == "string"
def test_enable_strict_mode_non_strict_only_tool():
"""Test that NON_STRICT_ONLY tools are not modified."""
schema = {
"name": "test_tool",
MCP_TOOL_METADATA_SCHEMA_STATUS: "NON_STRICT_ONLY",
"parameters": {
"type": "object",
"properties": {
"field": {"type": "string"},
},
"required": [],
},
}
result = enable_strict_mode(schema, strict=True)
# Strict mode should not be applied
assert "strict" not in result
# Metadata should be removed
assert MCP_TOOL_METADATA_SCHEMA_STATUS not in result
# Required should be unchanged
assert result["parameters"]["required"] == []
def test_enable_strict_mode_preserves_existing_required():
"""Test that fields already in required are not made nullable."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"already_required": {"type": "string"},
"optional_field": {"type": "integer"},
},
"required": ["already_required"],
},
}
result = enable_strict_mode(schema, strict=True)
# already_required should NOT be made nullable (it was already required)
assert result["parameters"]["properties"]["already_required"]["type"] == "string"
# optional_field should be made nullable
assert result["parameters"]["properties"]["optional_field"]["type"] == ["integer", "null"]
# Both should now be in required
assert set(result["parameters"]["required"]) == {"already_required", "optional_field"}
def test_enable_strict_mode_handles_anyof():
"""Test that anyOf structures are recursively processed."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"config": {
"anyOf": [
{
"type": "object",
"properties": {
"nested_field": {"type": "string"},
},
},
{"type": "null"},
],
},
},
"required": ["config"],
},
}
result = enable_strict_mode(schema, strict=True)
# The object inside anyOf should have additionalProperties and required set
anyof_options = result["parameters"]["properties"]["config"]["anyOf"]
object_option = next(opt for opt in anyof_options if opt.get("type") == "object")
assert object_option["additionalProperties"] is False
assert object_option["required"] == ["nested_field"]
def test_enable_strict_mode_handles_type_array_nullable():
"""Test that fields with type array (already nullable) are handled correctly."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"already_nullable": {"type": ["string", "null"]},
"not_nullable": {"type": "integer"},
},
"required": [],
},
}
result = enable_strict_mode(schema, strict=True)
# Already nullable field should not get duplicate null
already_nullable_type = result["parameters"]["properties"]["already_nullable"]["type"]
assert already_nullable_type.count("null") == 1
# Not nullable should become nullable
assert result["parameters"]["properties"]["not_nullable"]["type"] == ["integer", "null"]
def test_enable_strict_mode_does_not_mutate_original():
"""Test that the original schema is not mutated."""
schema = {
"name": "test_tool",
"parameters": {
"type": "object",
"properties": {
"field": {"type": "string"},
},
"required": [],
},
}
original_required = schema["parameters"]["required"].copy()
original_field_type = schema["parameters"]["properties"]["field"]["type"]
result = enable_strict_mode(schema, strict=True)
# Original should be unchanged
assert schema["parameters"]["required"] == original_required
assert schema["parameters"]["properties"]["field"]["type"] == original_field_type
assert "strict" not in schema
# Result should be different
assert result["strict"] is True
assert len(result["parameters"]["required"]) == 1