diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 545f8873..808bb622 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -588,11 +588,111 @@ def generate_schema_from_args_schema_v2( return function_call_json +def normalize_mcp_schema(schema: Dict[str, Any]) -> Dict[str, Any]: + """ + Normalize an MCP JSON schema to fix common issues: + 1. Add explicit 'additionalProperties': false to all object types + 2. Add explicit 'type' field to properties using $ref + 3. Process $defs recursively + + Args: + schema: The JSON schema to normalize (will be modified in-place) + + Returns: + The normalized schema (same object, modified in-place) + """ + import copy + + # Work on a deep copy to avoid modifying the original + schema = copy.deepcopy(schema) + + def normalize_object_schema(obj_schema: Dict[str, Any], defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + """Recursively normalize an object schema.""" + + # If this is an object type, add additionalProperties if missing + if obj_schema.get("type") == "object": + if "additionalProperties" not in obj_schema: + obj_schema["additionalProperties"] = False + + # Handle properties + if "properties" in obj_schema: + for prop_name, prop_schema in obj_schema["properties"].items(): + # Handle $ref references + if "$ref" in prop_schema: + # Add explicit type based on the reference + if "type" not in prop_schema: + # Try to resolve the type from $defs if available + if defs and prop_schema["$ref"].startswith("#/$defs/"): + def_name = prop_schema["$ref"].split("/")[-1] + if def_name in defs: + ref_schema = defs[def_name] + if "type" in ref_schema: + prop_schema["type"] = ref_schema["type"] + + # If still no type, assume object (common case for model references) + if "type" not in prop_schema: + prop_schema["type"] = "object" + + # Don't add additionalProperties to properties with $ref + # The $ref schema itself will have additionalProperties + # Adding it here makes the validator think it allows empty objects + continue + + # Recursively normalize nested objects + if isinstance(prop_schema, dict): + if prop_schema.get("type") == "object": + normalize_object_schema(prop_schema, defs) + + # Handle arrays with object items + if prop_schema.get("type") == "array" and "items" in prop_schema: + items = prop_schema["items"] + if isinstance(items, dict): + # Handle $ref in items + if "$ref" in items and "type" not in items: + if defs and items["$ref"].startswith("#/$defs/"): + def_name = items["$ref"].split("/")[-1] + if def_name in defs and "type" in defs[def_name]: + items["type"] = defs[def_name]["type"] + if "type" not in items: + items["type"] = "object" + + # Recursively normalize items + if items.get("type") == "object": + normalize_object_schema(items, defs) + + # Handle anyOf (complex union types) + if "anyOf" in prop_schema: + for option in prop_schema["anyOf"]: + if isinstance(option, dict) and option.get("type") == "object": + normalize_object_schema(option, defs) + + # Handle array items at the top level + if "items" in obj_schema and isinstance(obj_schema["items"], dict): + if obj_schema["items"].get("type") == "object": + normalize_object_schema(obj_schema["items"], defs) + + return obj_schema + + # Process $defs first if they exist + defs = schema.get("$defs", {}) + if defs: + for def_name, def_schema in defs.items(): + if isinstance(def_schema, dict): + normalize_object_schema(def_schema, defs) + + # Process the main schema + normalize_object_schema(schema, defs) + + return schema + + def generate_tool_schema_for_mcp( mcp_tool: MCPTool, append_heartbeat: bool = True, strict: bool = False, ) -> Dict[str, Any]: + from letta.functions.schema_validator import validate_complete_json_schema + # MCP tool.inputSchema is a JSON schema # https://github.com/modelcontextprotocol/python-sdk/blob/775f87981300660ee957b63c2a14b448ab9c3675/src/mcp/types.py#L678 parameters_schema = mcp_tool.inputSchema @@ -603,11 +703,16 @@ def generate_tool_schema_for_mcp( assert "properties" in parameters_schema, parameters_schema # assert "required" in parameters_schema, parameters_schema + # Normalize the schema to fix common issues with MCP schemas + # This adds additionalProperties: false and explicit types for $ref properties + parameters_schema = normalize_mcp_schema(parameters_schema) + # Zero-arg tools often omit "required" because nothing is required. # Normalise so downstream code can treat it consistently. parameters_schema.setdefault("required", []) # Process properties to handle anyOf types and make optional fields strict-compatible + # TODO: de-duplicate with handling in normalize_mcp_schema if "properties" in parameters_schema: for field_name, field_props in parameters_schema["properties"].items(): # Handle anyOf types by flattening to type array @@ -660,6 +765,14 @@ def generate_tool_schema_for_mcp( if REQUEST_HEARTBEAT_PARAM not in parameters_schema["required"]: parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM) + # Re-validate the schema after normalization and update the health status + # This allows previously INVALID schemas to pass if normalization fixed them + if mcp_tool.health: + health_status, health_reasons = validate_complete_json_schema(parameters_schema) + mcp_tool.health.status = health_status.value + mcp_tool.health.reasons = health_reasons + logger.debug(f"MCP tool {name} schema health after normalization: {health_status.value}, reasons: {health_reasons}") + # Return the final schema if strict: # https://platform.openai.com/docs/guides/function-calling#strict-mode diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 89e320a3..67b1bf5e 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -140,12 +140,44 @@ class MCPManager: for mcp_tool in mcp_tools: # TODO: @jnjpng move health check to tool class if mcp_tool.name == mcp_tool_name: - # Check tool health - reject only INVALID tools - if mcp_tool.health: - if mcp_tool.health.status == "INVALID": - raise ValueError( - f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid.Reasons: {', '.join(mcp_tool.health.reasons)}" - ) + # Check tool health - but try normalization first for INVALID schemas + if mcp_tool.health and mcp_tool.health.status == "INVALID": + logger.info(f"Attempting to normalize INVALID schema for tool {mcp_tool_name}") + logger.info(f"Original health reasons: {mcp_tool.health.reasons}") + + # Try to normalize the schema and re-validate + from letta.functions.schema_generator import normalize_mcp_schema + from letta.functions.schema_validator import validate_complete_json_schema + + try: + # Normalize the schema to fix common issues + logger.debug(f"Normalizing schema for {mcp_tool_name}") + normalized_schema = normalize_mcp_schema(mcp_tool.inputSchema) + + # Re-validate after normalization + logger.debug(f"Re-validating schema for {mcp_tool_name}") + health_status, health_reasons = validate_complete_json_schema(normalized_schema) + logger.info(f"After normalization: status={health_status.value}, reasons={health_reasons}") + + # Update the tool's schema and health (use inputSchema, not input_schema) + mcp_tool.inputSchema = normalized_schema + mcp_tool.health.status = health_status.value + mcp_tool.health.reasons = health_reasons + + # Log the normalization result + if health_status.value != "INVALID": + logger.info(f"✓ MCP tool {mcp_tool_name} schema normalized successfully: {health_status.value}") + else: + logger.warning(f"MCP tool {mcp_tool_name} still INVALID after normalization. Reasons: {health_reasons}") + except Exception as e: + logger.error(f"Failed to normalize schema for tool {mcp_tool_name}: {e}", exc_info=True) + + # After normalization attempt, check if still INVALID + if mcp_tool.health and mcp_tool.health.status == "INVALID": + raise ValueError( + f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid even after normalization. " + f"Reasons: {', '.join(mcp_tool.health.reasons)}" + ) tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) return await self.tool_manager.create_mcp_tool_async( diff --git a/tests/integration_test_mcp.py b/tests/integration_test_mcp.py new file mode 100644 index 00000000..f3c3877e --- /dev/null +++ b/tests/integration_test_mcp.py @@ -0,0 +1,396 @@ +import os +import sys +import threading +import time +import uuid +from pathlib import Path + +import pytest +import requests +from dotenv import load_dotenv +from letta_client import Letta, MessageCreate, ToolCallMessage, ToolReturnMessage + +from letta.functions.mcp_client.types import StdioServerConfig +from letta.schemas.agent import AgentState +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.letta_message_content import TextContent +from letta.schemas.llm_config import LLMConfig + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it's accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + yield url + + +@pytest.fixture(scope="module") +def client(server_url: str) -> Letta: + """ + Creates and returns a synchronous Letta REST client for testing. + """ + client_instance = Letta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="function") +def mcp_server_name() -> str: + """Generate a unique MCP server name for each test.""" + return f"test-mcp-server-{uuid.uuid4().hex[:8]}" + + +@pytest.fixture(scope="function") +def mock_mcp_server_config(mcp_server_name: str) -> StdioServerConfig: + """ + Creates a stdio configuration for the mock MCP server. + """ + # Get path to mock_mcp_server.py + script_dir = Path(__file__).parent + mcp_server_path = script_dir / "mock_mcp_server.py" + + if not mcp_server_path.exists(): + raise FileNotFoundError(f"Mock MCP server not found at {mcp_server_path}") + + return StdioServerConfig( + server_name=mcp_server_name, + command=sys.executable, # Use the current Python interpreter + args=[str(mcp_server_path)], + ) + + +@pytest.fixture(scope="function") +def agent_state(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig) -> AgentState: + """ + Creates an agent with MCP tools attached for testing. + """ + # Register the MCP server + client.tools.add_mcp_server(request=mock_mcp_server_config) + + # Verify server is registered + servers = client.tools.list_mcp_servers() + assert mcp_server_name in servers, f"MCP server {mcp_server_name} not found in {servers}" + + # List available MCP tools + mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name) + assert len(mcp_tools) > 0, "No tools found from MCP server" + + # Add the echo and add tools to Letta + echo_tool = next((t for t in mcp_tools if t.name == "echo"), None) + add_tool = next((t for t in mcp_tools if t.name == "add"), None) + + assert echo_tool is not None, "echo tool not found" + assert add_tool is not None, "add tool not found" + + letta_echo_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="echo") + letta_add_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="add") + + # Create agent with the MCP tools + agent = client.agents.create( + name=f"test_mcp_agent_{uuid.uuid4().hex[:8]}", + include_base_tools=True, + tool_ids=[letta_echo_tool.id, letta_add_tool.id], + memory_blocks=[ + { + "label": "human", + "value": "Name: Test User", + }, + { + "label": "persona", + "value": "You are a helpful assistant that can use MCP tools to help the user.", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tags=["test_mcp_agent"], + ) + + yield agent + + # Cleanup + try: + client.agents.delete(agent.id) + except Exception as e: + print(f"Warning: Failed to delete agent {agent.id}: {e}") + + try: + client.tools.delete_mcp_server(mcp_server_name=mcp_server_name) + except Exception as e: + print(f"Warning: Failed to delete MCP server {mcp_server_name}: {e}") + + +# ------------------------------ +# Test Cases +# ------------------------------ + + +def test_mcp_echo_tool(client: Letta, agent_state: AgentState): + """ + Test that an agent can successfully call the echo tool from the MCP server. + """ + test_message = "Hello from MCP integration test!" + + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[ + MessageCreate( + role="user", + content=f"Use the echo tool to echo back this exact message: '{test_message}'", + ) + ], + ) + + # Check for tool call message + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage" + + # Find the echo tool call + echo_call = next((m for m in tool_calls if m.tool_call.name == "echo"), None) + assert echo_call is not None, f"No echo tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}" + + # Check for tool return message + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage" + + # Find the return for the echo call + echo_return = next((m for m in tool_returns if m.tool_call_id == echo_call.tool_call.tool_call_id), None) + assert echo_return is not None, "No tool return found for echo call" + assert echo_return.status == "success", f"Echo tool failed with status: {echo_return.status}" + + # Verify the echo response contains our message + assert test_message in echo_return.tool_return, f"Expected '{test_message}' in tool return, got: {echo_return.tool_return}" + + +def test_mcp_add_tool(client: Letta, agent_state: AgentState): + """ + Test that an agent can successfully call the add tool from the MCP server. + """ + a, b = 42, 58 + expected_sum = a + b + + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[ + MessageCreate( + role="user", + content=f"Use the add tool to add {a} and {b}.", + ) + ], + ) + + # Check for tool call message + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage" + + # Find the add tool call + add_call = next((m for m in tool_calls if m.tool_call.name == "add"), None) + assert add_call is not None, f"No add tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}" + + # Check for tool return message + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage" + + # Find the return for the add call + add_return = next((m for m in tool_returns if m.tool_call_id == add_call.tool_call.tool_call_id), None) + assert add_return is not None, "No tool return found for add call" + assert add_return.status == "success", f"Add tool failed with status: {add_return.status}" + + # Verify the result contains the expected sum + assert str(expected_sum) in add_return.tool_return, f"Expected '{expected_sum}' in tool return, got: {add_return.tool_return}" + + +def test_mcp_multiple_tools_in_sequence(client: Letta, agent_state: AgentState): + """ + Test that an agent can call multiple MCP tools in sequence. + """ + response = client.agents.messages.create( + agent_id=agent_state.id, + messages=[ + MessageCreate( + role="user", + content="First use the add tool to add 10 and 20. Then use the echo tool to echo back the result you got from the add tool.", + ) + ], + ) + + # Check for tool call messages + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) >= 2, f"Expected at least 2 tool calls, got {len(tool_calls)}" + + # Verify both tools were called + tool_names = [m.tool_call.name for m in tool_calls] + assert "add" in tool_names, f"add tool not called. Tools called: {tool_names}" + assert "echo" in tool_names, f"echo tool not called. Tools called: {tool_names}" + + # Check for tool return messages + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) >= 2, f"Expected at least 2 tool returns, got {len(tool_returns)}" + + # Verify all tools succeeded + for tool_return in tool_returns: + assert tool_return.status == "success", f"Tool call failed with status: {tool_return.status}" + + +def test_mcp_server_listing(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig): + """ + Test that MCP server registration and tool listing works correctly. + """ + # Register the MCP server + client.tools.add_mcp_server(request=mock_mcp_server_config) + + try: + # Verify server is in the list + servers = client.tools.list_mcp_servers() + assert mcp_server_name in servers, f"MCP server {mcp_server_name} not found in {servers}" + + # List available tools + mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name) + assert len(mcp_tools) > 0, "No tools found from MCP server" + + # Verify expected tools are present + tool_names = [t.name for t in mcp_tools] + expected_tools = ["echo", "add", "multiply", "reverse_string"] + for expected_tool in expected_tools: + assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found. Available: {tool_names}" + + finally: + # Cleanup + client.tools.delete_mcp_server(mcp_server_name=mcp_server_name) + servers = client.tools.list_mcp_servers() + assert mcp_server_name not in servers, f"MCP server {mcp_server_name} should be deleted but is still in {servers}" + + +def test_mcp_complex_schema_tool(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig): + """ + Test that an agent can successfully call a tool with complex nested schema. + This tests the get_parameter_type_description tool which has: + - Enum-like preset parameter + - Optional string field + - Optional nested object with arrays of objects + """ + # Register the MCP server + client.tools.add_mcp_server(request=mock_mcp_server_config) + + try: + # List available tools + mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name) + + # Find the complex schema tool + complex_tool = next((t for t in mcp_tools if t.name == "get_parameter_type_description"), None) + assert complex_tool is not None, f"get_parameter_type_description tool not found. Available: {[t.name for t in mcp_tools]}" + + # Add it to Letta + letta_complex_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="get_parameter_type_description") + + # Create agent with the complex tool + agent = client.agents.create( + name=f"test_complex_schema_{uuid.uuid4().hex[:8]}", + include_base_tools=True, + tool_ids=[letta_complex_tool.id], + memory_blocks=[ + { + "label": "human", + "value": "Name: Test User", + }, + { + "label": "persona", + "value": "You are a helpful assistant that can use MCP tools with complex schemas.", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tags=["test_complex_schema"], + ) + + # Test 1: Simple call with just preset + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", content='Use the get_parameter_type_description tool with preset "a" to get parameter information.' + ) + ], + ) + + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage" + + complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None) + assert complex_call is not None, f"No get_parameter_type_description call found. Calls: {[m.tool_call.name for m in tool_calls]}" + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage" + + complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None) + assert complex_return is not None, "No tool return found for complex schema call" + assert complex_return.status == "success", f"Complex schema tool failed with status: {complex_return.status}" + assert "Preset: a" in complex_return.tool_return, f"Expected 'Preset: a' in return, got: {complex_return.tool_return}" + + # Test 2: Complex call with nested data + response = client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content="Use the get_parameter_type_description tool with these arguments: " + 'preset="b", connected_service_descriptor="test-service", ' + "and instantiation_data with isAbstract=true, isMultiplicity=false, " + 'and one instantiation with doid="TEST123" and nodeFamilyId=42.', + ) + ], + ) + + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage for complex nested call" + + complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None) + assert complex_call is not None, "No get_parameter_type_description call found for nested test" + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None) + assert complex_return is not None, "No tool return found for complex nested call" + assert complex_return.status == "success", f"Complex nested call failed with status: {complex_return.status}" + + # Verify the response contains our complex data + assert "Preset: b" in complex_return.tool_return, "Expected preset 'b' in response" + assert "test-service" in complex_return.tool_return, "Expected service descriptor in response" + + # Cleanup agent + client.agents.delete(agent.id) + + finally: + # Cleanup MCP server + client.tools.delete_mcp_server(mcp_server_name=mcp_server_name) diff --git a/tests/managers/test_mcp_manager.py b/tests/managers/test_mcp_manager.py index d8666089..2aad82c0 100644 --- a/tests/managers/test_mcp_manager.py +++ b/tests/managers/test_mcp_manager.py @@ -320,6 +320,189 @@ async def test_create_mcp_server_with_tools(mock_get_client, server, default_use assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted" +@pytest.mark.asyncio +@patch("letta.services.mcp_manager.MCPManager.get_mcp_client") +async def test_complex_schema_normalization(mock_get_client, server, default_user): + """Test that complex MCP schemas with nested objects are normalized and accepted.""" + from letta.functions.mcp_client.types import MCPTool, MCPToolHealth + from letta.schemas.mcp import MCPServer, MCPServerType + from letta.settings import tool_settings + + if tool_settings.mcp_read_from_config: + return + + # Create mock tools with complex schemas that would normally be INVALID + # These schemas have: nested $defs, $ref references, missing additionalProperties + mock_tools = [ + # 1. Nested object with $ref (like create_person) + MCPTool( + name="create_person", + description="Create a person with nested address", + inputSchema={ + "$defs": { + "Address": { + "type": "object", + "properties": { + "street": {"type": "string"}, + "city": {"type": "string"}, + "zip_code": {"type": "string"}, + }, + "required": ["street", "city", "zip_code"], + }, + "Person": { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"}, + "address": {"$ref": "#/$defs/Address"}, + }, + "required": ["name", "age"], + }, + }, + "type": "object", + "properties": {"person": {"$ref": "#/$defs/Person"}}, + "required": ["person"], + }, + health=MCPToolHealth( + status="INVALID", + reasons=["root: 'additionalProperties' not explicitly set", "root.properties.person: Missing 'type'"], + ), + ), + # 2. List of objects (like manage_tasks) + MCPTool( + name="manage_tasks", + description="Manage multiple tasks", + inputSchema={ + "$defs": { + "TaskItem": { + "type": "object", + "properties": { + "title": {"type": "string"}, + "priority": {"type": "integer", "default": 1}, + "completed": {"type": "boolean", "default": False}, + "tags": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["title"], + } + }, + "type": "object", + "properties": { + "tasks": { + "type": "array", + "items": {"$ref": "#/$defs/TaskItem"}, + } + }, + "required": ["tasks"], + }, + health=MCPToolHealth( + status="INVALID", + reasons=["root: 'additionalProperties' not explicitly set", "root.properties.tasks.items: Missing 'type'"], + ), + ), + # 3. Complex filter object with optional fields + MCPTool( + name="search_with_filters", + description="Search with complex filters", + inputSchema={ + "$defs": { + "SearchFilter": { + "type": "object", + "properties": { + "keywords": {"type": "array", "items": {"type": "string"}}, + "min_score": {"type": "number"}, + "categories": {"type": "array", "items": {"type": "string"}}, + }, + "required": ["keywords"], + } + }, + "type": "object", + "properties": { + "query": {"type": "string"}, + "filters": {"$ref": "#/$defs/SearchFilter"}, + }, + "required": ["query", "filters"], + }, + health=MCPToolHealth( + status="INVALID", + reasons=["root: 'additionalProperties' not explicitly set", "root.properties.filters: Missing 'type'"], + ), + ), + ] + + # Create mock client + mock_client = AsyncMock() + mock_client.connect_to_server = AsyncMock() + mock_client.list_tools = AsyncMock(return_value=mock_tools) + mock_client.cleanup = AsyncMock() + mock_get_client.return_value = mock_client + + # Create MCP server + server_name = f"test_complex_schema_{uuid.uuid4().hex[:8]}" + server_url = "https://test-complex.example.com/sse" + mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url) + + try: + # Create server (this will auto-sync tools) + created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user) + + assert created_server.server_name == server_name + + # Now attempt to add each tool - they should be normalized from INVALID to acceptable + # The normalization happens in add_tool_from_mcp_server + + # Test 1: create_person should normalize successfully + person_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "create_person", actor=default_user) + assert person_tool is not None + assert person_tool.name == "create_person" + # Verify the schema has additionalProperties set + assert person_tool.json_schema["parameters"]["additionalProperties"] == False + # Verify nested $defs have additionalProperties + if "$defs" in person_tool.json_schema["parameters"]: + for def_name, def_schema in person_tool.json_schema["parameters"]["$defs"].items(): + if def_schema.get("type") == "object": + assert "additionalProperties" in def_schema, f"$defs.{def_name} missing additionalProperties after normalization" + + # Test 2: manage_tasks should normalize successfully + tasks_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "manage_tasks", actor=default_user) + assert tasks_tool is not None + assert tasks_tool.name == "manage_tasks" + # Verify array items have explicit type + tasks_prop = tasks_tool.json_schema["parameters"]["properties"]["tasks"] + assert "items" in tasks_prop + assert "type" in tasks_prop["items"], "Array items should have explicit type after normalization" + + # Test 3: search_with_filters should normalize successfully + search_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "search_with_filters", actor=default_user) + assert search_tool is not None + assert search_tool.name == "search_with_filters" + + # Verify all tools were persisted + all_tools = await server.tool_manager.list_tools_async( + actor=default_user, names=["create_person", "manage_tasks", "search_with_filters"] + ) + + # Filter to tools from our MCP server + mcp_tools = [ + tool + for tool in all_tools + if tool.metadata_ + and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_ + and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name + ] + + # All 3 complex schema tools should have been normalized and persisted + assert len(mcp_tools) == 3, f"Expected 3 normalized tools, got {len(mcp_tools)}" + + # Verify they all have the correct MCP metadata + for tool in mcp_tools: + assert tool.tool_type == ToolType.EXTERNAL_MCP + assert f"mcp:{server_name}" in tool.tags + + finally: + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + @patch("letta.services.mcp_manager.MCPManager.get_mcp_client") async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user): """Test that MCP server creation succeeds even when tool sync fails (optimistic approach).""" diff --git a/tests/mock_mcp_server.py b/tests/mock_mcp_server.py new file mode 100755 index 00000000..b3381720 --- /dev/null +++ b/tests/mock_mcp_server.py @@ -0,0 +1,283 @@ +#!/usr/bin/env python3 +""" +Simple MCP test server with basic and complex tools for testing purposes. +""" + +import json +import logging +from typing import List, Optional, Union + +from mcp.server.fastmcp import FastMCP +from pydantic import BaseModel, ConfigDict, Field + +# Configure logging to stderr (not stdout for STDIO servers) +logging.basicConfig(level=logging.INFO) + +# Initialize FastMCP server +mcp = FastMCP("test-server") + + +# Complex Pydantic models for testing +class Address(BaseModel): + """An address with street, city, and zip code.""" + + street: str = Field(..., description="Street address") + city: str = Field(..., description="City name") + zip_code: str = Field(..., description="ZIP code") + country: str = Field(default="USA", description="Country name") + + +class Person(BaseModel): + """A person with name, age, and optional address.""" + + name: str = Field(..., description="Person's full name") + age: int = Field(..., description="Person's age", ge=0, le=150) + email: Optional[str] = Field(None, description="Email address") + address: Optional[Address] = Field(None, description="Home address") + + +class TaskItem(BaseModel): + """A task item with title, priority, and completion status.""" + + title: str = Field(..., description="Task title") + priority: int = Field(default=1, description="Priority level (1-5)", ge=1, le=5) + completed: bool = Field(default=False, description="Whether the task is completed") + tags: List[str] = Field(default_factory=list, description="List of tags") + + +class SearchFilter(BaseModel): + """Filter criteria for searching.""" + + keywords: List[str] = Field(..., description="List of keywords to search for") + min_score: Optional[float] = Field(None, description="Minimum score threshold", ge=0.0, le=1.0) + categories: Optional[List[str]] = Field(None, description="Categories to filter by") + + +# Customer-reported schema models (matching mcp_schema.json pattern) +class Instantiation(BaseModel): + """Instantiation object with optional node identifiers.""" + + # model_config = ConfigDict(json_schema_extra={"additionalProperties": False}) + + doid: Optional[str] = Field(None, description="DOID identifier") + nodeFamilyId: Optional[int] = Field(None, description="Node family ID") + nodeTypeId: Optional[int] = Field(None, description="Node type ID") + nodePositionId: Optional[int] = Field(None, description="Node position ID") + + +class InstantiationData(BaseModel): + """Instantiation data with abstract and multiplicity flags.""" + + # model_config = ConfigDict(json_schema_extra={"additionalProperties": False}) + + isAbstract: Optional[bool] = Field(None, description="Whether the instantiation is abstract") + isMultiplicity: Optional[bool] = Field(None, description="Whether the instantiation has multiplicity") + instantiations: List[Instantiation] = Field(None, description="List of instantiations") + + +class ParameterPreset(BaseModel): + """Parameter preset enum values.""" + + value: str = Field(..., description="Preset value (a, b, c, e, f, g, h, i, d, l, s, m, z, o, u, unknown)") + + +@mcp.tool() +async def echo(message: str) -> str: + """Echo back the provided message. + + Args: + message: The message to echo back + """ + return f"Echo: {message}" + + +@mcp.tool() +async def add(a: float, b: float) -> str: + """Add two numbers together. + + Args: + a: First number + b: Second number + """ + result = a + b + return f"{a} + {b} = {result}" + + +@mcp.tool() +async def multiply(a: float, b: float) -> str: + """Multiply two numbers together. + + Args: + a: First number + b: Second number + """ + result = a * b + return f"{a} × {b} = {result}" + + +@mcp.tool() +async def reverse_string(text: str) -> str: + """Reverse a string. + + Args: + text: The string to reverse + """ + return text[::-1] + + +# Complex tools using Pydantic models + + +@mcp.tool() +async def create_person(person: Person) -> str: + """Create a person profile with nested address information. + + Args: + person: Person object with name, age, optional email and address + """ + result = "Created person profile:\n" + result += f" Name: {person.name}\n" + result += f" Age: {person.age}\n" + + if person.email: + result += f" Email: {person.email}\n" + + if person.address: + result += " Address:\n" + result += f" {person.address.street}\n" + result += f" {person.address.city}, {person.address.zip_code}\n" + result += f" {person.address.country}\n" + + return result + + +@mcp.tool() +async def manage_tasks(tasks: List[TaskItem]) -> str: + """Manage multiple tasks with priorities and tags. + + Args: + tasks: List of task items to manage + """ + if not tasks: + return "No tasks provided" + + result = f"Managing {len(tasks)} task(s):\n\n" + + for i, task in enumerate(tasks, 1): + status = "✓" if task.completed else "○" + result += f"{i}. [{status}] {task.title}\n" + result += f" Priority: {task.priority}/5\n" + + if task.tags: + result += f" Tags: {', '.join(task.tags)}\n" + + result += "\n" + + completed = sum(1 for t in tasks if t.completed) + result += f"Summary: {completed}/{len(tasks)} completed" + + return result + + +@mcp.tool() +async def search_with_filters(query: str, filters: SearchFilter) -> str: + """Search with complex filter criteria including keywords and categories. + + Args: + query: The main search query + filters: Complex filter object with keywords, score threshold, and categories + """ + result = f"Search Query: '{query}'\n\n" + result += "Filters Applied:\n" + result += f" Keywords: {', '.join(filters.keywords)}\n" + + if filters.min_score is not None: + result += f" Minimum Score: {filters.min_score}\n" + + if filters.categories: + result += f" Categories: {', '.join(filters.categories)}\n" + + # Simulate search results + result += "\nFound 3 results matching criteria:\n" + result += f" 1. Result matching '{filters.keywords[0]}' (score: 0.95)\n" + result += f" 2. Result matching '{query}' (score: 0.87)\n" + result += " 3. Result matching multiple keywords (score: 0.82)\n" + + return result + + +@mcp.tool() +async def process_nested_data(data: dict) -> str: + """Process arbitrary nested dictionary data. + + Args: + data: Nested dictionary with arbitrary structure + """ + result = "Processing nested data:\n" + result += json.dumps(data, indent=2) + result += "\n\nData structure stats:\n" + result += f" Keys at root level: {len(data)}\n" + + def count_nested_items(obj, depth=0): + count = 0 + max_depth = depth + if isinstance(obj, dict): + for v in obj.values(): + sub_count, sub_depth = count_nested_items(v, depth + 1) + count += sub_count + 1 + max_depth = max(max_depth, sub_depth) + elif isinstance(obj, list): + for item in obj: + sub_count, sub_depth = count_nested_items(item, depth + 1) + count += sub_count + 1 + max_depth = max(max_depth, sub_depth) + return count, max_depth + + total_items, max_depth = count_nested_items(data) + result += f" Total nested items: {total_items}\n" + result += f" Maximum nesting depth: {max_depth}\n" + + return result + + +@mcp.tool() +async def get_parameter_type_description( + preset: str, + instantiation_data: InstantiationData, + connected_service_descriptor: Optional[str] = None, +) -> str: + """Get parameter type description with complex nested structure. + + This tool matches the customer-reported schema pattern with: + - Enum-like preset parameter + - Optional string field + - Optional nested object with arrays of objects + + Args: + preset: The parameter preset (a, b, c, e, f, g, h, i, d, l, s, m, z, o, u, unknown) + connected_service_descriptor: Connected service descriptor string, if available + instantiation_data: Instantiation data dict with isAbstract, isMultiplicity, and instantiations list + """ + result = "Parameter Type Description\n" + result += "=" * 50 + "\n\n" + result += f"Preset: {preset}\n\n" + + if connected_service_descriptor: + result += f"Connected Service: {connected_service_descriptor}\n\n" + + if instantiation_data: + result += "Instantiation Data:\n" + result += f" Is Abstract: {instantiation_data.isAbstract}\n" + result += f" Is Multiplicity: {instantiation_data.isMultiplicity}\n" + result += f" Instantiations: {instantiation_data.instantiations}\n" + + return result + + +def main(): + # Initialize and run the server + mcp.run(transport="stdio") + + +if __name__ == "__main__": + main()