fix: fix mcp for complex schemas and add tests (#5063)

This commit is contained in:
Sarah Wooders
2025-10-01 14:57:44 -07:00
committed by Caren Thomas
parent d7b2d3c6ba
commit 7b73b25a95
5 changed files with 1013 additions and 6 deletions

View File

@@ -588,11 +588,111 @@ def generate_schema_from_args_schema_v2(
return function_call_json
def normalize_mcp_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
"""
Normalize an MCP JSON schema to fix common issues:
1. Add explicit 'additionalProperties': false to all object types
2. Add explicit 'type' field to properties using $ref
3. Process $defs recursively
Args:
schema: The JSON schema to normalize (will be modified in-place)
Returns:
The normalized schema (same object, modified in-place)
"""
import copy
# Work on a deep copy to avoid modifying the original
schema = copy.deepcopy(schema)
def normalize_object_schema(obj_schema: Dict[str, Any], defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Recursively normalize an object schema."""
# If this is an object type, add additionalProperties if missing
if obj_schema.get("type") == "object":
if "additionalProperties" not in obj_schema:
obj_schema["additionalProperties"] = False
# Handle properties
if "properties" in obj_schema:
for prop_name, prop_schema in obj_schema["properties"].items():
# Handle $ref references
if "$ref" in prop_schema:
# Add explicit type based on the reference
if "type" not in prop_schema:
# Try to resolve the type from $defs if available
if defs and prop_schema["$ref"].startswith("#/$defs/"):
def_name = prop_schema["$ref"].split("/")[-1]
if def_name in defs:
ref_schema = defs[def_name]
if "type" in ref_schema:
prop_schema["type"] = ref_schema["type"]
# If still no type, assume object (common case for model references)
if "type" not in prop_schema:
prop_schema["type"] = "object"
# Don't add additionalProperties to properties with $ref
# The $ref schema itself will have additionalProperties
# Adding it here makes the validator think it allows empty objects
continue
# Recursively normalize nested objects
if isinstance(prop_schema, dict):
if prop_schema.get("type") == "object":
normalize_object_schema(prop_schema, defs)
# Handle arrays with object items
if prop_schema.get("type") == "array" and "items" in prop_schema:
items = prop_schema["items"]
if isinstance(items, dict):
# Handle $ref in items
if "$ref" in items and "type" not in items:
if defs and items["$ref"].startswith("#/$defs/"):
def_name = items["$ref"].split("/")[-1]
if def_name in defs and "type" in defs[def_name]:
items["type"] = defs[def_name]["type"]
if "type" not in items:
items["type"] = "object"
# Recursively normalize items
if items.get("type") == "object":
normalize_object_schema(items, defs)
# Handle anyOf (complex union types)
if "anyOf" in prop_schema:
for option in prop_schema["anyOf"]:
if isinstance(option, dict) and option.get("type") == "object":
normalize_object_schema(option, defs)
# Handle array items at the top level
if "items" in obj_schema and isinstance(obj_schema["items"], dict):
if obj_schema["items"].get("type") == "object":
normalize_object_schema(obj_schema["items"], defs)
return obj_schema
# Process $defs first if they exist
defs = schema.get("$defs", {})
if defs:
for def_name, def_schema in defs.items():
if isinstance(def_schema, dict):
normalize_object_schema(def_schema, defs)
# Process the main schema
normalize_object_schema(schema, defs)
return schema
def generate_tool_schema_for_mcp(
mcp_tool: MCPTool,
append_heartbeat: bool = True,
strict: bool = False,
) -> Dict[str, Any]:
from letta.functions.schema_validator import validate_complete_json_schema
# MCP tool.inputSchema is a JSON schema
# https://github.com/modelcontextprotocol/python-sdk/blob/775f87981300660ee957b63c2a14b448ab9c3675/src/mcp/types.py#L678
parameters_schema = mcp_tool.inputSchema
@@ -603,11 +703,16 @@ def generate_tool_schema_for_mcp(
assert "properties" in parameters_schema, parameters_schema
# assert "required" in parameters_schema, parameters_schema
# Normalize the schema to fix common issues with MCP schemas
# This adds additionalProperties: false and explicit types for $ref properties
parameters_schema = normalize_mcp_schema(parameters_schema)
# Zero-arg tools often omit "required" because nothing is required.
# Normalise so downstream code can treat it consistently.
parameters_schema.setdefault("required", [])
# Process properties to handle anyOf types and make optional fields strict-compatible
# TODO: de-duplicate with handling in normalize_mcp_schema
if "properties" in parameters_schema:
for field_name, field_props in parameters_schema["properties"].items():
# Handle anyOf types by flattening to type array
@@ -660,6 +765,14 @@ def generate_tool_schema_for_mcp(
if REQUEST_HEARTBEAT_PARAM not in parameters_schema["required"]:
parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM)
# Re-validate the schema after normalization and update the health status
# This allows previously INVALID schemas to pass if normalization fixed them
if mcp_tool.health:
health_status, health_reasons = validate_complete_json_schema(parameters_schema)
mcp_tool.health.status = health_status.value
mcp_tool.health.reasons = health_reasons
logger.debug(f"MCP tool {name} schema health after normalization: {health_status.value}, reasons: {health_reasons}")
# Return the final schema
if strict:
# https://platform.openai.com/docs/guides/function-calling#strict-mode

View File

@@ -140,12 +140,44 @@ class MCPManager:
for mcp_tool in mcp_tools:
# TODO: @jnjpng move health check to tool class
if mcp_tool.name == mcp_tool_name:
# Check tool health - reject only INVALID tools
if mcp_tool.health:
if mcp_tool.health.status == "INVALID":
raise ValueError(
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid.Reasons: {', '.join(mcp_tool.health.reasons)}"
)
# Check tool health - but try normalization first for INVALID schemas
if mcp_tool.health and mcp_tool.health.status == "INVALID":
logger.info(f"Attempting to normalize INVALID schema for tool {mcp_tool_name}")
logger.info(f"Original health reasons: {mcp_tool.health.reasons}")
# Try to normalize the schema and re-validate
from letta.functions.schema_generator import normalize_mcp_schema
from letta.functions.schema_validator import validate_complete_json_schema
try:
# Normalize the schema to fix common issues
logger.debug(f"Normalizing schema for {mcp_tool_name}")
normalized_schema = normalize_mcp_schema(mcp_tool.inputSchema)
# Re-validate after normalization
logger.debug(f"Re-validating schema for {mcp_tool_name}")
health_status, health_reasons = validate_complete_json_schema(normalized_schema)
logger.info(f"After normalization: status={health_status.value}, reasons={health_reasons}")
# Update the tool's schema and health (use inputSchema, not input_schema)
mcp_tool.inputSchema = normalized_schema
mcp_tool.health.status = health_status.value
mcp_tool.health.reasons = health_reasons
# Log the normalization result
if health_status.value != "INVALID":
logger.info(f"✓ MCP tool {mcp_tool_name} schema normalized successfully: {health_status.value}")
else:
logger.warning(f"MCP tool {mcp_tool_name} still INVALID after normalization. Reasons: {health_reasons}")
except Exception as e:
logger.error(f"Failed to normalize schema for tool {mcp_tool_name}: {e}", exc_info=True)
# After normalization attempt, check if still INVALID
if mcp_tool.health and mcp_tool.health.status == "INVALID":
raise ValueError(
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid even after normalization. "
f"Reasons: {', '.join(mcp_tool.health.reasons)}"
)
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
return await self.tool_manager.create_mcp_tool_async(

View File

@@ -0,0 +1,396 @@
import os
import sys
import threading
import time
import uuid
from pathlib import Path
import pytest
import requests
from dotenv import load_dotenv
from letta_client import Letta, MessageCreate, ToolCallMessage, ToolReturnMessage
from letta.functions.mcp_client.types import StdioServerConfig
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
yield url
@pytest.fixture(scope="module")
def client(server_url: str) -> Letta:
"""
Creates and returns a synchronous Letta REST client for testing.
"""
client_instance = Letta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
def mcp_server_name() -> str:
"""Generate a unique MCP server name for each test."""
return f"test-mcp-server-{uuid.uuid4().hex[:8]}"
@pytest.fixture(scope="function")
def mock_mcp_server_config(mcp_server_name: str) -> StdioServerConfig:
"""
Creates a stdio configuration for the mock MCP server.
"""
# Get path to mock_mcp_server.py
script_dir = Path(__file__).parent
mcp_server_path = script_dir / "mock_mcp_server.py"
if not mcp_server_path.exists():
raise FileNotFoundError(f"Mock MCP server not found at {mcp_server_path}")
return StdioServerConfig(
server_name=mcp_server_name,
command=sys.executable, # Use the current Python interpreter
args=[str(mcp_server_path)],
)
@pytest.fixture(scope="function")
def agent_state(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig) -> AgentState:
"""
Creates an agent with MCP tools attached for testing.
"""
# Register the MCP server
client.tools.add_mcp_server(request=mock_mcp_server_config)
# Verify server is registered
servers = client.tools.list_mcp_servers()
assert mcp_server_name in servers, f"MCP server {mcp_server_name} not found in {servers}"
# List available MCP tools
mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
assert len(mcp_tools) > 0, "No tools found from MCP server"
# Add the echo and add tools to Letta
echo_tool = next((t for t in mcp_tools if t.name == "echo"), None)
add_tool = next((t for t in mcp_tools if t.name == "add"), None)
assert echo_tool is not None, "echo tool not found"
assert add_tool is not None, "add tool not found"
letta_echo_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="echo")
letta_add_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="add")
# Create agent with the MCP tools
agent = client.agents.create(
name=f"test_mcp_agent_{uuid.uuid4().hex[:8]}",
include_base_tools=True,
tool_ids=[letta_echo_tool.id, letta_add_tool.id],
memory_blocks=[
{
"label": "human",
"value": "Name: Test User",
},
{
"label": "persona",
"value": "You are a helpful assistant that can use MCP tools to help the user.",
},
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tags=["test_mcp_agent"],
)
yield agent
# Cleanup
try:
client.agents.delete(agent.id)
except Exception as e:
print(f"Warning: Failed to delete agent {agent.id}: {e}")
try:
client.tools.delete_mcp_server(mcp_server_name=mcp_server_name)
except Exception as e:
print(f"Warning: Failed to delete MCP server {mcp_server_name}: {e}")
# ------------------------------
# Test Cases
# ------------------------------
def test_mcp_echo_tool(client: Letta, agent_state: AgentState):
"""
Test that an agent can successfully call the echo tool from the MCP server.
"""
test_message = "Hello from MCP integration test!"
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
role="user",
content=f"Use the echo tool to echo back this exact message: '{test_message}'",
)
],
)
# Check for tool call message
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage"
# Find the echo tool call
echo_call = next((m for m in tool_calls if m.tool_call.name == "echo"), None)
assert echo_call is not None, f"No echo tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}"
# Check for tool return message
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
# Find the return for the echo call
echo_return = next((m for m in tool_returns if m.tool_call_id == echo_call.tool_call.tool_call_id), None)
assert echo_return is not None, "No tool return found for echo call"
assert echo_return.status == "success", f"Echo tool failed with status: {echo_return.status}"
# Verify the echo response contains our message
assert test_message in echo_return.tool_return, f"Expected '{test_message}' in tool return, got: {echo_return.tool_return}"
def test_mcp_add_tool(client: Letta, agent_state: AgentState):
"""
Test that an agent can successfully call the add tool from the MCP server.
"""
a, b = 42, 58
expected_sum = a + b
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
role="user",
content=f"Use the add tool to add {a} and {b}.",
)
],
)
# Check for tool call message
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage"
# Find the add tool call
add_call = next((m for m in tool_calls if m.tool_call.name == "add"), None)
assert add_call is not None, f"No add tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}"
# Check for tool return message
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
# Find the return for the add call
add_return = next((m for m in tool_returns if m.tool_call_id == add_call.tool_call.tool_call_id), None)
assert add_return is not None, "No tool return found for add call"
assert add_return.status == "success", f"Add tool failed with status: {add_return.status}"
# Verify the result contains the expected sum
assert str(expected_sum) in add_return.tool_return, f"Expected '{expected_sum}' in tool return, got: {add_return.tool_return}"
def test_mcp_multiple_tools_in_sequence(client: Letta, agent_state: AgentState):
"""
Test that an agent can call multiple MCP tools in sequence.
"""
response = client.agents.messages.create(
agent_id=agent_state.id,
messages=[
MessageCreate(
role="user",
content="First use the add tool to add 10 and 20. Then use the echo tool to echo back the result you got from the add tool.",
)
],
)
# Check for tool call messages
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
assert len(tool_calls) >= 2, f"Expected at least 2 tool calls, got {len(tool_calls)}"
# Verify both tools were called
tool_names = [m.tool_call.name for m in tool_calls]
assert "add" in tool_names, f"add tool not called. Tools called: {tool_names}"
assert "echo" in tool_names, f"echo tool not called. Tools called: {tool_names}"
# Check for tool return messages
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert len(tool_returns) >= 2, f"Expected at least 2 tool returns, got {len(tool_returns)}"
# Verify all tools succeeded
for tool_return in tool_returns:
assert tool_return.status == "success", f"Tool call failed with status: {tool_return.status}"
def test_mcp_server_listing(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig):
"""
Test that MCP server registration and tool listing works correctly.
"""
# Register the MCP server
client.tools.add_mcp_server(request=mock_mcp_server_config)
try:
# Verify server is in the list
servers = client.tools.list_mcp_servers()
assert mcp_server_name in servers, f"MCP server {mcp_server_name} not found in {servers}"
# List available tools
mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
assert len(mcp_tools) > 0, "No tools found from MCP server"
# Verify expected tools are present
tool_names = [t.name for t in mcp_tools]
expected_tools = ["echo", "add", "multiply", "reverse_string"]
for expected_tool in expected_tools:
assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found. Available: {tool_names}"
finally:
# Cleanup
client.tools.delete_mcp_server(mcp_server_name=mcp_server_name)
servers = client.tools.list_mcp_servers()
assert mcp_server_name not in servers, f"MCP server {mcp_server_name} should be deleted but is still in {servers}"
def test_mcp_complex_schema_tool(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig):
"""
Test that an agent can successfully call a tool with complex nested schema.
This tests the get_parameter_type_description tool which has:
- Enum-like preset parameter
- Optional string field
- Optional nested object with arrays of objects
"""
# Register the MCP server
client.tools.add_mcp_server(request=mock_mcp_server_config)
try:
# List available tools
mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
# Find the complex schema tool
complex_tool = next((t for t in mcp_tools if t.name == "get_parameter_type_description"), None)
assert complex_tool is not None, f"get_parameter_type_description tool not found. Available: {[t.name for t in mcp_tools]}"
# Add it to Letta
letta_complex_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="get_parameter_type_description")
# Create agent with the complex tool
agent = client.agents.create(
name=f"test_complex_schema_{uuid.uuid4().hex[:8]}",
include_base_tools=True,
tool_ids=[letta_complex_tool.id],
memory_blocks=[
{
"label": "human",
"value": "Name: Test User",
},
{
"label": "persona",
"value": "You are a helpful assistant that can use MCP tools with complex schemas.",
},
],
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
tags=["test_complex_schema"],
)
# Test 1: Simple call with just preset
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
role="user", content='Use the get_parameter_type_description tool with preset "a" to get parameter information.'
)
],
)
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage"
complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None)
assert complex_call is not None, f"No get_parameter_type_description call found. Calls: {[m.tool_call.name for m in tool_calls]}"
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None)
assert complex_return is not None, "No tool return found for complex schema call"
assert complex_return.status == "success", f"Complex schema tool failed with status: {complex_return.status}"
assert "Preset: a" in complex_return.tool_return, f"Expected 'Preset: a' in return, got: {complex_return.tool_return}"
# Test 2: Complex call with nested data
response = client.agents.messages.create(
agent_id=agent.id,
messages=[
MessageCreate(
role="user",
content="Use the get_parameter_type_description tool with these arguments: "
'preset="b", connected_service_descriptor="test-service", '
"and instantiation_data with isAbstract=true, isMultiplicity=false, "
'and one instantiation with doid="TEST123" and nodeFamilyId=42.',
)
],
)
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage for complex nested call"
complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None)
assert complex_call is not None, "No get_parameter_type_description call found for nested test"
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None)
assert complex_return is not None, "No tool return found for complex nested call"
assert complex_return.status == "success", f"Complex nested call failed with status: {complex_return.status}"
# Verify the response contains our complex data
assert "Preset: b" in complex_return.tool_return, "Expected preset 'b' in response"
assert "test-service" in complex_return.tool_return, "Expected service descriptor in response"
# Cleanup agent
client.agents.delete(agent.id)
finally:
# Cleanup MCP server
client.tools.delete_mcp_server(mcp_server_name=mcp_server_name)

View File

@@ -320,6 +320,189 @@ async def test_create_mcp_server_with_tools(mock_get_client, server, default_use
assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted"
@pytest.mark.asyncio
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_complex_schema_normalization(mock_get_client, server, default_user):
"""Test that complex MCP schemas with nested objects are normalized and accepted."""
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
from letta.schemas.mcp import MCPServer, MCPServerType
from letta.settings import tool_settings
if tool_settings.mcp_read_from_config:
return
# Create mock tools with complex schemas that would normally be INVALID
# These schemas have: nested $defs, $ref references, missing additionalProperties
mock_tools = [
# 1. Nested object with $ref (like create_person)
MCPTool(
name="create_person",
description="Create a person with nested address",
inputSchema={
"$defs": {
"Address": {
"type": "object",
"properties": {
"street": {"type": "string"},
"city": {"type": "string"},
"zip_code": {"type": "string"},
},
"required": ["street", "city", "zip_code"],
},
"Person": {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"address": {"$ref": "#/$defs/Address"},
},
"required": ["name", "age"],
},
},
"type": "object",
"properties": {"person": {"$ref": "#/$defs/Person"}},
"required": ["person"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.person: Missing 'type'"],
),
),
# 2. List of objects (like manage_tasks)
MCPTool(
name="manage_tasks",
description="Manage multiple tasks",
inputSchema={
"$defs": {
"TaskItem": {
"type": "object",
"properties": {
"title": {"type": "string"},
"priority": {"type": "integer", "default": 1},
"completed": {"type": "boolean", "default": False},
"tags": {"type": "array", "items": {"type": "string"}},
},
"required": ["title"],
}
},
"type": "object",
"properties": {
"tasks": {
"type": "array",
"items": {"$ref": "#/$defs/TaskItem"},
}
},
"required": ["tasks"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.tasks.items: Missing 'type'"],
),
),
# 3. Complex filter object with optional fields
MCPTool(
name="search_with_filters",
description="Search with complex filters",
inputSchema={
"$defs": {
"SearchFilter": {
"type": "object",
"properties": {
"keywords": {"type": "array", "items": {"type": "string"}},
"min_score": {"type": "number"},
"categories": {"type": "array", "items": {"type": "string"}},
},
"required": ["keywords"],
}
},
"type": "object",
"properties": {
"query": {"type": "string"},
"filters": {"$ref": "#/$defs/SearchFilter"},
},
"required": ["query", "filters"],
},
health=MCPToolHealth(
status="INVALID",
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.filters: Missing 'type'"],
),
),
]
# Create mock client
mock_client = AsyncMock()
mock_client.connect_to_server = AsyncMock()
mock_client.list_tools = AsyncMock(return_value=mock_tools)
mock_client.cleanup = AsyncMock()
mock_get_client.return_value = mock_client
# Create MCP server
server_name = f"test_complex_schema_{uuid.uuid4().hex[:8]}"
server_url = "https://test-complex.example.com/sse"
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
try:
# Create server (this will auto-sync tools)
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
assert created_server.server_name == server_name
# Now attempt to add each tool - they should be normalized from INVALID to acceptable
# The normalization happens in add_tool_from_mcp_server
# Test 1: create_person should normalize successfully
person_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "create_person", actor=default_user)
assert person_tool is not None
assert person_tool.name == "create_person"
# Verify the schema has additionalProperties set
assert person_tool.json_schema["parameters"]["additionalProperties"] == False
# Verify nested $defs have additionalProperties
if "$defs" in person_tool.json_schema["parameters"]:
for def_name, def_schema in person_tool.json_schema["parameters"]["$defs"].items():
if def_schema.get("type") == "object":
assert "additionalProperties" in def_schema, f"$defs.{def_name} missing additionalProperties after normalization"
# Test 2: manage_tasks should normalize successfully
tasks_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "manage_tasks", actor=default_user)
assert tasks_tool is not None
assert tasks_tool.name == "manage_tasks"
# Verify array items have explicit type
tasks_prop = tasks_tool.json_schema["parameters"]["properties"]["tasks"]
assert "items" in tasks_prop
assert "type" in tasks_prop["items"], "Array items should have explicit type after normalization"
# Test 3: search_with_filters should normalize successfully
search_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "search_with_filters", actor=default_user)
assert search_tool is not None
assert search_tool.name == "search_with_filters"
# Verify all tools were persisted
all_tools = await server.tool_manager.list_tools_async(
actor=default_user, names=["create_person", "manage_tasks", "search_with_filters"]
)
# Filter to tools from our MCP server
mcp_tools = [
tool
for tool in all_tools
if tool.metadata_
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
]
# All 3 complex schema tools should have been normalized and persisted
assert len(mcp_tools) == 3, f"Expected 3 normalized tools, got {len(mcp_tools)}"
# Verify they all have the correct MCP metadata
for tool in mcp_tools:
assert tool.tool_type == ToolType.EXTERNAL_MCP
assert f"mcp:{server_name}" in tool.tags
finally:
# Clean up
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user):
"""Test that MCP server creation succeeds even when tool sync fails (optimistic approach)."""

283
tests/mock_mcp_server.py Executable file
View File

@@ -0,0 +1,283 @@
#!/usr/bin/env python3
"""
Simple MCP test server with basic and complex tools for testing purposes.
"""
import json
import logging
from typing import List, Optional, Union
from mcp.server.fastmcp import FastMCP
from pydantic import BaseModel, ConfigDict, Field
# Configure logging to stderr (not stdout for STDIO servers)
logging.basicConfig(level=logging.INFO)
# Initialize FastMCP server
mcp = FastMCP("test-server")
# Complex Pydantic models for testing
class Address(BaseModel):
"""An address with street, city, and zip code."""
street: str = Field(..., description="Street address")
city: str = Field(..., description="City name")
zip_code: str = Field(..., description="ZIP code")
country: str = Field(default="USA", description="Country name")
class Person(BaseModel):
"""A person with name, age, and optional address."""
name: str = Field(..., description="Person's full name")
age: int = Field(..., description="Person's age", ge=0, le=150)
email: Optional[str] = Field(None, description="Email address")
address: Optional[Address] = Field(None, description="Home address")
class TaskItem(BaseModel):
"""A task item with title, priority, and completion status."""
title: str = Field(..., description="Task title")
priority: int = Field(default=1, description="Priority level (1-5)", ge=1, le=5)
completed: bool = Field(default=False, description="Whether the task is completed")
tags: List[str] = Field(default_factory=list, description="List of tags")
class SearchFilter(BaseModel):
"""Filter criteria for searching."""
keywords: List[str] = Field(..., description="List of keywords to search for")
min_score: Optional[float] = Field(None, description="Minimum score threshold", ge=0.0, le=1.0)
categories: Optional[List[str]] = Field(None, description="Categories to filter by")
# Customer-reported schema models (matching mcp_schema.json pattern)
class Instantiation(BaseModel):
"""Instantiation object with optional node identifiers."""
# model_config = ConfigDict(json_schema_extra={"additionalProperties": False})
doid: Optional[str] = Field(None, description="DOID identifier")
nodeFamilyId: Optional[int] = Field(None, description="Node family ID")
nodeTypeId: Optional[int] = Field(None, description="Node type ID")
nodePositionId: Optional[int] = Field(None, description="Node position ID")
class InstantiationData(BaseModel):
"""Instantiation data with abstract and multiplicity flags."""
# model_config = ConfigDict(json_schema_extra={"additionalProperties": False})
isAbstract: Optional[bool] = Field(None, description="Whether the instantiation is abstract")
isMultiplicity: Optional[bool] = Field(None, description="Whether the instantiation has multiplicity")
instantiations: List[Instantiation] = Field(None, description="List of instantiations")
class ParameterPreset(BaseModel):
"""Parameter preset enum values."""
value: str = Field(..., description="Preset value (a, b, c, e, f, g, h, i, d, l, s, m, z, o, u, unknown)")
@mcp.tool()
async def echo(message: str) -> str:
"""Echo back the provided message.
Args:
message: The message to echo back
"""
return f"Echo: {message}"
@mcp.tool()
async def add(a: float, b: float) -> str:
"""Add two numbers together.
Args:
a: First number
b: Second number
"""
result = a + b
return f"{a} + {b} = {result}"
@mcp.tool()
async def multiply(a: float, b: float) -> str:
"""Multiply two numbers together.
Args:
a: First number
b: Second number
"""
result = a * b
return f"{a} × {b} = {result}"
@mcp.tool()
async def reverse_string(text: str) -> str:
"""Reverse a string.
Args:
text: The string to reverse
"""
return text[::-1]
# Complex tools using Pydantic models
@mcp.tool()
async def create_person(person: Person) -> str:
"""Create a person profile with nested address information.
Args:
person: Person object with name, age, optional email and address
"""
result = "Created person profile:\n"
result += f" Name: {person.name}\n"
result += f" Age: {person.age}\n"
if person.email:
result += f" Email: {person.email}\n"
if person.address:
result += " Address:\n"
result += f" {person.address.street}\n"
result += f" {person.address.city}, {person.address.zip_code}\n"
result += f" {person.address.country}\n"
return result
@mcp.tool()
async def manage_tasks(tasks: List[TaskItem]) -> str:
"""Manage multiple tasks with priorities and tags.
Args:
tasks: List of task items to manage
"""
if not tasks:
return "No tasks provided"
result = f"Managing {len(tasks)} task(s):\n\n"
for i, task in enumerate(tasks, 1):
status = "" if task.completed else ""
result += f"{i}. [{status}] {task.title}\n"
result += f" Priority: {task.priority}/5\n"
if task.tags:
result += f" Tags: {', '.join(task.tags)}\n"
result += "\n"
completed = sum(1 for t in tasks if t.completed)
result += f"Summary: {completed}/{len(tasks)} completed"
return result
@mcp.tool()
async def search_with_filters(query: str, filters: SearchFilter) -> str:
"""Search with complex filter criteria including keywords and categories.
Args:
query: The main search query
filters: Complex filter object with keywords, score threshold, and categories
"""
result = f"Search Query: '{query}'\n\n"
result += "Filters Applied:\n"
result += f" Keywords: {', '.join(filters.keywords)}\n"
if filters.min_score is not None:
result += f" Minimum Score: {filters.min_score}\n"
if filters.categories:
result += f" Categories: {', '.join(filters.categories)}\n"
# Simulate search results
result += "\nFound 3 results matching criteria:\n"
result += f" 1. Result matching '{filters.keywords[0]}' (score: 0.95)\n"
result += f" 2. Result matching '{query}' (score: 0.87)\n"
result += " 3. Result matching multiple keywords (score: 0.82)\n"
return result
@mcp.tool()
async def process_nested_data(data: dict) -> str:
"""Process arbitrary nested dictionary data.
Args:
data: Nested dictionary with arbitrary structure
"""
result = "Processing nested data:\n"
result += json.dumps(data, indent=2)
result += "\n\nData structure stats:\n"
result += f" Keys at root level: {len(data)}\n"
def count_nested_items(obj, depth=0):
count = 0
max_depth = depth
if isinstance(obj, dict):
for v in obj.values():
sub_count, sub_depth = count_nested_items(v, depth + 1)
count += sub_count + 1
max_depth = max(max_depth, sub_depth)
elif isinstance(obj, list):
for item in obj:
sub_count, sub_depth = count_nested_items(item, depth + 1)
count += sub_count + 1
max_depth = max(max_depth, sub_depth)
return count, max_depth
total_items, max_depth = count_nested_items(data)
result += f" Total nested items: {total_items}\n"
result += f" Maximum nesting depth: {max_depth}\n"
return result
@mcp.tool()
async def get_parameter_type_description(
preset: str,
instantiation_data: InstantiationData,
connected_service_descriptor: Optional[str] = None,
) -> str:
"""Get parameter type description with complex nested structure.
This tool matches the customer-reported schema pattern with:
- Enum-like preset parameter
- Optional string field
- Optional nested object with arrays of objects
Args:
preset: The parameter preset (a, b, c, e, f, g, h, i, d, l, s, m, z, o, u, unknown)
connected_service_descriptor: Connected service descriptor string, if available
instantiation_data: Instantiation data dict with isAbstract, isMultiplicity, and instantiations list
"""
result = "Parameter Type Description\n"
result += "=" * 50 + "\n\n"
result += f"Preset: {preset}\n\n"
if connected_service_descriptor:
result += f"Connected Service: {connected_service_descriptor}\n\n"
if instantiation_data:
result += "Instantiation Data:\n"
result += f" Is Abstract: {instantiation_data.isAbstract}\n"
result += f" Is Multiplicity: {instantiation_data.isMultiplicity}\n"
result += f" Instantiations: {instantiation_data.instantiations}\n"
return result
def main():
# Initialize and run the server
mcp.run(transport="stdio")
if __name__ == "__main__":
main()