fix: delete associated mcp oauth sessions on delete and link existing mcp oauth sessions on create
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -608,7 +608,7 @@ async def delete_mcp_server_from_config(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
Delete a MCP server configuration
|
||||
"""
|
||||
if tool_settings.mcp_read_from_config:
|
||||
# write to config file
|
||||
|
||||
@@ -6,7 +6,7 @@ from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy import null
|
||||
from sqlalchemy import delete, null
|
||||
from starlette.requests import Request
|
||||
|
||||
import letta.constants as constants
|
||||
@@ -169,17 +169,50 @@ class MCPManager:
|
||||
async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
||||
"""Create a new MCP server."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_mcp_server.organization_id = actor.organization_id
|
||||
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
|
||||
try:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_mcp_server.organization_id = actor.organization_id
|
||||
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
|
||||
|
||||
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
||||
if mcp_server_data.get("custom_headers") is None:
|
||||
mcp_server_data.pop("custom_headers", None)
|
||||
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
||||
if mcp_server_data.get("custom_headers") is None:
|
||||
mcp_server_data.pop("custom_headers", None)
|
||||
|
||||
mcp_server = MCPServerModel(**mcp_server_data)
|
||||
mcp_server = await mcp_server.create_async(session, actor=actor)
|
||||
return mcp_server.to_pydantic()
|
||||
mcp_server = MCPServerModel(**mcp_server_data)
|
||||
mcp_server = await mcp_server.create_async(session, actor=actor, no_commit=True)
|
||||
|
||||
# Link existing OAuth sessions for the same user and server URL
|
||||
# This ensures OAuth sessions created during testing get linked to the server
|
||||
server_url = getattr(mcp_server, "server_url", None)
|
||||
if server_url:
|
||||
from sqlalchemy import select
|
||||
|
||||
result = await session.execute(
|
||||
select(MCPOAuth).where(
|
||||
MCPOAuth.server_url == server_url,
|
||||
MCPOAuth.organization_id == actor.organization_id,
|
||||
MCPOAuth.user_id == actor.id, # Only link sessions for the same user
|
||||
MCPOAuth.server_id.is_(None), # Only update sessions not already linked
|
||||
)
|
||||
)
|
||||
oauth_sessions = result.scalars().all()
|
||||
|
||||
# TODO: @jnjpng we should upate sessions in bulk
|
||||
for oauth_session in oauth_sessions:
|
||||
oauth_session.server_id = mcp_server.id
|
||||
await oauth_session.update_async(db_session=session, actor=actor, no_commit=True)
|
||||
|
||||
if oauth_sessions:
|
||||
logger.info(
|
||||
f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}"
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
return mcp_server.to_pydantic()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to create MCP server: {e}")
|
||||
raise
|
||||
|
||||
@enforce_types
|
||||
async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer:
|
||||
@@ -286,13 +319,48 @@ class MCPManager:
|
||||
|
||||
@enforce_types
|
||||
async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a tool by its ID."""
|
||||
"""Delete a MCP server by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
await mcp_server.hard_delete_async(db_session=session, actor=actor)
|
||||
if not mcp_server:
|
||||
raise NoResultFound(f"MCP server with id {mcp_server_id} not found.")
|
||||
|
||||
server_url = getattr(mcp_server, "server_url", None)
|
||||
|
||||
# Delete OAuth sessions for the same user and server URL in the same transaction
|
||||
# This handles orphaned sessions that were created during testing/connection
|
||||
oauth_count = 0
|
||||
if server_url:
|
||||
result = await session.execute(
|
||||
delete(MCPOAuth).where(
|
||||
MCPOAuth.server_url == server_url,
|
||||
MCPOAuth.organization_id == actor.organization_id,
|
||||
MCPOAuth.user_id == actor.id, # Only delete sessions for the same user
|
||||
)
|
||||
)
|
||||
oauth_count = result.rowcount
|
||||
if oauth_count > 0:
|
||||
logger.info(
|
||||
f"Deleting {oauth_count} OAuth sessions for MCP server {mcp_server_id} (URL: {server_url}) for user {actor.id}"
|
||||
)
|
||||
|
||||
# Delete the MCP server, will cascade delete to linked OAuth sessions
|
||||
await session.execute(
|
||||
delete(MCPServerModel).where(
|
||||
MCPServerModel.id == mcp_server_id,
|
||||
MCPServerModel.organization_id == actor.organization_id,
|
||||
)
|
||||
)
|
||||
|
||||
await session.commit()
|
||||
except NoResultFound:
|
||||
await session.rollback()
|
||||
raise ValueError(f"MCP server with id {mcp_server_id} not found.")
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.error(f"Failed to delete MCP server {mcp_server_id}: {e}")
|
||||
raise
|
||||
|
||||
def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]]:
|
||||
mcp_server_list = {}
|
||||
|
||||
@@ -9168,6 +9168,177 @@ async def test_get_mcp_servers_by_ids(server, default_user, event_loop):
|
||||
assert all(s.organization_id == default_user.organization_id for s in bulk_fetched)
|
||||
|
||||
|
||||
# Additional MCPManager OAuth session tests
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_server_deletion_cascades_oauth_sessions(server, default_organization, default_user, event_loop):
|
||||
"""Deleting an MCP server deletes associated OAuth sessions (same user + URL)."""
|
||||
|
||||
from letta.schemas.mcp import MCPOAuthSessionCreate
|
||||
from letta.schemas.mcp import MCPServer as PydanticMCPServer
|
||||
from letta.schemas.mcp import MCPServerType
|
||||
|
||||
test_server_url = "https://test.example.com/mcp"
|
||||
|
||||
# Create orphaned OAuth sessions (no server id) for same user and URL
|
||||
created_session_ids: list[str] = []
|
||||
for i in range(3):
|
||||
session = await server.mcp_manager.create_oauth_session(
|
||||
MCPOAuthSessionCreate(
|
||||
server_url=test_server_url,
|
||||
server_name=f"test_mcp_server_{i}",
|
||||
user_id=default_user.id,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
created_session_ids.append(session.id)
|
||||
|
||||
# Create the MCP server with the same URL
|
||||
created_server = await server.mcp_manager.create_mcp_server(
|
||||
PydanticMCPServer(
|
||||
server_name=f"test_mcp_server_{") + str(uuid.uuid4().hex[:8]) + ("}", # ensure unique name
|
||||
server_type=MCPServerType.SSE,
|
||||
server_url=test_server_url,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Now delete the server via manager
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
# Verify all sessions are gone
|
||||
for sid in created_session_ids:
|
||||
session = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
|
||||
assert session is None, f"OAuth session {sid} should be deleted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_oauth_sessions_with_different_url_persist(server, default_organization, default_user, event_loop):
|
||||
"""Sessions with different URL should not be deleted when deleting the server for another URL."""
|
||||
|
||||
from letta.schemas.mcp import MCPOAuthSessionCreate
|
||||
from letta.schemas.mcp import MCPServer as PydanticMCPServer
|
||||
from letta.schemas.mcp import MCPServerType
|
||||
|
||||
server_url = "https://test.example.com/mcp"
|
||||
other_url = "https://other.example.com/mcp"
|
||||
|
||||
# Create a session for other_url (should persist)
|
||||
other_session = await server.mcp_manager.create_oauth_session(
|
||||
MCPOAuthSessionCreate(
|
||||
server_url=other_url,
|
||||
server_name="standalone_oauth",
|
||||
user_id=default_user.id,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create the MCP server at server_url
|
||||
created_server = await server.mcp_manager.create_mcp_server(
|
||||
PydanticMCPServer(
|
||||
server_name=f"test_mcp_server_{") + str(uuid.uuid4().hex[:8]) + ("}",
|
||||
server_type=MCPServerType.SSE,
|
||||
server_url=server_url,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Delete the server at server_url
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
# Verify the session at other_url still exists
|
||||
persisted = await server.mcp_manager.get_oauth_session_by_id(other_session.id, actor=default_user)
|
||||
assert persisted is not None, "OAuth session with different URL should persist"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_server_creation_links_orphaned_sessions(server, default_organization, default_user, event_loop):
|
||||
"""Creating a server should link any existing orphaned sessions (same user + URL)."""
|
||||
|
||||
from letta.schemas.mcp import MCPOAuthSessionCreate
|
||||
from letta.schemas.mcp import MCPServer as PydanticMCPServer
|
||||
from letta.schemas.mcp import MCPServerType
|
||||
|
||||
server_url = "https://test-atomic-create.example.com/mcp"
|
||||
|
||||
# Pre-create orphaned sessions (no server_id) for same user + URL
|
||||
orphaned_ids: list[str] = []
|
||||
for i in range(3):
|
||||
session = await server.mcp_manager.create_oauth_session(
|
||||
MCPOAuthSessionCreate(
|
||||
server_url=server_url,
|
||||
server_name=f"atomic_session_{i}",
|
||||
user_id=default_user.id,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
orphaned_ids.append(session.id)
|
||||
|
||||
# Create server
|
||||
created_server = await server.mcp_manager.create_mcp_server(
|
||||
PydanticMCPServer(
|
||||
server_name=f"test_atomic_server_{") + str(uuid.uuid4().hex[:8]) + ("}",
|
||||
server_type=MCPServerType.SSE,
|
||||
server_url=server_url,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Sessions should still be retrievable via manager API
|
||||
for sid in orphaned_ids:
|
||||
s = await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user)
|
||||
assert s is not None
|
||||
|
||||
# Indirect verification: deleting the server removes sessions for that URL+user
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
for sid in orphaned_ids:
|
||||
assert await server.mcp_manager.get_oauth_session_by_id(sid, actor=default_user) is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mcp_server_delete_removes_all_sessions_for_url_and_user(server, default_organization, default_user, event_loop):
|
||||
"""Deleting a server removes both linked and orphaned sessions for same user+URL."""
|
||||
|
||||
from letta.schemas.mcp import MCPOAuthSessionCreate
|
||||
from letta.schemas.mcp import MCPServer as PydanticMCPServer
|
||||
from letta.schemas.mcp import MCPServerType
|
||||
|
||||
server_url = "https://test-atomic-cleanup.example.com/mcp"
|
||||
|
||||
# Create orphaned session
|
||||
orphaned = await server.mcp_manager.create_oauth_session(
|
||||
MCPOAuthSessionCreate(
|
||||
server_url=server_url,
|
||||
server_name="orphaned",
|
||||
user_id=default_user.id,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create server
|
||||
created_server = await server.mcp_manager.create_mcp_server(
|
||||
PydanticMCPServer(
|
||||
server_name=f"cleanup_server_{") + str(uuid.uuid4().hex[:8]) + ("}",
|
||||
server_type=MCPServerType.SSE,
|
||||
server_url=server_url,
|
||||
organization_id=default_organization.id,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Delete server
|
||||
await server.mcp_manager.delete_mcp_server_by_id(created_server.id, actor=default_user)
|
||||
|
||||
# Both orphaned and any linked sessions for that URL+user should be gone
|
||||
assert await server.mcp_manager.get_oauth_session_by_id(orphaned.id, actor=default_user) is None
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# FileAgent Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user