From 2aae4bf0dbb47202725f3c70da092e6332df6355 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Tue, 9 Sep 2025 18:11:02 -0700 Subject: [PATCH] feat: add resync tool endpoint (#2812) Co-authored-by: Jin Peng --- fern/openapi-overrides.yml | 3 + fern/openapi.json | 55 ++++++++ letta/functions/schema_generator.py | 2 +- letta/schemas/mcp.py | 10 +- letta/server/rest_api/routers/v1/tools.py | 42 ++++++ letta/services/mcp_manager.py | 158 ++++++++++++++++++---- tests/test_managers.py | 111 +++++++++++++++ 7 files changed, 353 insertions(+), 28 deletions(-) diff --git a/fern/openapi-overrides.yml b/fern/openapi-overrides.yml index 8c7d73a0..f91dd8aa 100644 --- a/fern/openapi-overrides.yml +++ b/fern/openapi-overrides.yml @@ -60,6 +60,9 @@ paths: /v1/tools/mcp/servers/{mcp_server_name}/{mcp_tool_name}: post: summary: "Add MCP Tool" + /v1/tools/mcp/servers/{mcp_server_name}/resync: + post: + x-fern-ignore: true /v1/tools/mcp/servers/{mcp_server_name}: patch: summary: "Update MCP Server" diff --git a/fern/openapi.json b/fern/openapi.json index f54701d9..721f0724 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -1005,6 +1005,61 @@ } } }, + "/v1/tools/mcp/servers/{mcp_server_name}/resync": { + "post": { + "tags": ["tools"], + "summary": "Resync Mcp Server Tools", + "description": "Resync tools for an MCP server by:\n1. Fetching current tools from the MCP server\n2. Deleting tools that no longer exist on the server\n3. Updating schemas for existing tools\n4. Adding new tools from the server\n\nReturns a summary of changes made.", + "operationId": "resync_mcp_server_tools", + "parameters": [ + { + "name": "mcp_server_name", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Name" + } + }, + { + "name": "agent_id", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Agent Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/tools/mcp/servers/{mcp_server_name}/{mcp_tool_name}": { "post": { "tags": ["tools"], diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 8c657f43..545f8873 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -622,7 +622,7 @@ def generate_tool_schema_for_mcp( format_value = option["format"] if types: # Deduplicate types using set - field_props["type"] = list(set(types)) + field_props["type"] = list(dict.fromkeys(types)) # Only add format if the field is not optional (doesn't have null type) if format_value and len(field_props["type"]) == 1 and "null" not in field_props["type"]: field_props["format"] = format_value diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index e49f177a..5412bc33 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union from pydantic import Field @@ -175,3 +175,11 @@ class MCPOAuthSessionUpdate(BaseMCPOAuth): client_secret: Optional[str] = Field(None, description="OAuth client secret") redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI") status: Optional[OAuthSessionStatus] = Field(None, description="Session status") + + +class MCPServerResyncResult(LettaBase): + """Result of resyncing MCP server tools.""" + + deleted: List[str] = Field(default_factory=list, description="List of deleted tool names") + updated: List[str] = Field(default_factory=list, description="List of updated tool names") + added: List[str] = Field(default_factory=list, description="List of added tool names") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index efd03b0b..0838838d 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -587,6 +587,48 @@ async def list_mcp_tools_by_server( return mcp_tools +@router.post("/mcp/servers/{mcp_server_name}/resync", operation_id="resync_mcp_server_tools") +async def resync_mcp_server_tools( + mcp_server_name: str, + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), + agent_id: Optional[str] = None, +): + """ + Resync tools for an MCP server by: + 1. Fetching current tools from the MCP server + 2. Deleting tools that no longer exist on the server + 3. Updating schemas for existing tools + 4. Adding new tools from the server + + Returns a summary of changes made. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + try: + result = await server.mcp_manager.resync_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor, agent_id=agent_id) + return result + except ValueError as e: + raise HTTPException( + status_code=404, + detail={ + "code": "MCPServerNotFoundError", + "message": str(e), + "mcp_server_name": mcp_server_name, + }, + ) + except Exception as e: + logger.error(f"Unexpected error refreshing MCP server tools: {e}") + raise HTTPException( + status_code=404, + detail={ + "code": "MCPRefreshError", + "message": f"Failed to refresh MCP server tools: {str(e)}", + "mcp_server_name": mcp_server_name, + }, + ) + + @router.post("/mcp/servers/{mcp_server_name}/{mcp_tool_name}", response_model=Tool, operation_id="add_mcp_tool") async def add_mcp_tool( mcp_server_name: str, diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 775181c9..8668984a 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Tuple, Union from fastapi import HTTPException -from sqlalchemy import delete, null +from sqlalchemy import delete, desc, null, select from starlette.requests import Request import letta.constants as constants @@ -23,17 +23,19 @@ from letta.log import get_logger from letta.orm.errors import NoResultFound from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus from letta.orm.mcp_server import MCPServer as MCPServerModel +from letta.orm.tool import Tool as ToolModel from letta.schemas.mcp import ( MCPOAuthSession, MCPOAuthSessionCreate, MCPOAuthSessionUpdate, MCPServer, + MCPServerResyncResult, UpdateMCPServer, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer, ) -from letta.schemas.tool import Tool as PydanticTool, ToolCreate +from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient @@ -147,6 +149,117 @@ class MCPManager: # failed to add - handle error? return None + @enforce_types + async def resync_mcp_server_tools( + self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None + ) -> MCPServerResyncResult: + """ + Resync tools for an MCP server by: + 1. Fetching current tools from the MCP server + 2. Deleting tools that no longer exist on the server + 3. Updating schemas for existing tools + 4. Adding new tools from the server + + Returns a result with: + - deleted: List of deleted tool names + - updated: List of updated tool names + - added: List of added tool names + """ + # Get the MCP server ID + mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor) + if not mcp_server_id: + raise ValueError(f"MCP server '{mcp_server_name}' not found") + + # Fetch current tools from MCP server + try: + current_mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor, agent_id=agent_id) + except Exception as e: + logger.error(f"Failed to fetch tools from MCP server {mcp_server_name}: {e}") + raise HTTPException( + status_code=404, + detail={ + "code": "MCPServerUnavailable", + "message": f"Could not connect to MCP server {mcp_server_name} to resync tools", + "error": str(e), + }, + ) + + # Get all persisted tools for this MCP server + async with db_registry.async_session() as session: + # Query for tools with MCP metadata matching this server + # Using JSON path query to filter by metadata + persisted_tools = await ToolModel.list_async( + db_session=session, + organization_id=actor.organization_id, + ) + + # Filter tools that belong to this MCP server + mcp_tools = [] + for tool in persisted_tools: + if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_: + if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id: + mcp_tools.append(tool) + + # Create maps for easier comparison + current_tool_map = {tool.name: tool for tool in current_mcp_tools} + persisted_tool_map = {tool.name: tool for tool in mcp_tools} + + deleted_tools = [] + updated_tools = [] + added_tools = [] + + # 1. Delete tools that no longer exist on the server + for tool_name, persisted_tool in persisted_tool_map.items(): + if tool_name not in current_tool_map: + # Delete the tool (cascade will handle agent detachment) + await persisted_tool.hard_delete_async(db_session=session, actor=actor) + deleted_tools.append(tool_name) + logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}") + + # Commit deletions + await session.commit() + + # 2. Update existing tools and add new tools + for tool_name, current_tool in current_tool_map.items(): + if tool_name in persisted_tool_map: + # Update existing tool + persisted_tool = persisted_tool_map[tool_name] + tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool) + + # Check if schema has changed + if persisted_tool.json_schema != tool_create.json_schema: + # Update the tool + update_data = ToolUpdate( + description=tool_create.description, + json_schema=tool_create.json_schema, + source_code=tool_create.source_code, + ) + + await self.tool_manager.update_tool_by_id_async(tool_id=persisted_tool.id, tool_update=update_data, actor=actor) + updated_tools.append(tool_name) + logger.info(f"Updated MCP tool {tool_name} with new schema from server {mcp_server_name}") + else: + # Add new tool + # Skip INVALID tools + if current_tool.health and current_tool.health.status == "INVALID": + logger.warning( + f"Skipping invalid tool {tool_name} from MCP server {mcp_server_name}: {', '.join(current_tool.health.reasons)}" + ) + continue + + tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool) + await self.tool_manager.create_mcp_tool_async( + tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor + ) + added_tools.append(tool_name) + logger.info(f"Added new MCP tool {tool_name} from server {mcp_server_name}") + + return MCPServerResyncResult( + deleted=deleted_tools, + updated=updated_tools, + added=added_tools, + ) + @enforce_types async def list_mcp_servers(self, actor: PydanticUser) -> List[MCPServer]: """List all MCP servers available""" @@ -209,8 +322,6 @@ class MCPManager: # This ensures OAuth sessions created during testing get linked to the server server_url = getattr(mcp_server, "server_url", None) if server_url: - from sqlalchemy import select - result = await session.execute( select(MCPOAuth).where( MCPOAuth.server_url == server_url, @@ -326,26 +437,9 @@ class MCPManager: ) return mcp_server.to_pydantic() - # @enforce_types - # async def delete_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> None: - # """Delete an existing tool.""" - # with db_registry.session() as session: - # mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) - # mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) - # if not mcp_server: - # raise HTTPException( - # status_code=404, # Not Found - # detail={ - # "code": "MCPServerNotFoundError", - # "message": f"MCP server {mcp_server_name} not found", - # "mcp_server_name": mcp_server_name, - # }, - # ) - # mcp_server.delete(session, actor=actor) # Re-raise other database-related errors - @enforce_types async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None: - """Delete a MCP server by its ID.""" + """Delete a MCP server by its ID and associated tools and OAuth sessions.""" async with db_registry.async_session() as session: try: mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) @@ -353,6 +447,22 @@ class MCPManager: raise NoResultFound(f"MCP server with id {mcp_server_id} not found.") server_url = getattr(mcp_server, "server_url", None) + # Get all tools with matching metadata + stmt = select(ToolModel).where(ToolModel.organization_id == actor.organization_id) + result = await session.execute(stmt) + all_tools = result.scalars().all() + + # Filter and delete tools that belong to this MCP server + tools_deleted = 0 + for tool in all_tools: + if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_: + if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id: + await tool.hard_delete_async(db_session=session, actor=actor) + tools_deleted = 1 + logger.info(f"Deleted MCP tool {tool.name} associated with MCP server {mcp_server_id}") + + if tools_deleted > 0: + logger.info(f"Deleted {tools_deleted} MCP tools associated with MCP server {mcp_server_id}") # Delete OAuth sessions for the same user and server URL in the same transaction # This handles orphaned sessions that were created during testing/connection @@ -557,8 +667,6 @@ class MCPManager: @enforce_types async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: """Get the latest OAuth session by server URL, organization, and user.""" - from sqlalchemy import desc, select - async with db_registry.async_session() as session: # Query for OAuth session matching organization, user, server URL, and status # Order by updated_at desc to get the most recent record @@ -673,8 +781,6 @@ class MCPManager: cutoff_time = datetime.now() - timedelta(hours=max_age_hours) async with db_registry.async_session() as session: - from sqlalchemy import select - # Find expired sessions result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time)) expired_sessions = result.scalars().all() diff --git a/tests/test_managers.py b/tests/test_managers.py index 73882f49..a751eec7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -10931,6 +10931,117 @@ async def test_mcp_server_delete_removes_all_sessions_for_url_and_user(server, d assert await server.mcp_manager.get_oauth_session_by_id(orphaned.id, actor=default_user) is None +@pytest.mark.asyncio +async def test_mcp_server_resync_tools(server, default_user, default_organization): + """Test that resyncing MCP server tools correctly handles added, deleted, and updated tools.""" + from unittest.mock import AsyncMock, MagicMock, patch + + from letta.functions.mcp_client.types import MCPTool, MCPToolHealth + from letta.schemas.mcp import MCPServer as PydanticMCPServer, MCPServerType + from letta.schemas.tool import ToolCreate + + # Create MCP server + mcp_server = await server.mcp_manager.create_mcp_server( + PydanticMCPServer( + server_name=f"test_resync_{uuid.uuid4().hex[:8]}", + server_type=MCPServerType.SSE, + server_url="https://test-resync.example.com/mcp", + organization_id=default_organization.id, + ), + actor=default_user, + ) + mcp_server_id = mcp_server.id + + try: + # Create initial persisted tools (simulating previously added tools) + # Use sync method like in the existing mcp_tool fixture + tool1_create = ToolCreate.from_mcp( + mcp_server_name=mcp_server.server_name, + mcp_tool=MCPTool( + name="tool1", + description="Tool 1", + inputSchema={"type": "object", "properties": {"param1": {"type": "string"}}}, + ), + ) + tool1 = server.tool_manager.create_or_update_mcp_tool( + tool_create=tool1_create, + mcp_server_name=mcp_server.server_name, + mcp_server_id=mcp_server_id, + actor=default_user, + ) + + tool2_create = ToolCreate.from_mcp( + mcp_server_name=mcp_server.server_name, + mcp_tool=MCPTool( + name="tool2", + description="Tool 2 to be deleted", + inputSchema={"type": "object", "properties": {"param2": {"type": "number"}}}, + ), + ) + tool2 = server.tool_manager.create_or_update_mcp_tool( + tool_create=tool2_create, + mcp_server_name=mcp_server.server_name, + mcp_server_id=mcp_server_id, + actor=default_user, + ) + + # Mock the list_mcp_server_tools to return updated tools from server + # tool1 is updated, tool2 is deleted, tool3 is added + updated_tools = [ + MCPTool( + name="tool1", + description="Tool 1 Updated", + inputSchema={"type": "object", "properties": {"param1": {"type": "string"}, "param1b": {"type": "boolean"}}}, + health=MCPToolHealth(status="VALID", reasons=[]), + ), + MCPTool( + name="tool3", + description="Tool 3 New", + inputSchema={"type": "object", "properties": {"param3": {"type": "array"}}}, + health=MCPToolHealth(status="VALID", reasons=[]), + ), + ] + + with patch.object(server.mcp_manager, "list_mcp_server_tools", new_callable=AsyncMock) as mock_list_tools: + mock_list_tools.return_value = updated_tools + + # Run resync + result = await server.mcp_manager.resync_mcp_server_tools( + mcp_server_name=mcp_server.server_name, + actor=default_user, + ) + + # Verify the resync result + assert len(result.deleted) == 1 + assert "tool2" in result.deleted + + assert len(result.updated) == 1 + assert "tool1" in result.updated + + assert len(result.added) == 1 + assert "tool3" in result.added + + # Verify tool2 was actually deleted + try: + deleted_tool = server.tool_manager.get_tool_by_id(tool_id=tool2.id, actor=default_user) + assert False, "Tool2 should have been deleted" + except Exception: + pass # Expected - tool should be deleted + + # Verify tool1 was updated with new schema + updated_tool1 = server.tool_manager.get_tool_by_id(tool_id=tool1.id, actor=default_user) + assert "param1b" in updated_tool1.json_schema["parameters"]["properties"] + + # Verify tool3 was added + tools = await server.tool_manager.list_tools_async(actor=default_user, names=["tool3"]) + assert len(tools) == 1 + assert tools[0].name == "tool3" + + finally: + # Clean up + await server.mcp_manager.delete_mcp_server_by_id(mcp_server_id, actor=default_user) + + # ====================================================================================================================== # FileAgent Tests # ======================================================================================================================