diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index b1d8e9e6..a300dbb2 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -723,6 +723,73 @@ def generate_tool_schema_for_mcp( # Get $defs for $ref resolution defs = parameters_schema.get("$defs", {}) + def deduplicate_anyof(anyof_list): + """ + Deduplicate entries in an anyOf array based on their content. + + Rules: + 1. Remove exact duplicates (same type, same properties) + 2. For duplicate types with different metadata (e.g., format): + - Keep the most specific version (with format/constraints) + - If one has format and others don't, keep only the one with format + """ + if not anyof_list: + return anyof_list + + seen = [] + result = [] + + for item in anyof_list: + if not isinstance(item, dict): + if item not in seen: + seen.append(item) + result.append(item) + continue + + # Create a hashable representation for comparison + # Sort keys to ensure consistent comparison + item_type = item.get("type") + item_format = item.get("format") + + # Check if we've seen this exact item + is_duplicate = False + for existing_idx, existing in enumerate(result): + if not isinstance(existing, dict): + continue + + existing_type = existing.get("type") + existing_format = existing.get("format") + + # Exact match - skip this item + if item == existing: + is_duplicate = True + break + + # Same type with different format handling + if item_type and item_type == existing_type: + # Both have same type + if item_format and not existing_format: + # New item has format, existing doesn't - replace existing with new + result[existing_idx] = item + is_duplicate = True + break + elif not item_format and existing_format: + # Existing has format, new doesn't - keep existing, skip new + is_duplicate = True + break + elif item_format == existing_format: + # Same type and format (or both None) - compare full objects + # Prefer the one with more properties/constraints + if len(item) >= len(existing): + result[existing_idx] = item + is_duplicate = True + break + + if not is_duplicate: + result.append(item) + + return result + def inline_ref(schema_node, defs, depth=0, max_depth=10): """ Recursively inline all $ref references in a schema node. @@ -757,7 +824,10 @@ def generate_tool_schema_for_mcp( # Recursively process nested structures if "anyOf" in result: + # Inline refs in each anyOf option result["anyOf"] = [inline_ref(opt, defs, depth + 1, max_depth) for opt in result["anyOf"]] + # Deduplicate anyOf entries + result["anyOf"] = deduplicate_anyof(result["anyOf"]) if "properties" in result and isinstance(result["properties"], dict): result["properties"] = { prop_name: inline_ref(prop_schema, defs, depth + 1, max_depth) for prop_name, prop_schema in result["properties"].items() diff --git a/tests/mcp_tests/test_mcp_schema_validation.py b/tests/mcp_tests/test_mcp_schema_validation.py index c630d689..93eb021b 100644 --- a/tests/mcp_tests/test_mcp_schema_validation.py +++ b/tests/mcp_tests/test_mcp_schema_validation.py @@ -237,8 +237,12 @@ def test_mcp_schema_healing_with_anyof(): assert strict_schema["strict"] is True assert "a" in strict_schema["parameters"]["required"] assert "b" in strict_schema["parameters"]["required"] # Now required - # Type should be flattened array with deduplication - assert set(strict_schema["parameters"]["properties"]["b"]["type"]) == {"integer", "null"} + # anyOf should be preserved with integer and null types + b_prop = strict_schema["parameters"]["properties"]["b"] + assert "anyOf" in b_prop + assert len(b_prop["anyOf"]) == 2 + types_in_anyof = {opt.get("type") for opt in b_prop["anyOf"]} + assert types_in_anyof == {"integer", "null"} # Validate strict schema status, _ = validate_complete_json_schema(strict_schema["parameters"]) @@ -246,7 +250,7 @@ def test_mcp_schema_healing_with_anyof(): def test_mcp_schema_type_deduplication(): - """Test that duplicate types are deduplicated in schema generation.""" + """Test that anyOf duplicates are removed in schema generation.""" mcp_tool = MCPTool( name="test_tool", description="A test tool", @@ -270,10 +274,14 @@ def test_mcp_schema_type_deduplication(): # Generate strict schema strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True) - # Check that duplicates were removed - field_types = strict_schema["parameters"]["properties"]["field"]["type"] - assert len(field_types) == len(set(field_types)) # No duplicates - assert set(field_types) == {"string", "null"} + # Check that anyOf is preserved but duplicates are removed + field_prop = strict_schema["parameters"]["properties"]["field"] + assert "anyOf" in field_prop + types_in_anyof = [opt.get("type") for opt in field_prop["anyOf"]] + # Duplicates should be removed + assert len(types_in_anyof) == 2 # Deduplicated to 2 entries + assert types_in_anyof.count("string") == 1 # Only one string entry + assert types_in_anyof.count("null") == 1 # One null entry def test_mcp_schema_healing_preserves_existing_null(): @@ -333,7 +341,7 @@ def test_mcp_schema_healing_all_fields_already_required(): def test_mcp_schema_with_uuid_format(): - """Test handling of UUID format in anyOf schemas (root cause of duplicate string types).""" + """Test handling of UUID format in anyOf schemas (deduplicates but keeps format).""" mcp_tool = MCPTool( name="test_tool", description="A test tool with UUID formatted field", @@ -353,11 +361,17 @@ def test_mcp_schema_with_uuid_format(): # Generate strict schema strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True) - # Check that string type is not duplicated + # Check that anyOf is preserved with deduplication session_props = strict_schema["parameters"]["properties"]["session_id"] - assert set(session_props["type"]) == {"string", "null"} # No duplicate strings - # Format should NOT be preserved because field is optional (has null type) - assert "format" not in session_props + assert "anyOf" in session_props + # Deduplication should keep the string with format (more specific) + assert len(session_props["anyOf"]) == 2 # Deduplicated: string (with format) + null + types_in_anyof = [opt.get("type") for opt in session_props["anyOf"]] + assert types_in_anyof.count("string") == 1 # Only one string entry (the one with format) + assert "null" in types_in_anyof + # Verify the string entry has the uuid format + string_entry = next(opt for opt in session_props["anyOf"] if opt.get("type") == "string") + assert string_entry.get("format") == "uuid", "UUID format should be preserved" # Should be in required array (healed) assert "session_id" in strict_schema["parameters"]["required"] @@ -408,7 +422,7 @@ def test_mcp_schema_healing_only_in_strict_mode(): def test_mcp_schema_with_uuid_format_required_field(): - """Test that UUID format is preserved for required fields that don't have null type.""" + """Test that UUID format is preserved and duplicates are removed for required fields.""" mcp_tool = MCPTool( name="test_tool", description="A test tool with required UUID formatted field", @@ -428,11 +442,18 @@ def test_mcp_schema_with_uuid_format_required_field(): # Generate strict schema strict_schema = generate_tool_schema_for_mcp(mcp_tool, append_heartbeat=False, strict=True) - # Check that string type is not duplicated and format IS preserved + # Check that anyOf is deduplicated, keeping the more specific version session_props = strict_schema["parameters"]["properties"]["session_id"] - assert session_props["type"] == ["string"] # No null, no duplicates - assert "format" in session_props - assert session_props["format"] == "uuid" # Format should be preserved for non-optional field + assert "anyOf" in session_props + # Deduplication should keep only the string with format (more specific) + assert len(session_props["anyOf"]) == 1 # Deduplicated to 1 entry + types_in_anyof = [opt.get("type") for opt in session_props["anyOf"]] + assert types_in_anyof.count("string") == 1 # Only one string entry + assert "null" not in types_in_anyof # No null since it's required + # UUID format should be preserved + string_entry = session_props["anyOf"][0] + assert string_entry.get("type") == "string" + assert string_entry.get("format") == "uuid", "UUID format should be preserved" # Should be in required array assert "session_id" in strict_schema["parameters"]["required"]