feat: auto register mcp server tools as letta tools (#2847)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -738,7 +738,8 @@ async def add_mcp_server_to_config(
|
||||
custom_headers=request.custom_headers,
|
||||
)
|
||||
|
||||
await server.mcp_manager.create_mcp_server(mapped_request, actor=actor)
|
||||
# Create MCP server and optimistically sync tools
|
||||
await server.mcp_manager.create_mcp_server_with_tools(mapped_request, actor=actor)
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
|
||||
@@ -90,7 +90,6 @@ class MCPManager:
|
||||
logger.warning(f"Error listing tools for MCP server {mcp_server_name}: {e}")
|
||||
raise e
|
||||
|
||||
|
||||
@enforce_types
|
||||
async def execute_mcp_server_tool(
|
||||
self,
|
||||
@@ -355,6 +354,62 @@ class MCPManager:
|
||||
logger.error(f"Failed to create MCP server: {e}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server_with_tools(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
||||
"""
|
||||
Create a new MCP server and optimistically sync its tools.
|
||||
|
||||
This method:
|
||||
1. Creates the MCP server record
|
||||
2. Attempts to connect and fetch tools
|
||||
3. Persists valid tools in parallel (best-effort)
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
# First, create the MCP server
|
||||
created_server = await self.create_mcp_server(pydantic_mcp_server, actor)
|
||||
|
||||
# Optimistically try to sync tools
|
||||
try:
|
||||
logger.info(f"Attempting to auto-sync tools from MCP server: {created_server.server_name}")
|
||||
|
||||
# List all tools from the MCP server
|
||||
mcp_tools = await self.list_mcp_server_tools(mcp_server_name=created_server.server_name, actor=actor)
|
||||
|
||||
# Filter out invalid tools
|
||||
valid_tools = [tool for tool in mcp_tools if not (tool.health and tool.health.status == "INVALID")]
|
||||
|
||||
# Register in parallel
|
||||
if valid_tools:
|
||||
tool_tasks = []
|
||||
for mcp_tool in valid_tools:
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=created_server.server_name, mcp_tool=mcp_tool)
|
||||
task = self.tool_manager.create_mcp_tool_async(
|
||||
tool_create=tool_create, mcp_server_name=created_server.server_name, mcp_server_id=created_server.id, actor=actor
|
||||
)
|
||||
tool_tasks.append(task)
|
||||
|
||||
results = await asyncio.gather(*tool_tasks, return_exceptions=True)
|
||||
|
||||
successful = sum(1 for r in results if not isinstance(r, Exception))
|
||||
failed = len(results) - successful
|
||||
logger.info(
|
||||
f"Auto-sync completed for MCP server {created_server.server_name}: "
|
||||
f"{successful} tools persisted, {failed} failed, "
|
||||
f"{len(mcp_tools) - len(valid_tools)} invalid tools skipped"
|
||||
)
|
||||
else:
|
||||
logger.info(f"No valid tools found to sync from MCP server {created_server.server_name}")
|
||||
|
||||
except Exception as e:
|
||||
# Log the error but don't fail the server creation
|
||||
logger.warning(
|
||||
f"Failed to auto-sync tools from MCP server {created_server.server_name}: {e}. "
|
||||
f"Server was created successfully but tools were not persisted."
|
||||
)
|
||||
|
||||
return created_server
|
||||
|
||||
@enforce_types
|
||||
async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
|
||||
@@ -10691,6 +10691,179 @@ async def test_create_mcp_server(mock_get_client, server, default_user):
|
||||
print("TAGS", tool.tags)
|
||||
|
||||
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_with_tools(mock_get_client, server, default_user):
|
||||
"""Test that creating an MCP server automatically syncs and persists its tools."""
|
||||
from letta.functions.mcp_client.types import MCPToolHealth
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# Create mock tools with different health statuses
|
||||
mock_tools = [
|
||||
MCPTool(
|
||||
name="valid_tool_1",
|
||||
description="A valid tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {"type": "string"},
|
||||
},
|
||||
"required": ["param1"],
|
||||
},
|
||||
health=MCPToolHealth(status="VALID", reasons=[]),
|
||||
),
|
||||
MCPTool(
|
||||
name="valid_tool_2",
|
||||
description="Another valid tool",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param2": {"type": "number"},
|
||||
},
|
||||
},
|
||||
health=MCPToolHealth(status="VALID", reasons=[]),
|
||||
),
|
||||
MCPTool(
|
||||
name="invalid_tool",
|
||||
description="An invalid tool that should be skipped",
|
||||
inputSchema={
|
||||
"type": "invalid_type", # Invalid schema
|
||||
},
|
||||
health=MCPToolHealth(status="INVALID", reasons=["Invalid schema type"]),
|
||||
),
|
||||
MCPTool(
|
||||
name="warning_tool",
|
||||
description="A tool with warnings but should still be persisted",
|
||||
inputSchema={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
},
|
||||
health=MCPToolHealth(status="WARNING", reasons=["No properties defined"]),
|
||||
),
|
||||
]
|
||||
|
||||
# Create mock client
|
||||
mock_client = AsyncMock()
|
||||
mock_client.connect_to_server = AsyncMock()
|
||||
mock_client.list_tools = AsyncMock(return_value=mock_tools)
|
||||
mock_client.cleanup = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create MCP server config
|
||||
server_name = f"test_server_{uuid.uuid4().hex[:8]}"
|
||||
server_url = "https://test-with-tools.example.com/sse"
|
||||
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
|
||||
# Create server with tools using the new method
|
||||
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
|
||||
|
||||
# Verify server was created
|
||||
assert created_server.server_name == server_name
|
||||
assert created_server.server_type == MCPServerType.SSE
|
||||
assert created_server.server_url == server_url
|
||||
|
||||
# Verify tools were persisted (all except the invalid one)
|
||||
# Get all tools and filter by checking metadata
|
||||
all_tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool", "invalid_tool"]
|
||||
)
|
||||
|
||||
# Filter tools that belong to our MCP server
|
||||
persisted_tools = [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if tool.metadata_
|
||||
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
||||
]
|
||||
|
||||
# Should have 3 tools (2 valid + 1 warning, but not the invalid one)
|
||||
assert len(persisted_tools) == 3, f"Expected 3 tools, got {len(persisted_tools)}"
|
||||
|
||||
# Check tool names
|
||||
tool_names = {tool.name for tool in persisted_tools}
|
||||
assert "valid_tool_1" in tool_names
|
||||
assert "valid_tool_2" in tool_names
|
||||
assert "warning_tool" in tool_names
|
||||
assert "invalid_tool" not in tool_names # Invalid tool should be filtered out
|
||||
|
||||
# Verify each tool has correct metadata
|
||||
for tool in persisted_tools:
|
||||
assert tool.metadata_ is not None
|
||||
assert MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_name"] == server_name
|
||||
assert tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX]["server_id"] == created_server.id
|
||||
assert tool.tool_type == ToolType.EXTERNAL_MCP
|
||||
|
||||
# Clean up - delete the server
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
# Verify tools were also deleted (cascade) by trying to get them again
|
||||
remaining_tools = await server.tool_manager.list_tools_async(actor=default_user, names=["valid_tool_1", "valid_tool_2", "warning_tool"])
|
||||
|
||||
# Filter to see if any still belong to our deleted server
|
||||
remaining_mcp_tools = [
|
||||
tool
|
||||
for tool in remaining_tools
|
||||
if tool.metadata_
|
||||
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
||||
]
|
||||
assert len(remaining_mcp_tools) == 0, "Tools should be deleted when server is deleted"
|
||||
|
||||
|
||||
@patch("letta.services.mcp_manager.MCPManager.get_mcp_client")
|
||||
async def test_create_mcp_server_with_tools_connection_failure(mock_get_client, server, default_user):
|
||||
"""Test that MCP server creation succeeds even when tool sync fails (optimistic approach)."""
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# Create mock client that fails to connect
|
||||
mock_client = AsyncMock()
|
||||
mock_client.connect_to_server = AsyncMock(side_effect=Exception("Connection failed"))
|
||||
mock_client.cleanup = AsyncMock()
|
||||
mock_get_client.return_value = mock_client
|
||||
|
||||
# Create MCP server config
|
||||
server_name = f"test_server_fail_{uuid.uuid4().hex[:8]}"
|
||||
server_url = "https://test-fail.example.com/sse"
|
||||
mcp_server = MCPServer(server_name=server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
|
||||
# Create server with tools - should succeed despite connection failure
|
||||
created_server = await server.mcp_manager.create_mcp_server_with_tools(mcp_server, actor=default_user)
|
||||
|
||||
# Verify server was created successfully
|
||||
assert created_server.server_name == server_name
|
||||
assert created_server.server_type == MCPServerType.SSE
|
||||
assert created_server.server_url == server_url
|
||||
|
||||
# Verify no tools were persisted (due to connection failure)
|
||||
# Try to get tools by the names we would have expected
|
||||
all_tools = await server.tool_manager.list_tools_async(
|
||||
actor=default_user,
|
||||
names=["tool1", "tool2", "tool3"], # Generic names since we don't know what tools would have been listed
|
||||
)
|
||||
|
||||
# Filter to see if any belong to our server (there shouldn't be any)
|
||||
persisted_tools = [
|
||||
tool
|
||||
for tool in all_tools
|
||||
if tool.metadata_
|
||||
and MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_
|
||||
and tool.metadata_[MCP_TOOL_TAG_NAME_PREFIX].get("server_name") == server_name
|
||||
]
|
||||
assert len(persisted_tools) == 0, "No tools should be persisted when connection fails"
|
||||
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
|
||||
async def test_get_mcp_servers_by_ids(server, default_user):
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.settings import tool_settings
|
||||
|
||||
Reference in New Issue
Block a user