feat: auto register mcp server tools as letta tools (#2847)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -10691,6 +10691,179 @@ async def test_create_mcp_server(mock_get_client, server, default_user):
|
||||
print("TAGS", tool.tags)
|
||||
|
||||
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_with_tools(mock_get_client, server, default_user):
|
||||
"""Test that creating an MCP server automatically syncs and persists its tools."""
|
||||
from letta.functions.mcp_client.types import MCPToolHealth
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# Create mock tools with different health statuses
|
||||
mock_tools = [
|
||||
MCPTool(
|
||||
name="valid_tool_1",
|
||||
description="A valid tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {"type": "string"},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
health=MCPToolHealth(status="VALID", reasons=[]),
|
||||
),
|
||||
MCPTool(
|
||||
name="valid_tool_2",
|
||||
description="Another valid tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param2": {"type": "number"},
|
||||
},
|
||||
},
|
||||
health=MCPToolHealth(status="VALID", reasons=[]),
|
||||
),
|
||||
MCPTool(
|
||||
name="invalid_tool",
|
||||
description="An invalid tool that should be skipped",
|
||||
inputSchema={
|
||||
"type": "invalid_type", # Invalid schema
|
||||
},
|
||||
health=MCPToolHealth(status="INVALID", reasons=["Invalid schema type"]),
|
||||
),
|
||||
MCPTool(
|
||||
name="warning_tool",
|
||||
description="A tool with warnings but should still be persisted",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
health=MCPToolHealth(status="WARNING", reasons=["No properties defined"]),
|
||||
),
|
||||
]
|
||||
|
||||
# Create mock client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.connect_to_server = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
mock_client.cleanup = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create MCP server config
|
||||
server_name = f"test_server_{uuid.uuid4().hex[:8]}"
|
||||
server_url = "https://test-with-tools.example.com/sse"
|
||||
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
|
||||
# Create server with tools using the new method
|
||||
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
|
||||
|
||||
# Verify server was created
|
||||
assert created_server.server_name == server_name
|
||||
assert created_server.server_type == MCPServerType.SSE
|
||||
assert created_server.server_url == server_url
|
||||
|
||||
# Verify tools were persisted (all except the invalid one)
|
||||
# Get all tools and filter by checking metadata
|
||||
all_tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool", "invalid_tool"]
|
||||
)
|
||||
|
||||
# Filter tools that belong to our MCP server
|
||||
persisted_tools = [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if tool.metadata_
|
||||
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
||||
]
|
||||
|
||||
# Should have 3 tools (2 valid + 1 warning, but not the invalid one)
|
||||
assert len(persisted_tools) == 3, f"Expected 3 tools, got {len(persisted_tools)}"
|
||||
|
||||
# Check tool names
|
||||
tool_names = {tool.name for tool in persisted_tools}
|
||||
assert "valid_tool_1" in tool_names
|
||||
assert "valid_tool_2" in tool_names
|
||||
assert "warning_tool" in tool_names
|
||||
assert "invalid_tool" not in tool_names # Invalid tool should be filtered out
|
||||
|
||||
# Verify each tool has correct metadata
|
||||
for tool in persisted_tools:
|
||||
assert tool.metadata_ is not None
|
||||
assert MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == server_name
|
||||
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == created_server.id
|
||||
assert tool.tool_type == ToolType.EXTERNAL_MCP
|
||||
|
||||
# Clean up - delete the server
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
# Verify tools were also deleted (cascade) by trying to get them again
|
||||
remaining_tools = await server.tool_manager.list_tools_async(actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool"])
|
||||
|
||||
# Filter to see if any still belong to our deleted server
|
||||
remaining_mcp_tools = [
|
||||
tool
|
||||
for tool in remaining_tools
|
||||
if tool.metadata_
|
||||
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
||||
]
|
||||
assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted"
|
||||
|
||||
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user):
|
||||
"""Test that MCP server creation succeeds even when tool sync fails (optimistic approach)."""
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# Create mock client that fails to connect
|
||||
mock_client = AsyncMock()
|
||||
mock_client.connect_to_server = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.cleanup = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create MCP server config
|
||||
server_name = f"test_server_fail_{uuid.uuid4().hex[:8]}"
|
||||
server_url = "https://test-fail.example.com/sse"
|
||||
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
|
||||
# Create server with tools - should succeed despite connection failure
|
||||
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
|
||||
|
||||
# Verify server was created successfully
|
||||
assert created_server.server_name == server_name
|
||||
assert created_server.server_type == MCPServerType.SSE
|
||||
assert created_server.server_url == server_url
|
||||
|
||||
# Verify no tools were persisted (due to connection failure)
|
||||
# Try to get tools by the names we would have expected
|
||||
all_tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user,
|
||||
names=["tool1", "tool2", "tool3"], # Generic names since we don't know what tools would have been listed
|
||||
)
|
||||
|
||||
# Filter to see if any belong to our server (there shouldn't be any)
|
||||
persisted_tools = [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if tool.metadata_
|
||||
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
||||
]
|
||||
assert len(persisted_tools) == 0, "No tools should be persisted when connection fails"
|
||||
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
|
||||
async def test_get_mcp_servers_by_ids(server, default_user):
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.settings import tool_settings
|
||||
|
||||
Reference in New Issue
Block a user