feat: static redirect callback for mcp server oauth (#8611)

* base

* base

* more

* final

* remove

* pass
This commit is contained in:
jnjpng
2026-01-12 16:04:12 -08:00
committed by Sarah Wooders
parent 089ea415ab
commit c550457b60
5 changed files with 147 additions and 23 deletions

View File

@@ -2,7 +2,6 @@ import json
import logging
from datetime import datetime
from typing import Any, Dict, List, Optional, Union
from urllib.parse import urlparse
from pydantic import Field, field_validator
@@ -334,6 +333,7 @@ class MCPOAuthSessionCreate(BaseMCPOAuth):
class MCPOAuthSessionUpdate(BaseMCPOAuth):
"""Update an existing OAuth session."""
state: Optional[str] = Field(None, description="OAuth state parameter (for session lookup on callback)")
authorization_url: Optional[str] = Field(None, description="OAuth authorization URL")
authorization_code: Optional[str] = Field(None, description="OAuth authorization code")
access_token: Optional[str] = Field(None, description="OAuth access token")

View File

@@ -833,36 +833,48 @@ async def execute_mcp_tool(
logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
# TODO: @jnjpng need to route this through cloud API for production
@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback")
# Static OAuth callback endpoint - session is identified via state parameter
@router.get("/mcp/oauth/callback", operation_id="mcp_oauth_callback")
async def mcp_oauth_callback(
session_id: str,
code: Optional[str] = Query(None, description="OAuth authorization code"),
state: Optional[str] = Query(None, description="OAuth state parameter"),
error: Optional[str] = Query(None, description="OAuth error"),
error_description: Optional[str] = Query(None, description="OAuth error description"),
server: SyncServer = Depends(get_letta_server),
):
"""
Handle OAuth callback for MCP server authentication.
Session is identified via the state parameter instead of URL path.
"""
try:
oauth_session = MCPOAuthSession(session_id)
if not state:
return {"status": "error", "message": "Missing state parameter"}
# Look up OAuth session by state parameter
oauth_session = await server.mcp_server_manager.get_oauth_session_by_state(state)
if not oauth_session:
return {"status": "error", "message": "Invalid or expired state parameter"}
if error:
error_msg = f"OAuth error: {error}"
if error_description:
error_msg += f" - {error_description}"
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
# Use the legacy MCPOAuthSession class to update status
legacy_session = MCPOAuthSession(oauth_session.id)
await legacy_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": error_msg}
if not code or not state:
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": "Missing authorization code or state"}
if not code:
legacy_session = MCPOAuthSession(oauth_session.id)
await legacy_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": "Missing authorization code"}
# Store authorization code
success = await oauth_session.store_authorization_code(code, state)
# Store authorization code using the legacy session class
legacy_session = MCPOAuthSession(oauth_session.id)
success = await legacy_session.store_authorization_code(code, state)
if not success:
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": "Invalid state parameter"}
await legacy_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": "Failed to store authorization code"}
return {"status": "success", "message": "Authorization successful", "server_url": success.server_url}

View File

@@ -11,7 +11,9 @@ by a web frontend rather than opening a local browser.
import asyncio
import time
from typing import Callable, Optional, Tuple
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
import httpx
from fastmcp.client.auth.oauth import OAuth
from pydantic import AnyHttpUrl
@@ -50,6 +52,8 @@ class ServerSideOAuth(OAuth):
url_callback: Optional callback function called with the authorization URL
logo_uri: Optional logo URI to include in OAuth client metadata
scopes: OAuth scopes to request
exclude_resource_param: If True, prevents the RFC 8707 resource parameter from being
added to OAuth requests. Some servers (like Supabase) reject this parameter.
"""
def __init__(
@@ -62,12 +66,14 @@ class ServerSideOAuth(OAuth):
url_callback: Optional[Callable[[str], None]] = None,
logo_uri: Optional[str] = None,
scopes: Optional[str | list[str]] = None,
exclude_resource_param: bool = True,
):
self.session_id = session_id
self.mcp_manager = mcp_manager
self.actor = actor
self._redirect_uri = redirect_uri
self._url_callback = url_callback
self._exclude_resource_param = exclude_resource_param
# Initialize parent OAuth class (this creates FileTokenStorage internally)
super().__init__(
@@ -92,24 +98,100 @@ class ServerSideOAuth(OAuth):
if logo_uri:
self.context.client_metadata.logo_uri = logo_uri
async def _initialize(self) -> None:
"""Load stored tokens and client info, properly setting token expiry."""
await super()._initialize()
# Some OAuth servers (like Supabase) don't accept the RFC 8707 resource parameter
# Clear protected_resource_metadata to prevent the SDK from adding it to requests
if self._exclude_resource_param:
self.context.protected_resource_metadata = None
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
"""Handle protected resource metadata response.
This overrides the parent's method to:
1. Let OAuth server discovery work (extracts auth_server_url from metadata)
2. Then clear protected_resource_metadata to prevent RFC 8707 resource parameter
from being added to token exchange and other requests.
Some OAuth servers (like Supabase) reject the resource parameter entirely.
"""
# Call parent to process metadata and extract auth_server_url
await super()._handle_protected_resource_response(response)
# Clear the metadata to prevent resource parameter in subsequent requests
# The auth_server_url is already extracted, so OAuth discovery still works
if self._exclude_resource_param:
logger.debug("Clearing protected_resource_metadata to prevent resource parameter in token exchange")
self.context.protected_resource_metadata = None
async def _handle_token_response(self, response: httpx.Response) -> None:
"""Handle token exchange response, accepting both 200 and 201 status codes.
Some OAuth servers (like Supabase) return 201 Created instead of 200 OK
for successful token exchange. The MCP SDK only accepts 200, so we override
this method to accept both.
"""
# Accept both 200 and 201 as success (Supabase returns 201)
if response.status_code == 201:
logger.debug("Token exchange returned 201 Created, treating as success")
# Monkey-patch the status code to 200 so parent method accepts it
response.status_code = 200
await super()._handle_token_response(response)
async def redirect_handler(self, authorization_url: str) -> None:
"""Store authorization URL in database and call optional callback.
This overrides the parent's redirect_handler which would open a browser.
Instead, we:
1. Store the URL in the database for the API to return
2. Call an optional callback (e.g., to yield to an SSE stream)
1. Extract the state from the authorization URL (generated by MCP SDK)
2. Optionally strip the resource parameter (some servers reject it)
3. Store the URL and state in the database for the API to return
4. Call an optional callback (e.g., to yield to an SSE stream)
Args:
authorization_url: The OAuth authorization URL to redirect the user to
"""
logger.info(f"OAuth redirect handler called with URL: {authorization_url}")
# Store URL in database for API response
session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url)
# Strip the resource parameter if exclude_resource_param is True
# Some OAuth servers (like Supabase) reject the RFC 8707 resource parameter
if self._exclude_resource_param:
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query, keep_blank_values=True)
if "resource" in query_params:
logger.debug(f"Stripping resource parameter from authorization URL: {query_params['resource']}")
del query_params["resource"]
# Rebuild the URL without the resource parameter
# parse_qs returns lists, so flatten them for urlencode
flat_params = {k: v[0] if len(v) == 1 else v for k, v in query_params.items()}
new_query = urlencode(flat_params, doseq=True)
authorization_url = urlunparse(
(
parsed_url.scheme,
parsed_url.netloc,
parsed_url.path,
parsed_url.params,
new_query,
parsed_url.fragment,
)
)
logger.info(f"Authorization URL after stripping resource: {authorization_url}")
# Extract the state parameter from the authorization URL
parsed_url = urlparse(authorization_url)
query_params = parse_qs(parsed_url.query)
oauth_state = query_params.get("state", [None])[0]
# Store URL and state in database for API response
session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url, state=oauth_state)
await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
logger.info(f"OAuth authorization URL stored for session {self.session_id}")
logger.info(f"OAuth authorization URL stored for session {self.session_id} with state {oauth_state}")
# Call the callback if provided (e.g., to yield URL to SSE stream)
if self._url_callback:

View File

@@ -926,6 +926,18 @@ class MCPManager:
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def get_oauth_session_by_state(self, state: str) -> Optional[MCPOAuthSession]:
"""Get an OAuth session by its state parameter (used in static callback URI flow)."""
async with db_registry.async_session() as session:
result = await session.execute(select(MCPOAuth).where(MCPOAuth.state == state).limit(1))
oauth_session = result.scalar_one_or_none()
if not oauth_session:
return None
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
"""Update an existing OAuth session."""
@@ -933,6 +945,8 @@ class MCPManager:
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
# Update fields that are provided
if session_update.state is not None:
oauth_session.state = session_update.state
if session_update.authorization_url is not None:
oauth_session.authorization_url = session_update.authorization_url
@@ -1087,11 +1101,13 @@ class MCPManager:
LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT")
if is_web_request and NEXT_PUBLIC_CURRENT_HOST:
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}"
# Use static callback URI - session is identified via state parameter
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/mcp"
logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg"
elif LETTA_AGENTS_ENDPOINT:
# API and SDK usage should call core server directly
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}"
# Use static callback URI - session is identified via state parameter
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback"
else:
logger.error(
f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}"

View File

@@ -969,7 +969,7 @@ class MCPServerManager:
if oauth is None and hasattr(server_config, "server_url"):
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
# Check if access token exists by attempting to decrypt it
if oauth_session and await oauth_session.get_access_token_secret().get_plaintext_async():
if oauth_session and oauth_session.access_token_enc and await oauth_session.access_token_enc.get_plaintext_async():
# Create ServerSideOAuth from stored credentials
oauth = ServerSideOAuth(
mcp_url=oauth_session.server_url,
@@ -1098,6 +1098,18 @@ class MCPServerManager:
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def get_oauth_session_by_state(self, state: str) -> Optional[MCPOAuthSession]:
"""Get an OAuth session by its state parameter (used in static callback URI flow)."""
async with db_registry.async_session() as session:
result = await session.execute(select(MCPOAuth).where(MCPOAuth.state == state).limit(1))
oauth_session = result.scalar_one_or_none()
if not oauth_session:
return None
return await self._oauth_orm_to_pydantic_async(oauth_session)
@enforce_types
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
"""Update an existing OAuth session."""
@@ -1283,11 +1295,13 @@ class MCPServerManager:
LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT")
if is_web_request and NEXT_PUBLIC_CURRENT_HOST:
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}"
# Use static callback URI - session is identified via state parameter
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/mcp"
logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg"
elif LETTA_AGENTS_ENDPOINT:
# API and SDK usage should call core server directly
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}"
# Use static callback URI - session is identified via state parameter
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback"
else:
logger.error(
f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}"