From 93e41bbf0a290ef6498150aadff0a5c5f69fd31d Mon Sep 17 00:00:00 2001 From: jnjpng Date: Tue, 12 Aug 2025 15:32:33 -0700 Subject: [PATCH] fix: delete associated mcp oauth sessions on delete and link existing mcp oauth sessions on create Co-authored-by: Jin Peng --- letta/server/rest_api/routers/v1/tools.py | 2 +- letta/services/mcp_manager.py | 92 ++++++++++-- tests/test_managers.py | 171 ++++++++++++++++++++++ 3 files changed, 252 insertions(+), 13 deletions(-) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 6d33aaa6..864bbbf7 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -608,7 +608,7 @@ async def delete_mcp_server_from_config( actor_id: Optional[str] = Header(None, alias="user_id"), ): """ - Add a new MCP server to the Letta MCP server config + Delete a MCP server configuration """ if tool_settings.mcp_read_from_config: # write to config file diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index fffc90d9..9684e3c8 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -6,7 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Tuple, Union from fastapi import HTTPException -from sqlalchemy import null +from sqlalchemy import delete, null from starlette.requests import Request import letta.constants as constants @@ -169,17 +169,50 @@ class MCPManager: async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: """Create a new MCP server.""" async with db_registry.async_session() as session: - # Set the organization id at the ORM layer - pydantic_mcp_server.organization_id = actor.organization_id - mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) + try: + # Set the organization id at the ORM layer + pydantic_mcp_server.organization_id = actor.organization_id + mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) - # Ensure custom_headers None is stored as SQL NULL, not JSON null - if mcp_server_data.get("custom_headers") is None: - mcp_server_data.pop("custom_headers", None) + # Ensure custom_headers None is stored as SQL NULL, not JSON null + if mcp_server_data.get("custom_headers") is None: + mcp_server_data.pop("custom_headers", None) - mcp_server = MCPServerModel(**mcp_server_data) - mcp_server = await mcp_server.create_async(session, actor=actor) - return mcp_server.to_pydantic() + mcp_server = MCPServerModel(**mcp_server_data) + mcp_server = await mcp_server.create_async(session, actor=actor, no_commit=True) + + # Link existing OAuth sessions for the same user and server URL + # This ensures OAuth sessions created during testing get linked to the server + server_url = getattr(mcp_server, "server_url", None) + if server_url: + from sqlalchemy import select + + result = await session.execute( + select(MCPOAuth).where( + MCPOAuth.server_url == server_url, + MCPOAuth.organization_id == actor.organization_id, + MCPOAuth.user_id == actor.id, # Only link sessions for the same user + MCPOAuth.server_id.is_(None), # Only update sessions not already linked + ) + ) + oauth_sessions = result.scalars().all() + + # TODO: @jnjpng we should upate sessions in bulk + for oauth_session in oauth_sessions: + oauth_session.server_id = mcp_server.id + await oauth_session.update_async(db_session=session, actor=actor, no_commit=True) + + if oauth_sessions: + logger.info( + f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}" + ) + + await session.commit() + return mcp_server.to_pydantic() + except Exception as e: + await session.rollback() + logger.error(f"Failed to create MCP server: {e}") + raise @enforce_types async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer: @@ -286,13 +319,48 @@ class MCPManager: @enforce_types async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None: - """Delete a tool by its ID.""" + """Delete a MCP server by its ID.""" async with db_registry.async_session() as session: try: mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) - await mcp_server.hard_delete_async(db_session=session, actor=actor) + if not mcp_server: + raise NoResultFound(f"MCP server with id {mcp_server_id} not found.") + + server_url = getattr(mcp_server, "server_url", None) + + # Delete OAuth sessions for the same user and server URL in the same transaction + # This handles orphaned sessions that were created during testing/connection + oauth_count = 0 + if server_url: + result = await session.execute( + delete(MCPOAuth).where( + MCPOAuth.server_url == server_url, + MCPOAuth.organization_id == actor.organization_id, + MCPOAuth.user_id == actor.id, # Only delete sessions for the same user + ) + ) + oauth_count = result.rowcount + if oauth_count > 0: + logger.info( + f"Deleting {oauth_count} OAuth sessions for MCP server {mcp_server_id} (URL: {server_url}) for user {actor.id}" + ) + + # Delete the MCP server, will cascade delete to linked OAuth sessions + await session.execute( + delete(MCPServerModel).where( + MCPServerModel.id == mcp_server_id, + MCPServerModel.organization_id == actor.organization_id, + ) + ) + + await session.commit() except NoResultFound: + await session.rollback() raise ValueError(f"MCP server with id {mcp_server_id} not found.") + except Exception as e: + await session.rollback() + logger.error(f"Failed to delete MCP server {mcp_server_id}: {e}") + raise def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]]: mcp_server_list = {} diff --git a/tests/test_managers.py b/tests/test_managers.py index a6d6e4b9..ccd2802c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -9168,6 +9168,177 @@ async def test_get_mcp_servers_by_ids(server, default_user, event_loop): assert all(s.organization_id == default_user.organization_id for s in bulk_fetched) +# Additional MCPManager OAuth session tests +@pytest.mark.asyncio +async def test_mcp_server_deletion_cascades_oauth_sessions(server, default_organization, default_user, event_loop): + """Deleting an MCP server deletes associated OAuth sessions (same user + URL).""" + + from letta.schemas.mcp import MCPOAuthSessionCreate + from letta.schemas.mcp import MCPServer as PydanticMCPServer + from letta.schemas.mcp import MCPServerType + + test_server_url = "https://test.example.com/mcp" + + # Create orphaned OAuth sessions (no server id) for same user and URL + created_session_ids: list[str] = [] + for i in range(3): + session = await server.mcp_manager.create_oauth_session( + MCPOAuthSessionCreate( + server_url=test_server_url, + server_name=f"test_mcp_server_{i}", + user_id=default_user.id, + organization_id=default_organization.id, + ), + actor=default_user, + ) + created_session_ids.append(session.id) + + # Create the MCP server with the same URL + created_server = await server.mcp_manager.create_mcp_server( + PydanticMCPServer( + server_name=f"test_mcp_server_{") + str(uuid.uuid4().hex[:8]) + ("}", # ensure unique name + server_type=MCPServerType.SSE, + server_url=test_server_url, + organization_id=default_organization.id, + ), + actor=default_user, + ) + + # Now delete the server via manager + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + # Verify all sessions are gone + for sid in created_session_ids: + session = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user) + assert session is None, f"OAuth session {sid} should be deleted" + + +@pytest.mark.asyncio +async def test_oauth_sessions_with_different_url_persist(server, default_organization, default_user, event_loop): + """Sessions with different URL should not be deleted when deleting the server for another URL.""" + + from letta.schemas.mcp import MCPOAuthSessionCreate + from letta.schemas.mcp import MCPServer as PydanticMCPServer + from letta.schemas.mcp import MCPServerType + + server_url = "https://test.example.com/mcp" + other_url = "https://other.example.com/mcp" + + # Create a session for other_url (should persist) + other_session = await server.mcp_manager.create_oauth_session( + MCPOAuthSessionCreate( + server_url=other_url, + server_name="standalone_oauth", + user_id=default_user.id, + organization_id=default_organization.id, + ), + actor=default_user, + ) + + # Create the MCP server at server_url + created_server = await server.mcp_manager.create_mcp_server( + PydanticMCPServer( + server_name=f"test_mcp_server_{") + str(uuid.uuid4().hex[:8]) + ("}", + server_type=MCPServerType.SSE, + server_url=server_url, + organization_id=default_organization.id, + ), + actor=default_user, + ) + + # Delete the server at server_url + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + # Verify the session at other_url still exists + persisted = await server.mcp_manager.get_oauth_session_by_id(other_session.id, actor=default_user) + assert persisted is not None, "OAuth session with different URL should persist" + + +@pytest.mark.asyncio +async def test_mcp_server_creation_links_orphaned_sessions(server, default_organization, default_user, event_loop): + """Creating a server should link any existing orphaned sessions (same user + URL).""" + + from letta.schemas.mcp import MCPOAuthSessionCreate + from letta.schemas.mcp import MCPServer as PydanticMCPServer + from letta.schemas.mcp import MCPServerType + + server_url = "https://test-atomic-create.example.com/mcp" + + # Pre-create orphaned sessions (no server_id) for same user + URL + orphaned_ids: list[str] = [] + for i in range(3): + session = await server.mcp_manager.create_oauth_session( + MCPOAuthSessionCreate( + server_url=server_url, + server_name=f"atomic_session_{i}", + user_id=default_user.id, + organization_id=default_organization.id, + ), + actor=default_user, + ) + orphaned_ids.append(session.id) + + # Create server + created_server = await server.mcp_manager.create_mcp_server( + PydanticMCPServer( + server_name=f"test_atomic_server_{") + str(uuid.uuid4().hex[:8]) + ("}", + server_type=MCPServerType.SSE, + server_url=server_url, + organization_id=default_organization.id, + ), + actor=default_user, + ) + + # Sessions should still be retrievable via manager API + for sid in orphaned_ids: + s = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user) + assert s is not None + + # Indirect verification: deleting the server removes sessions for that URL+user + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + for sid in orphaned_ids: + assert await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user) is None + + +@pytest.mark.asyncio +async def test_mcp_server_delete_removes_all_sessions_for_url_and_user(server, default_organization, default_user, event_loop): + """Deleting a server removes both linked and orphaned sessions for same user+URL.""" + + from letta.schemas.mcp import MCPOAuthSessionCreate + from letta.schemas.mcp import MCPServer as PydanticMCPServer + from letta.schemas.mcp import MCPServerType + + server_url = "https://test-atomic-cleanup.example.com/mcp" + + # Create orphaned session + orphaned = await server.mcp_manager.create_oauth_session( + MCPOAuthSessionCreate( + server_url=server_url, + server_name="orphaned", + user_id=default_user.id, + organization_id=default_organization.id, + ), + actor=default_user, + ) + + # Create server + created_server = await server.mcp_manager.create_mcp_server( + PydanticMCPServer( + server_name=f"cleanup_server_{") + str(uuid.uuid4().hex[:8]) + ("}", + server_type=MCPServerType.SSE, + server_url=server_url, + organization_id=default_organization.id, + ), + actor=default_user, + ) + + # Delete server + await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user) + + # Both orphaned and any linked sessions for that URL+user should be gone + assert await server.mcp_manager.get_oauth_session_by_id(orphaned.id, actor=default_user) is None + + # ====================================================================================================================== # FileAgent Tests # ======================================================================================================================