fix: handle faulty schemas from bad mcp servers better
Co-authored-by: jnjpng <jin@letta.com> Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -18,9 +18,20 @@ TEMPLATED_VARIABLE_REGEX = (
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPToolHealth(BaseModel):
|
||||
"""Health status for an MCP tool's schema."""
|
||||
|
||||
# TODO: @jnjpng use the enum provided in schema_validator.py
|
||||
status: str = Field(..., description="Schema health status: STRICT_COMPLIANT, NON_STRICT_ONLY, or INVALID")
|
||||
reasons: List[str] = Field(default_factory=list, description="List of reasons for the health status")
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
||||
|
||||
# Optional health information added at runtime
|
||||
health: Optional[MCPToolHealth] = Field(None, description="Schema health status for OpenAI strict mode")
|
||||
|
||||
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
|
||||
187
letta/functions/schema_validator.py
Normal file
187
letta/functions/schema_validator.py
Normal file
@@ -0,0 +1,187 @@
|
||||
"""
|
||||
JSON Schema validator for OpenAI strict mode compliance.
|
||||
|
||||
This module provides validation for JSON schemas to ensure they comply with
|
||||
OpenAI's strict mode requirements for tool schemas.
|
||||
"""
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
||||
class SchemaHealth(Enum):
|
||||
"""Schema health status for OpenAI strict mode compliance."""
|
||||
|
||||
STRICT_COMPLIANT = "STRICT_COMPLIANT" # Passes OpenAI strict mode
|
||||
NON_STRICT_ONLY = "NON_STRICT_ONLY" # Valid JSON Schema but too loose for strict mode
|
||||
INVALID = "INVALID" # Broken for both
|
||||
|
||||
|
||||
def validate_complete_json_schema(schema: Dict[str, Any]) -> Tuple[SchemaHealth, List[str]]:
|
||||
"""
|
||||
Validate schema for OpenAI tool strict mode compliance.
|
||||
|
||||
This validator checks for:
|
||||
- Valid JSON Schema structure
|
||||
- OpenAI strict mode requirements
|
||||
- Special cases like required properties with empty object schemas
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to validate
|
||||
|
||||
Returns:
|
||||
A tuple of (SchemaHealth, list_of_reasons)
|
||||
"""
|
||||
|
||||
reasons: List[str] = []
|
||||
status = SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
def mark_non_strict(reason: str):
|
||||
"""Mark schema as non-strict only (valid but not strict-compliant)."""
|
||||
nonlocal status
|
||||
if status == SchemaHealth.STRICT_COMPLIANT:
|
||||
status = SchemaHealth.NON_STRICT_ONLY
|
||||
reasons.append(reason)
|
||||
|
||||
def mark_invalid(reason: str):
|
||||
"""Mark schema as invalid."""
|
||||
nonlocal status
|
||||
status = SchemaHealth.INVALID
|
||||
reasons.append(reason)
|
||||
|
||||
def schema_allows_empty_object(obj_schema: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Return True if this object schema allows {}, meaning no required props
|
||||
and no additionalProperties content.
|
||||
"""
|
||||
if obj_schema.get("type") != "object":
|
||||
return False
|
||||
props = obj_schema.get("properties", {})
|
||||
required = obj_schema.get("required", [])
|
||||
additional = obj_schema.get("additionalProperties", True)
|
||||
|
||||
# Empty object: no required props and additionalProperties is false
|
||||
if not required and additional is False:
|
||||
return True
|
||||
return False
|
||||
|
||||
def schema_allows_empty_array(arr_schema: Dict[str, Any]) -> bool:
|
||||
"""
|
||||
Return True if this array schema allows empty arrays with no constraints.
|
||||
"""
|
||||
if arr_schema.get("type") != "array":
|
||||
return False
|
||||
|
||||
# If minItems is set and > 0, it doesn't allow empty
|
||||
min_items = arr_schema.get("minItems", 0)
|
||||
if min_items > 0:
|
||||
return False
|
||||
|
||||
# If items schema is not defined or very permissive, it allows empty
|
||||
items = arr_schema.get("items")
|
||||
if items is None:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def recurse(node: Dict[str, Any], path: str, is_root: bool = False):
|
||||
"""Recursively validate a schema node."""
|
||||
node_type = node.get("type")
|
||||
|
||||
# Handle schemas without explicit type but with type-specific keywords
|
||||
if not node_type:
|
||||
# Check for type-specific keywords
|
||||
if "properties" in node or "additionalProperties" in node:
|
||||
node_type = "object"
|
||||
elif "items" in node:
|
||||
node_type = "array"
|
||||
elif any(kw in node for kw in ["anyOf", "oneOf", "allOf"]):
|
||||
# Union types don't require explicit type
|
||||
pass
|
||||
else:
|
||||
mark_invalid(f"{path}: Missing 'type'")
|
||||
return
|
||||
|
||||
# OBJECT
|
||||
if node_type == "object":
|
||||
props = node.get("properties")
|
||||
if props is not None and not isinstance(props, dict):
|
||||
mark_invalid(f"{path}: 'properties' must be a dict for objects")
|
||||
return
|
||||
|
||||
if "additionalProperties" not in node:
|
||||
mark_non_strict(f"{path}: 'additionalProperties' not explicitly set")
|
||||
elif node["additionalProperties"] is not False:
|
||||
mark_non_strict(f"{path}: 'additionalProperties' is not false (free-form object)")
|
||||
|
||||
required = node.get("required")
|
||||
if required is None:
|
||||
# Only mark as non-strict for nested objects, not root
|
||||
if not is_root:
|
||||
mark_non_strict(f"{path}: 'required' not specified for object")
|
||||
required = []
|
||||
elif not isinstance(required, list):
|
||||
mark_invalid(f"{path}: 'required' must be a list if present")
|
||||
required = []
|
||||
|
||||
# OpenAI strict-mode extra checks:
|
||||
for req_key in required:
|
||||
if props and req_key not in props:
|
||||
mark_invalid(f"{path}: required contains '{req_key}' not found in properties")
|
||||
elif props:
|
||||
req_schema = props[req_key]
|
||||
if isinstance(req_schema, dict):
|
||||
# Check for empty object issue
|
||||
if schema_allows_empty_object(req_schema):
|
||||
mark_invalid(f"{path}: required property '{req_key}' allows empty object (OpenAI will reject)")
|
||||
# Check for empty array issue
|
||||
if schema_allows_empty_array(req_schema):
|
||||
mark_invalid(f"{path}: required property '{req_key}' allows empty array (OpenAI will reject)")
|
||||
|
||||
# Recurse into properties
|
||||
if props:
|
||||
for prop_name, prop_schema in props.items():
|
||||
if isinstance(prop_schema, dict):
|
||||
recurse(prop_schema, f"{path}.properties.{prop_name}", is_root=False)
|
||||
else:
|
||||
mark_invalid(f"{path}.properties.{prop_name}: Not a valid schema dict")
|
||||
|
||||
# ARRAY
|
||||
elif node_type == "array":
|
||||
items = node.get("items")
|
||||
if items is None:
|
||||
mark_invalid(f"{path}: 'items' must be defined for arrays in strict mode")
|
||||
elif not isinstance(items, dict):
|
||||
mark_invalid(f"{path}: 'items' must be a schema dict for arrays")
|
||||
else:
|
||||
recurse(items, f"{path}.items", is_root=False)
|
||||
|
||||
# PRIMITIVE TYPES
|
||||
elif node_type in ["string", "number", "integer", "boolean", "null"]:
|
||||
# These are generally fine, but check for specific constraints
|
||||
pass
|
||||
|
||||
# UNION TYPES
|
||||
for kw in ("anyOf", "oneOf", "allOf"):
|
||||
if kw in node:
|
||||
if not isinstance(node[kw], list):
|
||||
mark_invalid(f"{path}: '{kw}' must be a list")
|
||||
else:
|
||||
for idx, sub_schema in enumerate(node[kw]):
|
||||
if isinstance(sub_schema, dict):
|
||||
recurse(sub_schema, f"{path}.{kw}[{idx}]", is_root=False)
|
||||
else:
|
||||
mark_invalid(f"{path}.{kw}[{idx}]: Not a valid schema dict")
|
||||
|
||||
# Start validation
|
||||
if not isinstance(schema, dict):
|
||||
return SchemaHealth.INVALID, ["Top-level schema must be a dict"]
|
||||
|
||||
# OpenAI tools require top-level type to be object
|
||||
if schema.get("type") != "object":
|
||||
mark_invalid("Top-level schema 'type' must be 'object' for OpenAI tools")
|
||||
|
||||
# Begin recursive validation
|
||||
recurse(schema, "root", is_root=True)
|
||||
|
||||
return status, reasons
|
||||
@@ -2,21 +2,39 @@ from collections import OrderedDict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from letta.constants import PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.schemas.tool import MCP_TOOL_METADATA_SCHEMA_STATUS, MCP_TOOL_METADATA_SCHEMA_WARNINGS
|
||||
from letta.utils import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Enables strict mode for a tool schema by setting 'strict' to True and
|
||||
disallowing additional properties in the parameters.
|
||||
|
||||
If the tool schema is NON_STRICT_ONLY, strict mode will not be applied.
|
||||
|
||||
Args:
|
||||
tool_schema (Dict[str, Any]): The original tool schema.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: A new tool schema with strict mode enabled.
|
||||
Dict[str, Any]: A new tool schema with strict mode conditionally enabled.
|
||||
"""
|
||||
schema = tool_schema.copy()
|
||||
|
||||
# Enable strict mode
|
||||
# Check if schema has status metadata indicating NON_STRICT_ONLY
|
||||
schema_status = schema.get(MCP_TOOL_METADATA_SCHEMA_STATUS)
|
||||
if schema_status == "NON_STRICT_ONLY":
|
||||
# Don't apply strict mode for non-strict schemas
|
||||
# Remove the metadata fields from the schema
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_STATUS, None)
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_WARNINGS, None)
|
||||
return schema
|
||||
elif schema_status == "INVALID":
|
||||
# We should not be hitting this and allowing invalid schemas to be used
|
||||
logger.error(f"Tool schema {schema} is invalid: {schema.get(MCP_TOOL_METADATA_SCHEMA_WARNINGS)}")
|
||||
|
||||
# Enable strict mode for STRICT_COMPLIANT or unspecified health status
|
||||
schema["strict"] = True
|
||||
|
||||
# Ensure parameters is a valid dictionary
|
||||
@@ -26,6 +44,11 @@ def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
# Set additionalProperties to False
|
||||
parameters["additionalProperties"] = False
|
||||
schema["parameters"] = parameters
|
||||
|
||||
# Remove the metadata fields from the schema
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_STATUS, None)
|
||||
schema.pop(MCP_TOOL_METADATA_SCHEMA_WARNINGS, None)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,10 @@ from letta.constants import (
|
||||
LETTA_VOICE_TOOL_MODULE_NAME,
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
)
|
||||
|
||||
# MCP Tool metadata constants for schema health status
|
||||
MCP_TOOL_METADATA_SCHEMA_STATUS = f"{MCP_TOOL_TAG_NAME_PREFIX}:SCHEMA_STATUS"
|
||||
MCP_TOOL_METADATA_SCHEMA_WARNINGS = f"{MCP_TOOL_TAG_NAME_PREFIX}:SCHEMA_WARNINGS"
|
||||
from letta.functions.ast_parsers import get_function_name_and_docstring
|
||||
from letta.functions.composio_helpers import generate_composio_tool_wrapper
|
||||
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
|
||||
@@ -171,6 +175,11 @@ class ToolCreate(LettaBase):
|
||||
# Pass the MCP tool to the schema generator
|
||||
json_schema = generate_tool_schema_for_mcp(mcp_tool=mcp_tool)
|
||||
|
||||
# Store health status in json_schema metadata if available
|
||||
if mcp_tool.health:
|
||||
json_schema[MCP_TOOL_METADATA_SCHEMA_STATUS] = mcp_tool.health.status
|
||||
json_schema[MCP_TOOL_METADATA_SCHEMA_WARNINGS] = mcp_tool.health.reasons
|
||||
|
||||
# Return a ToolCreate instance
|
||||
description = mcp_tool.description
|
||||
source_type = "python"
|
||||
|
||||
@@ -486,6 +486,20 @@ async def add_mcp_tool(
|
||||
},
|
||||
)
|
||||
|
||||
# Check tool health - reject only INVALID tools
|
||||
if mcp_tool.health:
|
||||
if mcp_tool.health.status == "INVALID":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"code": "MCPToolSchemaInvalid",
|
||||
"message": f"Tool {mcp_tool_name} has an invalid schema and cannot be attached",
|
||||
"mcp_tool_name": mcp_tool_name,
|
||||
"health_status": mcp_tool.health.status,
|
||||
"reasons": mcp_tool.health.reasons,
|
||||
},
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
# For config-based servers, use the server name as ID since they don't have database IDs
|
||||
mcp_server_id = mcp_server_name
|
||||
|
||||
@@ -23,7 +23,8 @@ from letta.config import LettaConfig
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.errors import HandleNotFoundError
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, MCPToolHealth, SSEServerConfig, StdioServerConfig
|
||||
from letta.functions.schema_validator import validate_complete_json_schema
|
||||
from letta.groups.helpers import load_multi_agent
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
@@ -2063,7 +2064,15 @@ class SyncServer(Server):
|
||||
if mcp_server_name not in self.mcp_clients:
|
||||
raise ValueError(f"No client was created for MCP server: {mcp_server_name}")
|
||||
|
||||
return await self.mcp_clients[mcp_server_name].list_tools()
|
||||
tools = await self.mcp_clients[mcp_server_name].list_tools()
|
||||
|
||||
# Add health information to each tool
|
||||
for tool in tools:
|
||||
if tool.inputSchema:
|
||||
health_status, reasons = validate_complete_json_schema(tool.inputSchema)
|
||||
tool.health = MCPToolHealth(status=health_status.value, reasons=reasons)
|
||||
|
||||
return tools
|
||||
|
||||
async def add_mcp_server_to_config(
|
||||
self, server_config: Union[SSEServerConfig, StdioServerConfig], allow_upsert: bool = True
|
||||
|
||||
@@ -10,7 +10,15 @@ from sqlalchemy import delete, null
|
||||
from starlette.requests import Request
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.functions.mcp_client.types import (
|
||||
MCPServerType,
|
||||
MCPTool,
|
||||
MCPToolHealth,
|
||||
SSEServerConfig,
|
||||
StdioServerConfig,
|
||||
StreamableHTTPServerConfig,
|
||||
)
|
||||
from letta.functions.schema_validator import validate_complete_json_schema
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
|
||||
@@ -59,6 +67,13 @@ class MCPManager:
|
||||
|
||||
# list tools
|
||||
tools = await mcp_client.list_tools()
|
||||
|
||||
# Add health information to each tool
|
||||
for tool in tools:
|
||||
if tool.inputSchema:
|
||||
health_status, reasons = validate_complete_json_schema(tool.inputSchema)
|
||||
tool.health = MCPToolHealth(status=health_status.value, reasons=reasons)
|
||||
|
||||
return tools
|
||||
except Exception as e:
|
||||
# MCP tool listing errors are often due to connection/configuration issues, not system errors
|
||||
@@ -116,7 +131,16 @@ class MCPManager:
|
||||
mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor)
|
||||
|
||||
for mcp_tool in mcp_tools:
|
||||
# TODO: @jnjpng move health check to tool class
|
||||
if mcp_tool.name == mcp_tool_name:
|
||||
# Check tool health - reject only INVALID tools
|
||||
if mcp_tool.health:
|
||||
if mcp_tool.health.status == "INVALID":
|
||||
raise ValueError(
|
||||
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid."
|
||||
f"Reasons: {', '.join(mcp_tool.health.reasons)}"
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return await self.tool_manager.create_mcp_tool_async(
|
||||
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
|
||||
|
||||
@@ -210,6 +210,83 @@ def test_stdio_mcp_server(client, agent_state):
|
||||
assert len(ret.tool_return.strip()) >= 10, f"Expected at least 10 characters in tool_return, got {len(ret.tool_return.strip())}"
|
||||
|
||||
|
||||
# Optional OpenAI validation test for MCP-normalized schema
|
||||
# Skips unless OPENAI_API_KEY is set to avoid network flakiness in CI
|
||||
EXAMPLE_BAD_SCHEMA = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"conversation_type": {
|
||||
"type": "string",
|
||||
"const": "Group",
|
||||
"description": "Specifies the type of conversation to be created. Must be 'Group' for this action.",
|
||||
},
|
||||
"message": {
|
||||
"type": "object",
|
||||
"additionalProperties": {}, # invalid for OpenAI: missing "type"
|
||||
"description": "Initial message payload",
|
||||
},
|
||||
"participant_ids": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Participant IDs",
|
||||
},
|
||||
},
|
||||
"required": ["conversation_type", "message", "participant_ids"],
|
||||
"additionalProperties": False,
|
||||
"$schema": "http://json-schema.org/draft-07/schema#",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not os.getenv("OPENAI_API_KEY"),
|
||||
reason="Requires OPENAI_API_KEY to call OpenAI for schema validation",
|
||||
)
|
||||
def test_openai_rejects_untyped_additional_properties_and_accepts_normalized_schema():
|
||||
"""Test written to check if our extra schema validation works.
|
||||
|
||||
Some MCP servers will return faulty schemas that require correction, or they will brick the LLM client calls.
|
||||
"""
|
||||
import copy
|
||||
|
||||
try:
|
||||
from openai import OpenAI
|
||||
except Exception as e: # pragma: no cover
|
||||
pytest.skip(f"openai package not available: {e}")
|
||||
|
||||
client = OpenAI()
|
||||
|
||||
def run_request_with_schema(schema: dict):
|
||||
tools = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "TWITTER_CREATE_A_NEW_DM_CONVERSATION",
|
||||
"description": "Create a DM conversation",
|
||||
"parameters": schema,
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
return client.chat.completions.create(
|
||||
model="gpt-4o-mini",
|
||||
messages=[{"role": "user", "content": "hello"}],
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
# Bad schema should raise
|
||||
with pytest.raises(Exception):
|
||||
run_request_with_schema(EXAMPLE_BAD_SCHEMA)
|
||||
|
||||
# Normalized should succeed
|
||||
normalized = copy.deepcopy(EXAMPLE_BAD_SCHEMA)
|
||||
normalized["properties"]["message"]["additionalProperties"] = False
|
||||
normalized["properties"]["message"]["properties"] = {"text": {"type": "string"}}
|
||||
normalized["properties"]["message"]["required"] = ["text"]
|
||||
resp = run_request_with_schema(normalized)
|
||||
assert getattr(resp, "id", None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_streamable_http_mcp_server_update_schema_no_docstring_required(client, agent_state, server_url):
|
||||
"""
|
||||
|
||||
185
tests/mcp_tests/test_mcp_schema_validation.py
Normal file
185
tests/mcp_tests/test_mcp_schema_validation.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
Test MCP tool schema validation integration.
|
||||
"""
|
||||
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
|
||||
from letta.functions.schema_validator import SchemaHealth, validate_complete_json_schema
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_tools_get_health_status():
|
||||
"""Test that MCP tools receive health status when listed."""
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# Create mock tools with different schema types
|
||||
mock_tools = [
|
||||
# Strict compliant tool
|
||||
MCPTool(
|
||||
name="strict_tool",
|
||||
inputSchema={"type": "object", "properties": {"text": {"type": "string"}}, "required": ["text"], "additionalProperties": False},
|
||||
),
|
||||
# Non-strict tool (free-form object)
|
||||
MCPTool(
|
||||
name="non_strict_tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "object", "additionalProperties": {}}}, # Free-form object
|
||||
"required": ["message"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
),
|
||||
# Invalid tool (missing type)
|
||||
MCPTool(name="invalid_tool", inputSchema={"properties": {"data": {"type": "string"}}, "required": ["data"]}),
|
||||
]
|
||||
|
||||
# Mock the server and client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
|
||||
# Call the method directly
|
||||
actual_server = SyncServer.__new__(SyncServer)
|
||||
actual_server.mcp_clients = {"test_server": mock_client}
|
||||
|
||||
tools = await actual_server.get_tools_from_mcp_server("test_server")
|
||||
|
||||
# Verify health status was added
|
||||
assert len(tools) == 3
|
||||
|
||||
# Check strict tool
|
||||
strict_tool = tools[0]
|
||||
assert strict_tool.name == "strict_tool"
|
||||
assert strict_tool.health is not None
|
||||
assert strict_tool.health.status == SchemaHealth.STRICT_COMPLIANT.value
|
||||
assert strict_tool.health.reasons == []
|
||||
|
||||
# Check non-strict tool
|
||||
non_strict_tool = tools[1]
|
||||
assert non_strict_tool.name == "non_strict_tool"
|
||||
assert non_strict_tool.health is not None
|
||||
assert non_strict_tool.health.status == SchemaHealth.NON_STRICT_ONLY.value
|
||||
assert len(non_strict_tool.health.reasons) > 0
|
||||
assert any("additionalProperties" in reason for reason in non_strict_tool.health.reasons)
|
||||
|
||||
# Check invalid tool
|
||||
invalid_tool = tools[2]
|
||||
assert invalid_tool.name == "invalid_tool"
|
||||
assert invalid_tool.health is not None
|
||||
assert invalid_tool.health.status == SchemaHealth.INVALID.value
|
||||
assert len(invalid_tool.health.reasons) > 0
|
||||
assert any("type" in reason for reason in invalid_tool.health.reasons)
|
||||
|
||||
|
||||
def test_composio_like_schema_marked_non_strict():
|
||||
"""Test that Composio-like schemas are correctly marked as NON_STRICT_ONLY."""
|
||||
|
||||
# Example schema from Composio with free-form message object
|
||||
composio_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {"type": "object", "additionalProperties": {}, "description": "Message to send"} # Free-form, missing "type"
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(composio_schema)
|
||||
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
assert len(reasons) > 0
|
||||
assert any("additionalProperties" in reason for reason in reasons)
|
||||
|
||||
|
||||
def test_empty_object_in_required_marked_invalid():
|
||||
"""Test that required properties allowing empty objects are marked INVALID."""
|
||||
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {"type": "object", "properties": {}, "required": [], "additionalProperties": False} # Empty object schema
|
||||
},
|
||||
"required": ["config"], # Required but allows empty object
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("empty object" in reason for reason in reasons)
|
||||
assert any("config" in reason for reason in reasons)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_mcp_tool_rejects_non_strict_schemas():
|
||||
"""Test that adding MCP tools with non-strict schemas is rejected."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from letta.server.rest_api.routers.v1.tools import add_mcp_tool
|
||||
from letta.settings import tool_settings
|
||||
|
||||
# Mock a non-strict tool
|
||||
non_strict_tool = MCPTool(
|
||||
name="test_tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {"message": {"type": "object"}}, # Missing additionalProperties: false
|
||||
"required": ["message"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
)
|
||||
non_strict_tool.health = MCPToolHealth(status=SchemaHealth.NON_STRICT_ONLY.value, reasons=["Missing additionalProperties for message"])
|
||||
|
||||
# Mock server response
|
||||
with patch("letta.server.rest_api.routers.v1.tools.get_letta_server") as mock_get_server:
|
||||
with patch.object(tool_settings, "mcp_read_from_config", True): # Ensure we're using config path
|
||||
mock_server = AsyncMock()
|
||||
mock_server.get_tools_from_mcp_server = AsyncMock(return_value=[non_strict_tool])
|
||||
mock_server.user_manager.get_user_or_default = MagicMock()
|
||||
mock_get_server.return_value = mock_server
|
||||
|
||||
# Should raise HTTPException for non-strict schema
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, actor_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "non-strict schema" in exc_info.value.detail["message"].lower()
|
||||
assert exc_info.value.detail["health_status"] == SchemaHealth.NON_STRICT_ONLY.value
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_mcp_tool_rejects_invalid_schemas():
|
||||
"""Test that adding MCP tools with invalid schemas is rejected."""
|
||||
from fastapi import HTTPException
|
||||
|
||||
from letta.server.rest_api.routers.v1.tools import add_mcp_tool
|
||||
from letta.settings import tool_settings
|
||||
|
||||
# Mock an invalid tool
|
||||
invalid_tool = MCPTool(
|
||||
name="test_tool",
|
||||
inputSchema={
|
||||
"properties": {"data": {"type": "string"}},
|
||||
"required": ["data"],
|
||||
# Missing "type": "object"
|
||||
},
|
||||
)
|
||||
invalid_tool.health = MCPToolHealth(status=SchemaHealth.INVALID.value, reasons=["Missing 'type' at root level"])
|
||||
|
||||
# Mock server response
|
||||
with patch("letta.server.rest_api.routers.v1.tools.get_letta_server") as mock_get_server:
|
||||
with patch.object(tool_settings, "mcp_read_from_config", True): # Ensure we're using config path
|
||||
mock_server = AsyncMock()
|
||||
mock_server.get_tools_from_mcp_server = AsyncMock(return_value=[invalid_tool])
|
||||
mock_server.user_manager.get_user_or_default = MagicMock()
|
||||
mock_get_server.return_value = mock_server
|
||||
|
||||
# Should raise HTTPException for invalid schema
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
await add_mcp_tool(mcp_server_name="test_server", mcp_tool_name="test_tool", server=mock_server, actor_id=None)
|
||||
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "invalid schema" in exc_info.value.detail["message"].lower()
|
||||
assert exc_info.value.detail["health_status"] == SchemaHealth.INVALID.value
|
||||
314
tests/mcp_tests/test_schema_validator.py
Normal file
314
tests/mcp_tests/test_schema_validator.py
Normal file
@@ -0,0 +1,314 @@
|
||||
"""
|
||||
Unit tests for the JSON Schema validator for OpenAI strict mode compliance.
|
||||
"""
|
||||
|
||||
from letta.functions.schema_validator import SchemaHealth, validate_complete_json_schema
|
||||
|
||||
|
||||
class TestSchemaValidator:
|
||||
"""Test cases for the schema validator."""
|
||||
|
||||
def test_valid_strict_compliant_schema(self):
|
||||
"""Test a fully strict-compliant schema."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string", "description": "The name of the user"},
|
||||
"age": {"type": "integer", "description": "The age of the user"},
|
||||
"address": {
|
||||
"type": "object",
|
||||
"properties": {"street": {"type": "string"}, "city": {"type": "string"}},
|
||||
"required": ["street", "city"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_free_form_object_non_strict(self):
|
||||
"""Test that free-form objects (like Composio message) are marked as NON_STRICT_ONLY."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "A message object",
|
||||
# Missing additionalProperties: false makes this free-form
|
||||
}
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
assert any("additionalProperties" in reason for reason in reasons)
|
||||
|
||||
def test_empty_object_in_required_invalid(self):
|
||||
"""Test that required properties allowing empty objects are marked INVALID."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {"type": "object", "properties": {}, "required": [], "additionalProperties": False} # Empty object schema
|
||||
},
|
||||
"required": ["config"], # Required but allows empty object
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("empty object" in reason for reason in reasons)
|
||||
|
||||
def test_missing_type_invalid(self):
|
||||
"""Test that schemas missing type are marked INVALID."""
|
||||
schema = {
|
||||
# Missing "type": "object"
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name"],
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("type" in reason.lower() for reason in reasons)
|
||||
|
||||
def test_missing_items_in_array_invalid(self):
|
||||
"""Test that arrays without items definition are marked INVALID."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array"
|
||||
# Missing "items" definition
|
||||
}
|
||||
},
|
||||
"required": ["tags"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("items" in reason for reason in reasons)
|
||||
|
||||
def test_required_property_not_in_properties_invalid(self):
|
||||
"""Test that required properties not defined in properties are marked INVALID."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}},
|
||||
"required": ["name", "email"], # "email" not in properties
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("email" in reason and "not found" in reason for reason in reasons)
|
||||
|
||||
def test_nested_object_validation(self):
|
||||
"""Test that nested objects are properly validated."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"profile": {
|
||||
"type": "object",
|
||||
"properties": {"bio": {"type": "string"}},
|
||||
# Missing additionalProperties and required
|
||||
}
|
||||
},
|
||||
"required": ["profile"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
},
|
||||
"required": ["user"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
# Should have warnings about nested profile object
|
||||
assert any("profile" in reason.lower() or "properties.profile" in reason for reason in reasons)
|
||||
|
||||
def test_union_types_with_anyof(self):
|
||||
"""Test schemas with anyOf union types."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"value": {"anyOf": [{"type": "string"}, {"type": "number"}]}},
|
||||
"required": ["value"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_array_with_proper_items(self):
|
||||
"""Test arrays with properly defined items."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {"id": {"type": "string"}, "value": {"type": "number"}},
|
||||
"required": ["id", "value"],
|
||||
"additionalProperties": False,
|
||||
},
|
||||
}
|
||||
},
|
||||
"required": ["items"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_empty_array_in_required_invalid(self):
|
||||
"""Test that required properties allowing empty arrays are marked INVALID."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
# No minItems constraint, allows empty array
|
||||
}
|
||||
},
|
||||
"required": ["tags"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
# This should actually be STRICT_COMPLIANT since empty arrays with defined items are OK
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
|
||||
def test_array_without_constraints_invalid(self):
|
||||
"""Test that arrays without any constraints in required props are invalid."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "array"
|
||||
# No items defined at all - completely unconstrained
|
||||
}
|
||||
},
|
||||
"required": ["data"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("items" in reason for reason in reasons)
|
||||
|
||||
def test_composio_like_schema(self):
|
||||
"""Test a schema similar to Composio's free-form message structure."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "object",
|
||||
"description": "Message to send",
|
||||
# No properties defined, no additionalProperties: false
|
||||
# This is a free-form object
|
||||
}
|
||||
},
|
||||
"required": ["message"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
assert any("additionalProperties" in reason for reason in reasons)
|
||||
|
||||
def test_non_dict_schema(self):
|
||||
"""Test that non-dict schemas are marked INVALID."""
|
||||
schema = "not a dict"
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.INVALID
|
||||
assert any("dict" in reason for reason in reasons)
|
||||
|
||||
def test_schema_with_defaults_strict_compliant(self):
|
||||
"""Test that root-level schemas without required field are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {"name": {"type": "string"}, "optional": {"type": "string"}},
|
||||
# Missing "required" field at root level is OK
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
# After fix, root level without required should be STRICT_COMPLIANT
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_composio_schema_with_optional_root_properties_strict_compliant(self):
|
||||
"""Test that Composio-like schemas with optional root properties are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"thinking": {"type": "string", "description": "Deep inner monologue"},
|
||||
"connected_account_id": {"type": "string", "description": "Specific connected account ID"},
|
||||
"toolkit": {"type": "string", "description": "Name of the toolkit"},
|
||||
"request_heartbeat": {"type": "boolean", "description": "Request immediate heartbeat"},
|
||||
},
|
||||
"required": ["thinking", "request_heartbeat"], # Not all properties are required
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_root_level_without_required_strict_compliant(self):
|
||||
"""Test that root-level objects without 'required' field are STRICT_COMPLIANT."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
},
|
||||
# No "required" field at root level
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
# Root level without required should be STRICT_COMPLIANT
|
||||
assert status == SchemaHealth.STRICT_COMPLIANT
|
||||
assert reasons == []
|
||||
|
||||
def test_nested_object_without_required_non_strict(self):
|
||||
"""Test that nested objects without 'required' remain NON_STRICT_ONLY."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"preferences": {
|
||||
"type": "object",
|
||||
"properties": {"theme": {"type": "string"}, "language": {"type": "string"}},
|
||||
# Missing "required" field in nested object
|
||||
"additionalProperties": False,
|
||||
},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
"required": ["name"], # Don't require preferences so it's not marked INVALID
|
||||
"additionalProperties": False,
|
||||
}
|
||||
},
|
||||
"required": ["user"],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
status, reasons = validate_complete_json_schema(schema)
|
||||
assert status == SchemaHealth.NON_STRICT_ONLY
|
||||
# Should have warning about nested preferences object missing 'required'
|
||||
assert any("required" in reason and "preferences" in reason for reason in reasons)
|
||||
Reference in New Issue
Block a user