feat: add resync tool endpoint (#2812)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -60,6 +60,9 @@ paths:
|
||||
/v1/tools/mcp/servers/{mcp_server_name}/{mcp_tool_name}:
|
||||
post:
|
||||
summary: "Add MCP Tool"
|
||||
/v1/tools/mcp/servers/{mcp_server_name}/resync:
|
||||
post:
|
||||
x-fern-ignore: true
|
||||
/v1/tools/mcp/servers/{mcp_server_name}:
|
||||
patch:
|
||||
summary: "Update MCP Server"
|
||||
|
||||
@@ -1005,6 +1005,61 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/tools/mcp/servers/{mcp_server_name}/resync": {
|
||||
"post": {
|
||||
"tags": ["tools"],
|
||||
"summary": "Resync Mcp Server Tools",
|
||||
"description": "Resync tools for an MCP server by:\n1. Fetching current tools from the MCP server\n2. Deleting tools that no longer exist on the server\n3. Updating schemas for existing tools\n4. Adding new tools from the server\n\nReturns a summary of changes made.",
|
||||
"operationId": "resync_mcp_server_tools",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "mcp_server_name",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"title": "Mcp Server Name"
|
||||
}
|
||||
},
|
||||
{
|
||||
"name": "agent_id",
|
||||
"in": "query",
|
||||
"required": false,
|
||||
"schema": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Agent Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/tools/mcp/servers/{mcp_server_name}/{mcp_tool_name}": {
|
||||
"post": {
|
||||
"tags": ["tools"],
|
||||
|
||||
@@ -622,7 +622,7 @@ def generate_tool_schema_for_mcp(
|
||||
format_value = option["format"]
|
||||
if types:
|
||||
# Deduplicate types using set
|
||||
field_props["type"] = list(set(types))
|
||||
field_props["type"] = list(dict.fromkeys(types))
|
||||
# Only add format if the field is not optional (doesn't have null type)
|
||||
if format_value and len(field_props["type"]) == 1 and "null" not in field_props["type"]:
|
||||
field_props["format"] = format_value
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -175,3 +175,11 @@ class MCPOAuthSessionUpdate(BaseMCPOAuth):
|
||||
client_secret: Optional[str] = Field(None, description="OAuth client secret")
|
||||
redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI")
|
||||
status: Optional[OAuthSessionStatus] = Field(None, description="Session status")
|
||||
|
||||
|
||||
class MCPServerResyncResult(LettaBase):
|
||||
"""Result of resyncing MCP server tools."""
|
||||
|
||||
deleted: List[str] = Field(default_factory=list, description="List of deleted tool names")
|
||||
updated: List[str] = Field(default_factory=list, description="List of updated tool names")
|
||||
added: List[str] = Field(default_factory=list, description="List of added tool names")
|
||||
|
||||
@@ -587,6 +587,48 @@ async def list_mcp_tools_by_server(
|
||||
return mcp_tools
|
||||
|
||||
|
||||
@router.post("/mcp/servers/{mcp_server_name}/resync", operation_id="resync_mcp_server_tools")
|
||||
async def resync_mcp_server_tools(
|
||||
mcp_server_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Resync tools for an MCP server by:
|
||||
1. Fetching current tools from the MCP server
|
||||
2. Deleting tools that no longer exist on the server
|
||||
3. Updating schemas for existing tools
|
||||
4. Adding new tools from the server
|
||||
|
||||
Returns a summary of changes made.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id)
|
||||
return result
|
||||
except ValueError as e:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error refreshing MCP server tools: {e}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPRefreshError",
|
||||
"message": f"Failed to refresh MCP server tools: {str(e)}",
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.post("/mcp/servers/{mcp_server_name}/{mcp_tool_name}", response_model=Tool, operation_id="add_mcp_tool")
|
||||
async def add_mcp_tool(
|
||||
mcp_server_name: str,
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import delete, null
|
||||
from sqlalchemy import delete, desc, null, select
|
||||
from starlette.requests import Request
|
||||
|
||||
import letta.constants as constants
|
||||
@@ -23,17 +23,19 @@ from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
|
||||
from letta.orm.mcp_server import MCPServer as MCPServerModel
|
||||
from letta.orm.tool import Tool as ToolModel
|
||||
from letta.schemas.mcp import (
|
||||
MCPOAuthSession,
|
||||
MCPOAuthSessionCreate,
|
||||
MCPOAuthSessionUpdate,
|
||||
MCPServer,
|
||||
MCPServerResyncResult,
|
||||
UpdateMCPServer,
|
||||
UpdateSSEMCPServer,
|
||||
UpdateStdioMCPServer,
|
||||
UpdateStreamableHTTPMCPServer,
|
||||
)
|
||||
from letta.schemas.tool import Tool as PydanticTool, ToolCreate
|
||||
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||
@@ -147,6 +149,117 @@ class MCPManager:
|
||||
# failed to add - handle error?
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def resync_mcp_server_tools(
|
||||
self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None
|
||||
) -> MCPServerResyncResult:
|
||||
"""
|
||||
Resync tools for an MCP server by:
|
||||
1. Fetching current tools from the MCP server
|
||||
2. Deleting tools that no longer exist on the server
|
||||
3. Updating schemas for existing tools
|
||||
4. Adding new tools from the server
|
||||
|
||||
Returns a result with:
|
||||
- deleted: List of deleted tool names
|
||||
- updated: List of updated tool names
|
||||
- added: List of added tool names
|
||||
"""
|
||||
# Get the MCP server ID
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
if not mcp_server_id:
|
||||
raise ValueError(f"MCP server '{mcp_server_name}' not found")
|
||||
|
||||
# Fetch current tools from MCP server
|
||||
try:
|
||||
current_mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor, agent_id=agent_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to fetch tools from MCP server {mcp_server_name}: {e}")
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"code": "MCPServerUnavailable",
|
||||
"message": f"Could not connect to MCP server {mcp_server_name} to resync tools",
|
||||
"error": str(e),
|
||||
},
|
||||
)
|
||||
|
||||
# Get all persisted tools for this MCP server
|
||||
async with db_registry.async_session() as session:
|
||||
# Query for tools with MCP metadata matching this server
|
||||
# Using JSON path query to filter by metadata
|
||||
persisted_tools = await ToolModel.list_async(
|
||||
db_session=session,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
|
||||
# Filter tools that belong to this MCP server
|
||||
mcp_tools = []
|
||||
for tool in persisted_tools:
|
||||
if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
|
||||
if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id:
|
||||
mcp_tools.append(tool)
|
||||
|
||||
# Create maps for easier comparison
|
||||
current_tool_map = {tool.name: tool for tool in current_mcp_tools}
|
||||
persisted_tool_map = {tool.name: tool for tool in mcp_tools}
|
||||
|
||||
deleted_tools = []
|
||||
updated_tools = []
|
||||
added_tools = []
|
||||
|
||||
# 1. Delete tools that no longer exist on the server
|
||||
for tool_name, persisted_tool in persisted_tool_map.items():
|
||||
if tool_name not in current_tool_map:
|
||||
# Delete the tool (cascade will handle agent detachment)
|
||||
await persisted_tool.hard_delete_async(db_session=session, actor=actor)
|
||||
deleted_tools.append(tool_name)
|
||||
logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}")
|
||||
|
||||
# Commit deletions
|
||||
await session.commit()
|
||||
|
||||
# 2. Update existing tools and add new tools
|
||||
for tool_name, current_tool in current_tool_map.items():
|
||||
if tool_name in persisted_tool_map:
|
||||
# Update existing tool
|
||||
persisted_tool = persisted_tool_map[tool_name]
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool)
|
||||
|
||||
# Check if schema has changed
|
||||
if persisted_tool.json_schema != tool_create.json_schema:
|
||||
# Update the tool
|
||||
update_data = ToolUpdate(
|
||||
description=tool_create.description,
|
||||
json_schema=tool_create.json_schema,
|
||||
source_code=tool_create.source_code,
|
||||
)
|
||||
|
||||
await self.tool_manager.update_tool_by_id_async(tool_id=persisted_tool.id, tool_update=update_data, actor=actor)
|
||||
updated_tools.append(tool_name)
|
||||
logger.info(f"Updated MCP tool {tool_name} with new schema from server {mcp_server_name}")
|
||||
else:
|
||||
# Add new tool
|
||||
# Skip INVALID tools
|
||||
if current_tool.health and current_tool.health.status == "INVALID":
|
||||
logger.warning(
|
||||
f"Skipping invalid tool {tool_name} from MCP server {mcp_server_name}: {', '.join(current_tool.health.reasons)}"
|
||||
)
|
||||
continue
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool)
|
||||
await self.tool_manager.create_mcp_tool_async(
|
||||
tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor
|
||||
)
|
||||
added_tools.append(tool_name)
|
||||
logger.info(f"Added new MCP tool {tool_name} from server {mcp_server_name}")
|
||||
|
||||
return MCPServerResyncResult(
|
||||
deleted=deleted_tools,
|
||||
updated=updated_tools,
|
||||
added=added_tools,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
async def list_mcp_servers(self, actor: PydanticUser) -> List[MCPServer]:
|
||||
"""List all MCP servers available"""
|
||||
@@ -209,8 +322,6 @@ class MCPManager:
|
||||
# This ensures OAuth sessions created during testing get linked to the server
|
||||
server_url = getattr(mcp_server, "server_url", None)
|
||||
if server_url:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(
|
||||
select(MCPOAuth).where(
|
||||
MCPOAuth.server_url == server_url,
|
||||
@@ -326,26 +437,9 @@ class MCPManager:
|
||||
)
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
# @enforce_types
|
||||
# async def delete_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> None:
|
||||
# """Delete an existing tool."""
|
||||
# with db_registry.session() as session:
|
||||
# mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||
# mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
# if not mcp_server:
|
||||
# raise HTTPException(
|
||||
# status_code=404, # Not Found
|
||||
# detail={
|
||||
# "code": "MCPServerNotFoundError",
|
||||
# "message": f"MCP server {mcp_server_name} not found",
|
||||
# "mcp_server_name": mcp_server_name,
|
||||
# },
|
||||
# )
|
||||
# mcp_server.delete(session, actor=actor) # Re-raise other database-related errors
|
||||
|
||||
@enforce_types
|
||||
async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a MCP server by its ID."""
|
||||
"""Delete a MCP server by its ID and associated tools and OAuth sessions."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
@@ -353,6 +447,22 @@ class MCPManager:
|
||||
raise NoResultFound(f"MCP server with id {mcp_server_id} not found.")
|
||||
|
||||
server_url = getattr(mcp_server, "server_url", None)
|
||||
# Get all tools with matching metadata
|
||||
stmt = select(ToolModel).where(ToolModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(stmt)
|
||||
all_tools = result.scalars().all()
|
||||
|
||||
# Filter and delete tools that belong to this MCP server
|
||||
tools_deleted = 0
|
||||
for tool in all_tools:
|
||||
if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_:
|
||||
if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id:
|
||||
await tool.hard_delete_async(db_session=session, actor=actor)
|
||||
tools_deleted = 1
|
||||
logger.info(f"Deleted MCP tool {tool.name} associated with MCP server {mcp_server_id}")
|
||||
|
||||
if tools_deleted > 0:
|
||||
logger.info(f"Deleted {tools_deleted} MCP tools associated with MCP server {mcp_server_id}")
|
||||
|
||||
# Delete OAuth sessions for the same user and server URL in the same transaction
|
||||
# This handles orphaned sessions that were created during testing/connection
|
||||
@@ -557,8 +667,6 @@ class MCPManager:
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
"""Get the latest OAuth session by server URL, organization, and user."""
|
||||
from sqlalchemy import desc, select
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Query for OAuth session matching organization, user, server URL, and status
|
||||
# Order by updated_at desc to get the most recent record
|
||||
@@ -673,8 +781,6 @@ class MCPManager:
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
# Find expired sessions
|
||||
result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time))
|
||||
expired_sessions = result.scalars().all()
|
||||
|
||||
@@ -10931,6 +10931,117 @@ async def test_mcp_server_delete_removes_all_sessions_for_url_and_user(server, d
|
||||
assert await server.mcp_manager.get_oauth_session_by_id(orphaned.id, actor=default_user) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_server_resync_tools(server, default_user, default_organization):
|
||||
"""Test that resyncing MCP server tools correctly handles added, deleted, and updated tools."""
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
from letta.functions.mcp_client.types import MCPTool, MCPToolHealth
|
||||
from letta.schemas.mcp import MCPServer as PydanticMCPServer, MCPServerType
|
||||
from letta.schemas.tool import ToolCreate
|
||||
|
||||
# Create MCP server
|
||||
mcp_server = await server.mcp_manager.create_mcp_server(
|
||||
PydanticMCPServer(
|
||||
server_name=f"test_resync_{uuid.uuid4().hex[:8]}",
|
||||
server_type=MCPServerType.SSE,
|
||||
server_url="https://test-resync.example.com/mcp",
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
mcp_server_id = mcp_server.id
|
||||
|
||||
try:
|
||||
# Create initial persisted tools (simulating previously added tools)
|
||||
# Use sync method like in the existing mcp_tool fixture
|
||||
tool1_create = ToolCreate.from_mcp(
|
||||
mcp_server_name=mcp_server.server_name,
|
||||
mcp_tool=MCPTool(
|
||||
name="tool1",
|
||||
description="Tool 1",
|
||||
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}}},
|
||||
),
|
||||
)
|
||||
tool1 = server.tool_manager.create_or_update_mcp_tool(
|
||||
tool_create=tool1_create,
|
||||
mcp_server_name=mcp_server.server_name,
|
||||
mcp_server_id=mcp_server_id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
tool2_create = ToolCreate.from_mcp(
|
||||
mcp_server_name=mcp_server.server_name,
|
||||
mcp_tool=MCPTool(
|
||||
name="tool2",
|
||||
description="Tool 2 to be deleted",
|
||||
inputSchema={"type": "object", "properties": {"param2": {"type": "number"}}},
|
||||
),
|
||||
)
|
||||
tool2 = server.tool_manager.create_or_update_mcp_tool(
|
||||
tool_create=tool2_create,
|
||||
mcp_server_name=mcp_server.server_name,
|
||||
mcp_server_id=mcp_server_id,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Mock the list_mcp_server_tools to return updated tools from server
|
||||
# tool1 is updated, tool2 is deleted, tool3 is added
|
||||
updated_tools = [
|
||||
MCPTool(
|
||||
name="tool1",
|
||||
description="Tool 1 Updated",
|
||||
inputSchema={"type": "object", "properties": {"param1": {"type": "string"}, "param1b": {"type": "boolean"}}},
|
||||
health=MCPToolHealth(status="VALID", reasons=[]),
|
||||
),
|
||||
MCPTool(
|
||||
name="tool3",
|
||||
description="Tool 3 New",
|
||||
inputSchema={"type": "object", "properties": {"param3": {"type": "array"}}},
|
||||
health=MCPToolHealth(status="VALID", reasons=[]),
|
||||
),
|
||||
]
|
||||
|
||||
with patch.object(server.mcp_manager, "list_mcp_server_tools", new_callable=AsyncMock) as mock_list_tools:
|
||||
mock_list_tools.return_value = updated_tools
|
||||
|
||||
# Run resync
|
||||
result = await server.mcp_manager.resync_mcp_server_tools(
|
||||
mcp_server_name=mcp_server.server_name,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify the resync result
|
||||
assert len(result.deleted) == 1
|
||||
assert "tool2" in result.deleted
|
||||
|
||||
assert len(result.updated) == 1
|
||||
assert "tool1" in result.updated
|
||||
|
||||
assert len(result.added) == 1
|
||||
assert "tool3" in result.added
|
||||
|
||||
# Verify tool2 was actually deleted
|
||||
try:
|
||||
deleted_tool = server.tool_manager.get_tool_by_id(tool_id=tool2.id, actor=default_user)
|
||||
assert False, "Tool2 should have been deleted"
|
||||
except Exception:
|
||||
pass # Expected - tool should be deleted
|
||||
|
||||
# Verify tool1 was updated with new schema
|
||||
updated_tool1 = server.tool_manager.get_tool_by_id(tool_id=tool1.id, actor=default_user)
|
||||
assert "param1b" in updated_tool1.json_schema["parameters"]["properties"]
|
||||
|
||||
# Verify tool3 was added
|
||||
tools = await server.tool_manager.list_tools_async(actor=default_user, names=["tool3"])
|
||||
assert len(tools) == 1
|
||||
assert tools[0].name == "tool3"
|
||||
|
||||
finally:
|
||||
# Clean up
|
||||
await server.mcp_manager.delete_mcp_server_by_id(mcp_server_id, actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# FileAgent Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user