feat: static redirect callback for mcp server oauth (#8611)
* base * base * more * final * remove * pass
This commit is contained in:
@@ -2,7 +2,6 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
@@ -334,6 +333,7 @@ class MCPOAuthSessionCreate(BaseMCPOAuth):
|
||||
class MCPOAuthSessionUpdate(BaseMCPOAuth):
|
||||
"""Update an existing OAuth session."""
|
||||
|
||||
state: Optional[str] = Field(None, description="OAuth state parameter (for session lookup on callback)")
|
||||
authorization_url: Optional[str] = Field(None, description="OAuth authorization URL")
|
||||
authorization_code: Optional[str] = Field(None, description="OAuth authorization code")
|
||||
access_token: Optional[str] = Field(None, description="OAuth access token")
|
||||
|
||||
@@ -833,36 +833,48 @@ async def execute_mcp_tool(
|
||||
logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
|
||||
|
||||
|
||||
# TODO: @jnjpng need to route this through cloud API for production
|
||||
@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback")
|
||||
# Static OAuth callback endpoint - session is identified via state parameter
|
||||
@router.get("/mcp/oauth/callback", operation_id="mcp_oauth_callback")
|
||||
async def mcp_oauth_callback(
|
||||
session_id: str,
|
||||
code: Optional[str] = Query(None, description="OAuth authorization code"),
|
||||
state: Optional[str] = Query(None, description="OAuth state parameter"),
|
||||
error: Optional[str] = Query(None, description="OAuth error"),
|
||||
error_description: Optional[str] = Query(None, description="OAuth error description"),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Handle OAuth callback for MCP server authentication.
|
||||
Session is identified via the state parameter instead of URL path.
|
||||
"""
|
||||
try:
|
||||
oauth_session = MCPOAuthSession(session_id)
|
||||
if not state:
|
||||
return {"status": "error", "message": "Missing state parameter"}
|
||||
|
||||
# Look up OAuth session by state parameter
|
||||
oauth_session = await server.mcp_server_manager.get_oauth_session_by_state(state)
|
||||
if not oauth_session:
|
||||
return {"status": "error", "message": "Invalid or expired state parameter"}
|
||||
|
||||
if error:
|
||||
error_msg = f"OAuth error: {error}"
|
||||
if error_description:
|
||||
error_msg += f" - {error_description}"
|
||||
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
# Use the legacy MCPOAuthSession class to update status
|
||||
legacy_session = MCPOAuthSession(oauth_session.id)
|
||||
await legacy_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if not code or not state:
|
||||
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": "Missing authorization code or state"}
|
||||
if not code:
|
||||
legacy_session = MCPOAuthSession(oauth_session.id)
|
||||
await legacy_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": "Missing authorization code"}
|
||||
|
||||
# Store authorization code
|
||||
success = await oauth_session.store_authorization_code(code, state)
|
||||
# Store authorization code using the legacy session class
|
||||
legacy_session = MCPOAuthSession(oauth_session.id)
|
||||
success = await legacy_session.store_authorization_code(code, state)
|
||||
if not success:
|
||||
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": "Invalid state parameter"}
|
||||
await legacy_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": "Failed to store authorization code"}
|
||||
|
||||
return {"status": "success", "message": "Authorization successful", "server_url": success.server_url}
|
||||
|
||||
|
||||
@@ -11,7 +11,9 @@ by a web frontend rather than opening a local browser.
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Callable, Optional, Tuple
|
||||
from urllib.parse import parse_qs, urlencode, urlparse, urlunparse
|
||||
|
||||
import httpx
|
||||
from fastmcp.client.auth.oauth import OAuth
|
||||
from pydantic import AnyHttpUrl
|
||||
|
||||
@@ -50,6 +52,8 @@ class ServerSideOAuth(OAuth):
|
||||
url_callback: Optional callback function called with the authorization URL
|
||||
logo_uri: Optional logo URI to include in OAuth client metadata
|
||||
scopes: OAuth scopes to request
|
||||
exclude_resource_param: If True, prevents the RFC 8707 resource parameter from being
|
||||
added to OAuth requests. Some servers (like Supabase) reject this parameter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -62,12 +66,14 @@ class ServerSideOAuth(OAuth):
|
||||
url_callback: Optional[Callable[[str], None]] = None,
|
||||
logo_uri: Optional[str] = None,
|
||||
scopes: Optional[str | list[str]] = None,
|
||||
exclude_resource_param: bool = True,
|
||||
):
|
||||
self.session_id = session_id
|
||||
self.mcp_manager = mcp_manager
|
||||
self.actor = actor
|
||||
self._redirect_uri = redirect_uri
|
||||
self._url_callback = url_callback
|
||||
self._exclude_resource_param = exclude_resource_param
|
||||
|
||||
# Initialize parent OAuth class (this creates FileTokenStorage internally)
|
||||
super().__init__(
|
||||
@@ -92,24 +98,100 @@ class ServerSideOAuth(OAuth):
|
||||
if logo_uri:
|
||||
self.context.client_metadata.logo_uri = logo_uri
|
||||
|
||||
async def _initialize(self) -> None:
|
||||
"""Load stored tokens and client info, properly setting token expiry."""
|
||||
await super()._initialize()
|
||||
|
||||
# Some OAuth servers (like Supabase) don't accept the RFC 8707 resource parameter
|
||||
# Clear protected_resource_metadata to prevent the SDK from adding it to requests
|
||||
if self._exclude_resource_param:
|
||||
self.context.protected_resource_metadata = None
|
||||
|
||||
async def _handle_protected_resource_response(self, response: httpx.Response) -> None:
|
||||
"""Handle protected resource metadata response.
|
||||
|
||||
This overrides the parent's method to:
|
||||
1. Let OAuth server discovery work (extracts auth_server_url from metadata)
|
||||
2. Then clear protected_resource_metadata to prevent RFC 8707 resource parameter
|
||||
from being added to token exchange and other requests.
|
||||
|
||||
Some OAuth servers (like Supabase) reject the resource parameter entirely.
|
||||
"""
|
||||
# Call parent to process metadata and extract auth_server_url
|
||||
await super()._handle_protected_resource_response(response)
|
||||
|
||||
# Clear the metadata to prevent resource parameter in subsequent requests
|
||||
# The auth_server_url is already extracted, so OAuth discovery still works
|
||||
if self._exclude_resource_param:
|
||||
logger.debug("Clearing protected_resource_metadata to prevent resource parameter in token exchange")
|
||||
self.context.protected_resource_metadata = None
|
||||
|
||||
async def _handle_token_response(self, response: httpx.Response) -> None:
|
||||
"""Handle token exchange response, accepting both 200 and 201 status codes.
|
||||
|
||||
Some OAuth servers (like Supabase) return 201 Created instead of 200 OK
|
||||
for successful token exchange. The MCP SDK only accepts 200, so we override
|
||||
this method to accept both.
|
||||
"""
|
||||
# Accept both 200 and 201 as success (Supabase returns 201)
|
||||
if response.status_code == 201:
|
||||
logger.debug("Token exchange returned 201 Created, treating as success")
|
||||
# Monkey-patch the status code to 200 so parent method accepts it
|
||||
response.status_code = 200
|
||||
|
||||
await super()._handle_token_response(response)
|
||||
|
||||
async def redirect_handler(self, authorization_url: str) -> None:
|
||||
"""Store authorization URL in database and call optional callback.
|
||||
|
||||
This overrides the parent's redirect_handler which would open a browser.
|
||||
Instead, we:
|
||||
1. Store the URL in the database for the API to return
|
||||
2. Call an optional callback (e.g., to yield to an SSE stream)
|
||||
1. Extract the state from the authorization URL (generated by MCP SDK)
|
||||
2. Optionally strip the resource parameter (some servers reject it)
|
||||
3. Store the URL and state in the database for the API to return
|
||||
4. Call an optional callback (e.g., to yield to an SSE stream)
|
||||
|
||||
Args:
|
||||
authorization_url: The OAuth authorization URL to redirect the user to
|
||||
"""
|
||||
logger.info(f"OAuth redirect handler called with URL: {authorization_url}")
|
||||
|
||||
# Store URL in database for API response
|
||||
session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url)
|
||||
# Strip the resource parameter if exclude_resource_param is True
|
||||
# Some OAuth servers (like Supabase) reject the RFC 8707 resource parameter
|
||||
if self._exclude_resource_param:
|
||||
parsed_url = urlparse(authorization_url)
|
||||
query_params = parse_qs(parsed_url.query, keep_blank_values=True)
|
||||
|
||||
if "resource" in query_params:
|
||||
logger.debug(f"Stripping resource parameter from authorization URL: {query_params['resource']}")
|
||||
del query_params["resource"]
|
||||
|
||||
# Rebuild the URL without the resource parameter
|
||||
# parse_qs returns lists, so flatten them for urlencode
|
||||
flat_params = {k: v[0] if len(v) == 1 else v for k, v in query_params.items()}
|
||||
new_query = urlencode(flat_params, doseq=True)
|
||||
authorization_url = urlunparse(
|
||||
(
|
||||
parsed_url.scheme,
|
||||
parsed_url.netloc,
|
||||
parsed_url.path,
|
||||
parsed_url.params,
|
||||
new_query,
|
||||
parsed_url.fragment,
|
||||
)
|
||||
)
|
||||
logger.info(f"Authorization URL after stripping resource: {authorization_url}")
|
||||
|
||||
# Extract the state parameter from the authorization URL
|
||||
parsed_url = urlparse(authorization_url)
|
||||
query_params = parse_qs(parsed_url.query)
|
||||
oauth_state = query_params.get("state", [None])[0]
|
||||
|
||||
# Store URL and state in database for API response
|
||||
session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url, state=oauth_state)
|
||||
await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
|
||||
|
||||
logger.info(f"OAuth authorization URL stored for session {self.session_id}")
|
||||
logger.info(f"OAuth authorization URL stored for session {self.session_id} with state {oauth_state}")
|
||||
|
||||
# Call the callback if provided (e.g., to yield URL to SSE stream)
|
||||
if self._url_callback:
|
||||
|
||||
@@ -926,6 +926,18 @@ class MCPManager:
|
||||
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_state(self, state: str) -> Optional[MCPOAuthSession]:
|
||||
"""Get an OAuth session by its state parameter (used in static callback URI flow)."""
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(MCPOAuth).where(MCPOAuth.state == state).limit(1))
|
||||
oauth_session = result.scalar_one_or_none()
|
||||
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
"""Update an existing OAuth session."""
|
||||
@@ -933,6 +945,8 @@ class MCPManager:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
|
||||
# Update fields that are provided
|
||||
if session_update.state is not None:
|
||||
oauth_session.state = session_update.state
|
||||
if session_update.authorization_url is not None:
|
||||
oauth_session.authorization_url = session_update.authorization_url
|
||||
|
||||
@@ -1087,11 +1101,13 @@ class MCPManager:
|
||||
LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT")
|
||||
|
||||
if is_web_request and NEXT_PUBLIC_CURRENT_HOST:
|
||||
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}"
|
||||
# Use static callback URI - session is identified via state parameter
|
||||
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/mcp"
|
||||
logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg"
|
||||
elif LETTA_AGENTS_ENDPOINT:
|
||||
# API and SDK usage should call core server directly
|
||||
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}"
|
||||
# Use static callback URI - session is identified via state parameter
|
||||
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback"
|
||||
else:
|
||||
logger.error(
|
||||
f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}"
|
||||
|
||||
@@ -969,7 +969,7 @@ class MCPServerManager:
|
||||
if oauth is None and hasattr(server_config, "server_url"):
|
||||
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
||||
# Check if access token exists by attempting to decrypt it
|
||||
if oauth_session and await oauth_session.get_access_token_secret().get_plaintext_async():
|
||||
if oauth_session and oauth_session.access_token_enc and await oauth_session.access_token_enc.get_plaintext_async():
|
||||
# Create ServerSideOAuth from stored credentials
|
||||
oauth = ServerSideOAuth(
|
||||
mcp_url=oauth_session.server_url,
|
||||
@@ -1098,6 +1098,18 @@ class MCPServerManager:
|
||||
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_state(self, state: str) -> Optional[MCPOAuthSession]:
|
||||
"""Get an OAuth session by its state parameter (used in static callback URI flow)."""
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(MCPOAuth).where(MCPOAuth.state == state).limit(1))
|
||||
oauth_session = result.scalar_one_or_none()
|
||||
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return await self._oauth_orm_to_pydantic_async(oauth_session)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
"""Update an existing OAuth session."""
|
||||
@@ -1283,11 +1295,13 @@ class MCPServerManager:
|
||||
LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT")
|
||||
|
||||
if is_web_request and NEXT_PUBLIC_CURRENT_HOST:
|
||||
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}"
|
||||
# Use static callback URI - session is identified via state parameter
|
||||
redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/mcp"
|
||||
logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg"
|
||||
elif LETTA_AGENTS_ENDPOINT:
|
||||
# API and SDK usage should call core server directly
|
||||
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}"
|
||||
# Use static callback URI - session is identified via state parameter
|
||||
redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback"
|
||||
else:
|
||||
logger.error(
|
||||
f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}"
|
||||
|
||||
Reference in New Issue
Block a user