fix: fix mcp for complex schemas and add tests (#5063)
This commit is contained in:
committed by
Caren Thomas
parent
d7b2d3c6ba
commit
7b73b25a95
@@ -588,11 +588,111 @@ def generate_schema_from_args_schema_v2(
|
||||
return function_call_json
|
||||
|
||||
|
||||
def normalize_mcp_schema(schema: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Normalize an MCP JSON schema to fix common issues:
|
||||
1. Add explicit 'additionalProperties': false to all object types
|
||||
2. Add explicit 'type' field to properties using $ref
|
||||
3. Process $defs recursively
|
||||
|
||||
Args:
|
||||
schema: The JSON schema to normalize (will be modified in-place)
|
||||
|
||||
Returns:
|
||||
The normalized schema (same object, modified in-place)
|
||||
"""
|
||||
import copy
|
||||
|
||||
# Work on a deep copy to avoid modifying the original
|
||||
schema = copy.deepcopy(schema)
|
||||
|
||||
def normalize_object_schema(obj_schema: Dict[str, Any], defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
|
||||
"""Recursively normalize an object schema."""
|
||||
|
||||
# If this is an object type, add additionalProperties if missing
|
||||
if obj_schema.get("type") == "object":
|
||||
if "additionalProperties" not in obj_schema:
|
||||
obj_schema["additionalProperties"] = False
|
||||
|
||||
# Handle properties
|
||||
if "properties" in obj_schema:
|
||||
for prop_name, prop_schema in obj_schema["properties"].items():
|
||||
# Handle $ref references
|
||||
if "$ref" in prop_schema:
|
||||
# Add explicit type based on the reference
|
||||
if "type" not in prop_schema:
|
||||
# Try to resolve the type from $defs if available
|
||||
if defs and prop_schema["$ref"].startswith("#/$defs/"):
|
||||
def_name = prop_schema["$ref"].split("/")[-1]
|
||||
if def_name in defs:
|
||||
ref_schema = defs[def_name]
|
||||
if "type" in ref_schema:
|
||||
prop_schema["type"] = ref_schema["type"]
|
||||
|
||||
# If still no type, assume object (common case for model references)
|
||||
if "type" not in prop_schema:
|
||||
prop_schema["type"] = "object"
|
||||
|
||||
# Don't add additionalProperties to properties with $ref
|
||||
# The $ref schema itself will have additionalProperties
|
||||
# Adding it here makes the validator think it allows empty objects
|
||||
continue
|
||||
|
||||
# Recursively normalize nested objects
|
||||
if isinstance(prop_schema, dict):
|
||||
if prop_schema.get("type") == "object":
|
||||
normalize_object_schema(prop_schema, defs)
|
||||
|
||||
# Handle arrays with object items
|
||||
if prop_schema.get("type") == "array" and "items" in prop_schema:
|
||||
items = prop_schema["items"]
|
||||
if isinstance(items, dict):
|
||||
# Handle $ref in items
|
||||
if "$ref" in items and "type" not in items:
|
||||
if defs and items["$ref"].startswith("#/$defs/"):
|
||||
def_name = items["$ref"].split("/")[-1]
|
||||
if def_name in defs and "type" in defs[def_name]:
|
||||
items["type"] = defs[def_name]["type"]
|
||||
if "type" not in items:
|
||||
items["type"] = "object"
|
||||
|
||||
# Recursively normalize items
|
||||
if items.get("type") == "object":
|
||||
normalize_object_schema(items, defs)
|
||||
|
||||
# Handle anyOf (complex union types)
|
||||
if "anyOf" in prop_schema:
|
||||
for option in prop_schema["anyOf"]:
|
||||
if isinstance(option, dict) and option.get("type") == "object":
|
||||
normalize_object_schema(option, defs)
|
||||
|
||||
# Handle array items at the top level
|
||||
if "items" in obj_schema and isinstance(obj_schema["items"], dict):
|
||||
if obj_schema["items"].get("type") == "object":
|
||||
normalize_object_schema(obj_schema["items"], defs)
|
||||
|
||||
return obj_schema
|
||||
|
||||
# Process $defs first if they exist
|
||||
defs = schema.get("$defs", {})
|
||||
if defs:
|
||||
for def_name, def_schema in defs.items():
|
||||
if isinstance(def_schema, dict):
|
||||
normalize_object_schema(def_schema, defs)
|
||||
|
||||
# Process the main schema
|
||||
normalize_object_schema(schema, defs)
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
def generate_tool_schema_for_mcp(
|
||||
mcp_tool: MCPTool,
|
||||
append_heartbeat: bool = True,
|
||||
strict: bool = False,
|
||||
) -> Dict[str, Any]:
|
||||
from letta.functions.schema_validator import validate_complete_json_schema
|
||||
|
||||
# MCP tool.inputSchema is a JSON schema
|
||||
# https://github.com/modelcontextprotocol/python-sdk/blob/775f87981300660ee957b63c2a14b448ab9c3675/src/mcp/types.py#L678
|
||||
parameters_schema = mcp_tool.inputSchema
|
||||
@@ -603,11 +703,16 @@ def generate_tool_schema_for_mcp(
|
||||
assert "properties" in parameters_schema, parameters_schema
|
||||
# assert "required" in parameters_schema, parameters_schema
|
||||
|
||||
# Normalize the schema to fix common issues with MCP schemas
|
||||
# This adds additionalProperties: false and explicit types for $ref properties
|
||||
parameters_schema = normalize_mcp_schema(parameters_schema)
|
||||
|
||||
# Zero-arg tools often omit "required" because nothing is required.
|
||||
# Normalise so downstream code can treat it consistently.
|
||||
parameters_schema.setdefault("required", [])
|
||||
|
||||
# Process properties to handle anyOf types and make optional fields strict-compatible
|
||||
# TODO: de-duplicate with handling in normalize_mcp_schema
|
||||
if "properties" in parameters_schema:
|
||||
for field_name, field_props in parameters_schema["properties"].items():
|
||||
# Handle anyOf types by flattening to type array
|
||||
@@ -660,6 +765,14 @@ def generate_tool_schema_for_mcp(
|
||||
if REQUEST_HEARTBEAT_PARAM not in parameters_schema["required"]:
|
||||
parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM)
|
||||
|
||||
# Re-validate the schema after normalization and update the health status
|
||||
# This allows previously INVALID schemas to pass if normalization fixed them
|
||||
if mcp_tool.health:
|
||||
health_status, health_reasons = validate_complete_json_schema(parameters_schema)
|
||||
mcp_tool.health.status = health_status.value
|
||||
mcp_tool.health.reasons = health_reasons
|
||||
logger.debug(f"MCP tool {name} schema health after normalization: {health_status.value}, reasons: {health_reasons}")
|
||||
|
||||
# Return the final schema
|
||||
if strict:
|
||||
# https://platform.openai.com/docs/guides/function-calling#strict-mode
|
||||
|
||||
@@ -140,12 +140,44 @@ class MCPManager:
|
||||
for mcp_tool in mcp_tools:
|
||||
# TODO: @jnjpng move health check to tool class
|
||||
if mcp_tool.name == mcp_tool_name:
|
||||
# Check tool health - reject only INVALID tools
|
||||
if mcp_tool.health:
|
||||
if mcp_tool.health.status == "INVALID":
|
||||
raise ValueError(
|
||||
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid.Reasons: {', '.join(mcp_tool.health.reasons)}"
|
||||
)
|
||||
# Check tool health - but try normalization first for INVALID schemas
|
||||
if mcp_tool.health and mcp_tool.health.status == "INVALID":
|
||||
logger.info(f"Attempting to normalize INVALID schema for tool {mcp_tool_name}")
|
||||
logger.info(f"Original health reasons: {mcp_tool.health.reasons}")
|
||||
|
||||
# Try to normalize the schema and re-validate
|
||||
from letta.functions.schema_generator import normalize_mcp_schema
|
||||
from letta.functions.schema_validator import validate_complete_json_schema
|
||||
|
||||
try:
|
||||
# Normalize the schema to fix common issues
|
||||
logger.debug(f"Normalizing schema for {mcp_tool_name}")
|
||||
normalized_schema = normalize_mcp_schema(mcp_tool.inputSchema)
|
||||
|
||||
# Re-validate after normalization
|
||||
logger.debug(f"Re-validating schema for {mcp_tool_name}")
|
||||
health_status, health_reasons = validate_complete_json_schema(normalized_schema)
|
||||
logger.info(f"After normalization: status={health_status.value}, reasons={health_reasons}")
|
||||
|
||||
# Update the tool's schema and health (use inputSchema, not input_schema)
|
||||
mcp_tool.inputSchema = normalized_schema
|
||||
mcp_tool.health.status = health_status.value
|
||||
mcp_tool.health.reasons = health_reasons
|
||||
|
||||
# Log the normalization result
|
||||
if health_status.value != "INVALID":
|
||||
logger.info(f"✓ MCP tool {mcp_tool_name} schema normalized successfully: {health_status.value}")
|
||||
else:
|
||||
logger.warning(f"MCP tool {mcp_tool_name} still INVALID after normalization. Reasons: {health_reasons}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to normalize schema for tool {mcp_tool_name}: {e}", exc_info=True)
|
||||
|
||||
# After normalization attempt, check if still INVALID
|
||||
if mcp_tool.health and mcp_tool.health.status == "INVALID":
|
||||
raise ValueError(
|
||||
f"Tool {mcp_tool_name} cannot be attached, JSON schema is invalid even after normalization. "
|
||||
f"Reasons: {', '.join(mcp_tool.health.reasons)}"
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return await self.tool_manager.create_mcp_tool_async(
|
||||
|
||||
396
tests/integration_test_mcp.py
Normal file
396
tests/integration_test_mcp.py
Normal file
@@ -0,0 +1,396 @@
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta, MessageCreate, ToolCallMessage, ToolReturnMessage
|
||||
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it's accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
yield url
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mcp_server_name() -> str:
|
||||
"""Generate a unique MCP server name for each test."""
|
||||
return f"test-mcp-server-{uuid.uuid4().hex[:8]}"
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def mock_mcp_server_config(mcp_server_name: str) -> StdioServerConfig:
|
||||
"""
|
||||
Creates a stdio configuration for the mock MCP server.
|
||||
"""
|
||||
# Get path to mock_mcp_server.py
|
||||
script_dir = Path(__file__).parent
|
||||
mcp_server_path = script_dir / "mock_mcp_server.py"
|
||||
|
||||
if not mcp_server_path.exists():
|
||||
raise FileNotFoundError(f"Mock MCP server not found at {mcp_server_path}")
|
||||
|
||||
return StdioServerConfig(
|
||||
server_name=mcp_server_name,
|
||||
command=sys.executable, # Use the current Python interpreter
|
||||
args=[str(mcp_server_path)],
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="function")
|
||||
def agent_state(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig) -> AgentState:
|
||||
"""
|
||||
Creates an agent with MCP tools attached for testing.
|
||||
"""
|
||||
# Register the MCP server
|
||||
client.tools.add_mcp_server(request=mock_mcp_server_config)
|
||||
|
||||
# Verify server is registered
|
||||
servers = client.tools.list_mcp_servers()
|
||||
assert mcp_server_name in servers, f"MCP server {mcp_server_name} not found in {servers}"
|
||||
|
||||
# List available MCP tools
|
||||
mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||
assert len(mcp_tools) > 0, "No tools found from MCP server"
|
||||
|
||||
# Add the echo and add tools to Letta
|
||||
echo_tool = next((t for t in mcp_tools if t.name == "echo"), None)
|
||||
add_tool = next((t for t in mcp_tools if t.name == "add"), None)
|
||||
|
||||
assert echo_tool is not None, "echo tool not found"
|
||||
assert add_tool is not None, "add tool not found"
|
||||
|
||||
letta_echo_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="echo")
|
||||
letta_add_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="add")
|
||||
|
||||
# Create agent with the MCP tools
|
||||
agent = client.agents.create(
|
||||
name=f"test_mcp_agent_{uuid.uuid4().hex[:8]}",
|
||||
include_base_tools=True,
|
||||
tool_ids=[letta_echo_tool.id, letta_add_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "Name: Test User",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "You are a helpful assistant that can use MCP tools to help the user.",
|
||||
},
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=["test_mcp_agent"],
|
||||
)
|
||||
|
||||
yield agent
|
||||
|
||||
# Cleanup
|
||||
try:
|
||||
client.agents.delete(agent.id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to delete agent {agent.id}: {e}")
|
||||
|
||||
try:
|
||||
client.tools.delete_mcp_server(mcp_server_name=mcp_server_name)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to delete MCP server {mcp_server_name}: {e}")
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Test Cases
|
||||
# ------------------------------
|
||||
|
||||
|
||||
def test_mcp_echo_tool(client: Letta, agent_state: AgentState):
|
||||
"""
|
||||
Test that an agent can successfully call the echo tool from the MCP server.
|
||||
"""
|
||||
test_message = "Hello from MCP integration test!"
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=f"Use the echo tool to echo back this exact message: '{test_message}'",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check for tool call message
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage"
|
||||
|
||||
# Find the echo tool call
|
||||
echo_call = next((m for m in tool_calls if m.tool_call.name == "echo"), None)
|
||||
assert echo_call is not None, f"No echo tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}"
|
||||
|
||||
# Check for tool return message
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
|
||||
|
||||
# Find the return for the echo call
|
||||
echo_return = next((m for m in tool_returns if m.tool_call_id == echo_call.tool_call.tool_call_id), None)
|
||||
assert echo_return is not None, "No tool return found for echo call"
|
||||
assert echo_return.status == "success", f"Echo tool failed with status: {echo_return.status}"
|
||||
|
||||
# Verify the echo response contains our message
|
||||
assert test_message in echo_return.tool_return, f"Expected '{test_message}' in tool return, got: {echo_return.tool_return}"
|
||||
|
||||
|
||||
def test_mcp_add_tool(client: Letta, agent_state: AgentState):
|
||||
"""
|
||||
Test that an agent can successfully call the add tool from the MCP server.
|
||||
"""
|
||||
a, b = 42, 58
|
||||
expected_sum = a + b
|
||||
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=f"Use the add tool to add {a} and {b}.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check for tool call message
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage"
|
||||
|
||||
# Find the add tool call
|
||||
add_call = next((m for m in tool_calls if m.tool_call.name == "add"), None)
|
||||
assert add_call is not None, f"No add tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}"
|
||||
|
||||
# Check for tool return message
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
|
||||
|
||||
# Find the return for the add call
|
||||
add_return = next((m for m in tool_returns if m.tool_call_id == add_call.tool_call.tool_call_id), None)
|
||||
assert add_return is not None, "No tool return found for add call"
|
||||
assert add_return.status == "success", f"Add tool failed with status: {add_return.status}"
|
||||
|
||||
# Verify the result contains the expected sum
|
||||
assert str(expected_sum) in add_return.tool_return, f"Expected '{expected_sum}' in tool return, got: {add_return.tool_return}"
|
||||
|
||||
|
||||
def test_mcp_multiple_tools_in_sequence(client: Letta, agent_state: AgentState):
|
||||
"""
|
||||
Test that an agent can call multiple MCP tools in sequence.
|
||||
"""
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent_state.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="First use the add tool to add 10 and 20. Then use the echo tool to echo back the result you got from the add tool.",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Check for tool call messages
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
assert len(tool_calls) >= 2, f"Expected at least 2 tool calls, got {len(tool_calls)}"
|
||||
|
||||
# Verify both tools were called
|
||||
tool_names = [m.tool_call.name for m in tool_calls]
|
||||
assert "add" in tool_names, f"add tool not called. Tools called: {tool_names}"
|
||||
assert "echo" in tool_names, f"echo tool not called. Tools called: {tool_names}"
|
||||
|
||||
# Check for tool return messages
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert len(tool_returns) >= 2, f"Expected at least 2 tool returns, got {len(tool_returns)}"
|
||||
|
||||
# Verify all tools succeeded
|
||||
for tool_return in tool_returns:
|
||||
assert tool_return.status == "success", f"Tool call failed with status: {tool_return.status}"
|
||||
|
||||
|
||||
def test_mcp_server_listing(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig):
|
||||
"""
|
||||
Test that MCP server registration and tool listing works correctly.
|
||||
"""
|
||||
# Register the MCP server
|
||||
client.tools.add_mcp_server(request=mock_mcp_server_config)
|
||||
|
||||
try:
|
||||
# Verify server is in the list
|
||||
servers = client.tools.list_mcp_servers()
|
||||
assert mcp_server_name in servers, f"MCP server {mcp_server_name} not found in {servers}"
|
||||
|
||||
# List available tools
|
||||
mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||
assert len(mcp_tools) > 0, "No tools found from MCP server"
|
||||
|
||||
# Verify expected tools are present
|
||||
tool_names = [t.name for t in mcp_tools]
|
||||
expected_tools = ["echo", "add", "multiply", "reverse_string"]
|
||||
for expected_tool in expected_tools:
|
||||
assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found. Available: {tool_names}"
|
||||
|
||||
finally:
|
||||
# Cleanup
|
||||
client.tools.delete_mcp_server(mcp_server_name=mcp_server_name)
|
||||
servers = client.tools.list_mcp_servers()
|
||||
assert mcp_server_name not in servers, f"MCP server {mcp_server_name} should be deleted but is still in {servers}"
|
||||
|
||||
|
||||
def test_mcp_complex_schema_tool(client: Letta, mcp_server_name: str, mock_mcp_server_config: StdioServerConfig):
|
||||
"""
|
||||
Test that an agent can successfully call a tool with complex nested schema.
|
||||
This tests the get_parameter_type_description tool which has:
|
||||
- Enum-like preset parameter
|
||||
- Optional string field
|
||||
- Optional nested object with arrays of objects
|
||||
"""
|
||||
# Register the MCP server
|
||||
client.tools.add_mcp_server(request=mock_mcp_server_config)
|
||||
|
||||
try:
|
||||
# List available tools
|
||||
mcp_tools = client.tools.list_mcp_tools_by_server(mcp_server_name=mcp_server_name)
|
||||
|
||||
# Find the complex schema tool
|
||||
complex_tool = next((t for t in mcp_tools if t.name == "get_parameter_type_description"), None)
|
||||
assert complex_tool is not None, f"get_parameter_type_description tool not found. Available: {[t.name for t in mcp_tools]}"
|
||||
|
||||
# Add it to Letta
|
||||
letta_complex_tool = client.tools.add_mcp_tool(mcp_server_name=mcp_server_name, mcp_tool_name="get_parameter_type_description")
|
||||
|
||||
# Create agent with the complex tool
|
||||
agent = client.agents.create(
|
||||
name=f"test_complex_schema_{uuid.uuid4().hex[:8]}",
|
||||
include_base_tools=True,
|
||||
tool_ids=[letta_complex_tool.id],
|
||||
memory_blocks=[
|
||||
{
|
||||
"label": "human",
|
||||
"value": "Name: Test User",
|
||||
},
|
||||
{
|
||||
"label": "persona",
|
||||
"value": "You are a helpful assistant that can use MCP tools with complex schemas.",
|
||||
},
|
||||
],
|
||||
llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
tags=["test_complex_schema"],
|
||||
)
|
||||
|
||||
# Test 1: Simple call with just preset
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user", content='Use the get_parameter_type_description tool with preset "a" to get parameter information.'
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage"
|
||||
|
||||
complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None)
|
||||
assert complex_call is not None, f"No get_parameter_type_description call found. Calls: {[m.tool_call.name for m in tool_calls]}"
|
||||
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage"
|
||||
|
||||
complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None)
|
||||
assert complex_return is not None, "No tool return found for complex schema call"
|
||||
assert complex_return.status == "success", f"Complex schema tool failed with status: {complex_return.status}"
|
||||
assert "Preset: a" in complex_return.tool_return, f"Expected 'Preset: a' in return, got: {complex_return.tool_return}"
|
||||
|
||||
# Test 2: Complex call with nested data
|
||||
response = client.agents.messages.create(
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content="Use the get_parameter_type_description tool with these arguments: "
|
||||
'preset="b", connected_service_descriptor="test-service", '
|
||||
"and instantiation_data with isAbstract=true, isMultiplicity=false, "
|
||||
'and one instantiation with doid="TEST123" and nodeFamilyId=42.',
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)]
|
||||
assert len(tool_calls) > 0, "Expected at least one ToolCallMessage for complex nested call"
|
||||
|
||||
complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None)
|
||||
assert complex_call is not None, "No get_parameter_type_description call found for nested test"
|
||||
|
||||
tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)]
|
||||
complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None)
|
||||
assert complex_return is not None, "No tool return found for complex nested call"
|
||||
assert complex_return.status == "success", f"Complex nested call failed with status: {complex_return.status}"
|
||||
|
||||
# Verify the response contains our complex data
|
||||
assert "Preset: b" in complex_return.tool_return, "Expected preset 'b' in response"
|
||||
assert "test-service" in complex_return.tool_return, "Expected service descriptor in response"
|
||||
|
||||
# Cleanup agent
|
||||
client.agents.delete(agent.id)
|
||||
|
||||
finally:
|
||||
# Cleanup MCP server
|
||||
client.tools.delete_mcp_server(mcp_server_name=mcp_server_name)
|
||||
@@ -320,6 +320,189 @@ async def test_create_mcp_server_with_tools(mock_get_client, server, default_use
|
||||
assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_complex_schema_normalization(mock_get_client, server, default_user):
|
||||
"""Test that complex MCP schemas with nested objects are normalized and accepted."""
|
||||
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# Create mock tools with complex schemas that would normally be INVALID
|
||||
# These schemas have: nested $defs, $ref references, missing additionalProperties
|
||||
mock_tools = [
|
||||
# 1. Nested object with $ref (like create_person)
|
||||
MCPTool(
|
||||
name="create_person",
|
||||
description="Create a person with nested address",
|
||||
inputSchema={
|
||||
"$defs": {
|
||||
"Address": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"street": {"type": "string"},
|
||||
"city": {"type": "string"},
|
||||
"zip_code": {"type": "string"},
|
||||
},
|
||||
"required": ["street", "city", "zip_code"],
|
||||
},
|
||||
"Person": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer"},
|
||||
"address": {"$ref": "#/$defs/Address"},
|
||||
},
|
||||
"required": ["name", "age"],
|
||||
},
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {"person": {"$ref": "#/$defs/Person"}},
|
||||
"required": ["person"],
|
||||
},
|
||||
health=MCPToolHealth(
|
||||
status="INVALID",
|
||||
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.person: Missing 'type'"],
|
||||
),
|
||||
),
|
||||
# 2. List of objects (like manage_tasks)
|
||||
MCPTool(
|
||||
name="manage_tasks",
|
||||
description="Manage multiple tasks",
|
||||
inputSchema={
|
||||
"$defs": {
|
||||
"TaskItem": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"title": {"type": "string"},
|
||||
"priority": {"type": "integer", "default": 1},
|
||||
"completed": {"type": "boolean", "default": False},
|
||||
"tags": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["title"],
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tasks": {
|
||||
"type": "array",
|
||||
"items": {"$ref": "#/$defs/TaskItem"},
|
||||
}
|
||||
},
|
||||
"required": ["tasks"],
|
||||
},
|
||||
health=MCPToolHealth(
|
||||
status="INVALID",
|
||||
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.tasks.items: Missing 'type'"],
|
||||
),
|
||||
),
|
||||
# 3. Complex filter object with optional fields
|
||||
MCPTool(
|
||||
name="search_with_filters",
|
||||
description="Search with complex filters",
|
||||
inputSchema={
|
||||
"$defs": {
|
||||
"SearchFilter": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"keywords": {"type": "array", "items": {"type": "string"}},
|
||||
"min_score": {"type": "number"},
|
||||
"categories": {"type": "array", "items": {"type": "string"}},
|
||||
},
|
||||
"required": ["keywords"],
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"filters": {"$ref": "#/$defs/SearchFilter"},
|
||||
},
|
||||
"required": ["query", "filters"],
|
||||
},
|
||||
health=MCPToolHealth(
|
||||
status="INVALID",
|
||||
reasons=["root: 'additionalProperties' not explicitly set", "root.properties.filters: Missing 'type'"],
|
||||
),
|
||||
),
|
||||
]
|
||||
|
||||
# Create mock client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.connect_to_server = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
mock_client.cleanup = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create MCP server
|
||||
server_name = f"test_complex_schema_{uuid.uuid4().hex[:8]}"
|
||||
server_url = "https://test-complex.example.com/sse"
|
||||
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
|
||||
try:
|
||||
# Create server (this will auto-sync tools)
|
||||
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
|
||||
|
||||
assert created_server.server_name == server_name
|
||||
|
||||
# Now attempt to add each tool - they should be normalized from INVALID to acceptable
|
||||
# The normalization happens in add_tool_from_mcp_server
|
||||
|
||||
# Test 1: create_person should normalize successfully
|
||||
person_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "create_person", actor=default_user)
|
||||
assert person_tool is not None
|
||||
assert person_tool.name == "create_person"
|
||||
# Verify the schema has additionalProperties set
|
||||
assert person_tool.json_schema["parameters"]["additionalProperties"] == False
|
||||
# Verify nested $defs have additionalProperties
|
||||
if "$defs" in person_tool.json_schema["parameters"]:
|
||||
for def_name, def_schema in person_tool.json_schema["parameters"]["$defs"].items():
|
||||
if def_schema.get("type") == "object":
|
||||
assert "additionalProperties" in def_schema, f"$defs.{def_name} missing additionalProperties after normalization"
|
||||
|
||||
# Test 2: manage_tasks should normalize successfully
|
||||
tasks_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "manage_tasks", actor=default_user)
|
||||
assert tasks_tool is not None
|
||||
assert tasks_tool.name == "manage_tasks"
|
||||
# Verify array items have explicit type
|
||||
tasks_prop = tasks_tool.json_schema["parameters"]["properties"]["tasks"]
|
||||
assert "items" in tasks_prop
|
||||
assert "type" in tasks_prop["items"], "Array items should have explicit type after normalization"
|
||||
|
||||
# Test 3: search_with_filters should normalize successfully
|
||||
search_tool = await server.mcp_manager.add_tool_from_mcp_server(server_name, "search_with_filters", actor=default_user)
|
||||
assert search_tool is not None
|
||||
assert search_tool.name == "search_with_filters"
|
||||
|
||||
# Verify all tools were persisted
|
||||
all_tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, names=["create_person", "manage_tasks", "search_with_filters"]
|
||||
)
|
||||
|
||||
# Filter to tools from our MCP server
|
||||
mcp_tools = [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if tool.metadata_
|
||||
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
||||
]
|
||||
|
||||
# All 3 complex schema tools should have been normalized and persisted
|
||||
assert len(mcp_tools) == 3, f"Expected 3 normalized tools, got {len(mcp_tools)}"
|
||||
|
||||
# Verify they all have the correct MCP metadata
|
||||
for tool in mcp_tools:
|
||||
assert tool.tool_type == ToolType.EXTERNAL_MCP
|
||||
assert f"mcp:{server_name}" in tool.tags
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user):
|
||||
"""Test that MCP server creation succeeds even when tool sync fails (optimistic approach)."""
|
||||
|
||||
283
tests/mock_mcp_server.py
Executable file
283
tests/mock_mcp_server.py
Executable file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple MCP test server with basic and complex tools for testing purposes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
# Configure logging to stderr (not stdout for STDIO servers)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Initialize FastMCP server
|
||||
mcp = FastMCP("test-server")
|
||||
|
||||
|
||||
# Complex Pydantic models for testing
|
||||
class Address(BaseModel):
|
||||
"""An address with street, city, and zip code."""
|
||||
|
||||
street: str = Field(..., description="Street address")
|
||||
city: str = Field(..., description="City name")
|
||||
zip_code: str = Field(..., description="ZIP code")
|
||||
country: str = Field(default="USA", description="Country name")
|
||||
|
||||
|
||||
class Person(BaseModel):
|
||||
"""A person with name, age, and optional address."""
|
||||
|
||||
name: str = Field(..., description="Person's full name")
|
||||
age: int = Field(..., description="Person's age", ge=0, le=150)
|
||||
email: Optional[str] = Field(None, description="Email address")
|
||||
address: Optional[Address] = Field(None, description="Home address")
|
||||
|
||||
|
||||
class TaskItem(BaseModel):
|
||||
"""A task item with title, priority, and completion status."""
|
||||
|
||||
title: str = Field(..., description="Task title")
|
||||
priority: int = Field(default=1, description="Priority level (1-5)", ge=1, le=5)
|
||||
completed: bool = Field(default=False, description="Whether the task is completed")
|
||||
tags: List[str] = Field(default_factory=list, description="List of tags")
|
||||
|
||||
|
||||
class SearchFilter(BaseModel):
|
||||
"""Filter criteria for searching."""
|
||||
|
||||
keywords: List[str] = Field(..., description="List of keywords to search for")
|
||||
min_score: Optional[float] = Field(None, description="Minimum score threshold", ge=0.0, le=1.0)
|
||||
categories: Optional[List[str]] = Field(None, description="Categories to filter by")
|
||||
|
||||
|
||||
# Customer-reported schema models (matching mcp_schema.json pattern)
|
||||
class Instantiation(BaseModel):
|
||||
"""Instantiation object with optional node identifiers."""
|
||||
|
||||
# model_config = ConfigDict(json_schema_extra={"additionalProperties": False})
|
||||
|
||||
doid: Optional[str] = Field(None, description="DOID identifier")
|
||||
nodeFamilyId: Optional[int] = Field(None, description="Node family ID")
|
||||
nodeTypeId: Optional[int] = Field(None, description="Node type ID")
|
||||
nodePositionId: Optional[int] = Field(None, description="Node position ID")
|
||||
|
||||
|
||||
class InstantiationData(BaseModel):
|
||||
"""Instantiation data with abstract and multiplicity flags."""
|
||||
|
||||
# model_config = ConfigDict(json_schema_extra={"additionalProperties": False})
|
||||
|
||||
isAbstract: Optional[bool] = Field(None, description="Whether the instantiation is abstract")
|
||||
isMultiplicity: Optional[bool] = Field(None, description="Whether the instantiation has multiplicity")
|
||||
instantiations: List[Instantiation] = Field(None, description="List of instantiations")
|
||||
|
||||
|
||||
class ParameterPreset(BaseModel):
|
||||
"""Parameter preset enum values."""
|
||||
|
||||
value: str = Field(..., description="Preset value (a, b, c, e, f, g, h, i, d, l, s, m, z, o, u, unknown)")
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def echo(message: str) -> str:
|
||||
"""Echo back the provided message.
|
||||
|
||||
Args:
|
||||
message: The message to echo back
|
||||
"""
|
||||
return f"Echo: {message}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def add(a: float, b: float) -> str:
|
||||
"""Add two numbers together.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
"""
|
||||
result = a + b
|
||||
return f"{a} + {b} = {result}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def multiply(a: float, b: float) -> str:
|
||||
"""Multiply two numbers together.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
"""
|
||||
result = a * b
|
||||
return f"{a} × {b} = {result}"
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def reverse_string(text: str) -> str:
|
||||
"""Reverse a string.
|
||||
|
||||
Args:
|
||||
text: The string to reverse
|
||||
"""
|
||||
return text[::-1]
|
||||
|
||||
|
||||
# Complex tools using Pydantic models
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def create_person(person: Person) -> str:
|
||||
"""Create a person profile with nested address information.
|
||||
|
||||
Args:
|
||||
person: Person object with name, age, optional email and address
|
||||
"""
|
||||
result = "Created person profile:\n"
|
||||
result += f" Name: {person.name}\n"
|
||||
result += f" Age: {person.age}\n"
|
||||
|
||||
if person.email:
|
||||
result += f" Email: {person.email}\n"
|
||||
|
||||
if person.address:
|
||||
result += " Address:\n"
|
||||
result += f" {person.address.street}\n"
|
||||
result += f" {person.address.city}, {person.address.zip_code}\n"
|
||||
result += f" {person.address.country}\n"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def manage_tasks(tasks: List[TaskItem]) -> str:
|
||||
"""Manage multiple tasks with priorities and tags.
|
||||
|
||||
Args:
|
||||
tasks: List of task items to manage
|
||||
"""
|
||||
if not tasks:
|
||||
return "No tasks provided"
|
||||
|
||||
result = f"Managing {len(tasks)} task(s):\n\n"
|
||||
|
||||
for i, task in enumerate(tasks, 1):
|
||||
status = "✓" if task.completed else "○"
|
||||
result += f"{i}. [{status}] {task.title}\n"
|
||||
result += f" Priority: {task.priority}/5\n"
|
||||
|
||||
if task.tags:
|
||||
result += f" Tags: {', '.join(task.tags)}\n"
|
||||
|
||||
result += "\n"
|
||||
|
||||
completed = sum(1 for t in tasks if t.completed)
|
||||
result += f"Summary: {completed}/{len(tasks)} completed"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def search_with_filters(query: str, filters: SearchFilter) -> str:
|
||||
"""Search with complex filter criteria including keywords and categories.
|
||||
|
||||
Args:
|
||||
query: The main search query
|
||||
filters: Complex filter object with keywords, score threshold, and categories
|
||||
"""
|
||||
result = f"Search Query: '{query}'\n\n"
|
||||
result += "Filters Applied:\n"
|
||||
result += f" Keywords: {', '.join(filters.keywords)}\n"
|
||||
|
||||
if filters.min_score is not None:
|
||||
result += f" Minimum Score: {filters.min_score}\n"
|
||||
|
||||
if filters.categories:
|
||||
result += f" Categories: {', '.join(filters.categories)}\n"
|
||||
|
||||
# Simulate search results
|
||||
result += "\nFound 3 results matching criteria:\n"
|
||||
result += f" 1. Result matching '{filters.keywords[0]}' (score: 0.95)\n"
|
||||
result += f" 2. Result matching '{query}' (score: 0.87)\n"
|
||||
result += " 3. Result matching multiple keywords (score: 0.82)\n"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def process_nested_data(data: dict) -> str:
|
||||
"""Process arbitrary nested dictionary data.
|
||||
|
||||
Args:
|
||||
data: Nested dictionary with arbitrary structure
|
||||
"""
|
||||
result = "Processing nested data:\n"
|
||||
result += json.dumps(data, indent=2)
|
||||
result += "\n\nData structure stats:\n"
|
||||
result += f" Keys at root level: {len(data)}\n"
|
||||
|
||||
def count_nested_items(obj, depth=0):
|
||||
count = 0
|
||||
max_depth = depth
|
||||
if isinstance(obj, dict):
|
||||
for v in obj.values():
|
||||
sub_count, sub_depth = count_nested_items(v, depth + 1)
|
||||
count += sub_count + 1
|
||||
max_depth = max(max_depth, sub_depth)
|
||||
elif isinstance(obj, list):
|
||||
for item in obj:
|
||||
sub_count, sub_depth = count_nested_items(item, depth + 1)
|
||||
count += sub_count + 1
|
||||
max_depth = max(max_depth, sub_depth)
|
||||
return count, max_depth
|
||||
|
||||
total_items, max_depth = count_nested_items(data)
|
||||
result += f" Total nested items: {total_items}\n"
|
||||
result += f" Maximum nesting depth: {max_depth}\n"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@mcp.tool()
|
||||
async def get_parameter_type_description(
|
||||
preset: str,
|
||||
instantiation_data: InstantiationData,
|
||||
connected_service_descriptor: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get parameter type description with complex nested structure.
|
||||
|
||||
This tool matches the customer-reported schema pattern with:
|
||||
- Enum-like preset parameter
|
||||
- Optional string field
|
||||
- Optional nested object with arrays of objects
|
||||
|
||||
Args:
|
||||
preset: The parameter preset (a, b, c, e, f, g, h, i, d, l, s, m, z, o, u, unknown)
|
||||
connected_service_descriptor: Connected service descriptor string, if available
|
||||
instantiation_data: Instantiation data dict with isAbstract, isMultiplicity, and instantiations list
|
||||
"""
|
||||
result = "Parameter Type Description\n"
|
||||
result += "=" * 50 + "\n\n"
|
||||
result += f"Preset: {preset}\n\n"
|
||||
|
||||
if connected_service_descriptor:
|
||||
result += f"Connected Service: {connected_service_descriptor}\n\n"
|
||||
|
||||
if instantiation_data:
|
||||
result += "Instantiation Data:\n"
|
||||
result += f" Is Abstract: {instantiation_data.isAbstract}\n"
|
||||
result += f" Is Multiplicity: {instantiation_data.isMultiplicity}\n"
|
||||
result += f" Instantiations: {instantiation_data.instantiations}\n"
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
# Initialize and run the server
|
||||
mcp.run(transport="stdio")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user