diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index b78eeccf..06190ed1 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -2,7 +2,6 @@ import json import logging from datetime import datetime from typing import Any, Dict, List, Optional, Union - from urllib.parse import urlparse from pydantic import Field, field_validator @@ -334,6 +333,7 @@ class MCPOAuthSessionCreate(BaseMCPOAuth): class MCPOAuthSessionUpdate(BaseMCPOAuth): """Update an existing OAuth session.""" + state: Optional[str] = Field(None, description="OAuth state parameter (for session lookup on callback)") authorization_url: Optional[str] = Field(None, description="OAuth authorization URL") authorization_code: Optional[str] = Field(None, description="OAuth authorization code") access_token: Optional[str] = Field(None, description="OAuth access token") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 0336ae29..89ae8e18 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -833,36 +833,48 @@ async def execute_mcp_tool( logger.warning(f"Error during MCP client cleanup: {cleanup_error}") -# TODO: @jnjpng need to route this through cloud API for production -@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback") +# Static OAuth callback endpoint - session is identified via state parameter +@router.get("/mcp/oauth/callback", operation_id="mcp_oauth_callback") async def mcp_oauth_callback( - session_id: str, code: Optional[str] = Query(None, description="OAuth authorization code"), state: Optional[str] = Query(None, description="OAuth state parameter"), error: Optional[str] = Query(None, description="OAuth error"), error_description: Optional[str] = Query(None, description="OAuth error description"), + server: SyncServer = Depends(get_letta_server), ): """ Handle OAuth callback for MCP server authentication. + Session is identified via the state parameter instead of URL path. """ try: - oauth_session = MCPOAuthSession(session_id) + if not state: + return {"status": "error", "message": "Missing state parameter"} + + # Look up OAuth session by state parameter + oauth_session = await server.mcp_server_manager.get_oauth_session_by_state(state) + if not oauth_session: + return {"status": "error", "message": "Invalid or expired state parameter"} + if error: error_msg = f"OAuth error: {error}" if error_description: error_msg += f" - {error_description}" - await oauth_session.update_session_status(OAuthSessionStatus.ERROR) + # Use the legacy MCPOAuthSession class to update status + legacy_session = MCPOAuthSession(oauth_session.id) + await legacy_session.update_session_status(OAuthSessionStatus.ERROR) return {"status": "error", "message": error_msg} - if not code or not state: - await oauth_session.update_session_status(OAuthSessionStatus.ERROR) - return {"status": "error", "message": "Missing authorization code or state"} + if not code: + legacy_session = MCPOAuthSession(oauth_session.id) + await legacy_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": "Missing authorization code"} - # Store authorization code - success = await oauth_session.store_authorization_code(code, state) + # Store authorization code using the legacy session class + legacy_session = MCPOAuthSession(oauth_session.id) + success = await legacy_session.store_authorization_code(code, state) if not success: - await oauth_session.update_session_status(OAuthSessionStatus.ERROR) - return {"status": "error", "message": "Invalid state parameter"} + await legacy_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": "Failed to store authorization code"} return {"status": "success", "message": "Authorization successful", "server_url": success.server_url} diff --git a/letta/services/mcp/server_side_oauth.py b/letta/services/mcp/server_side_oauth.py index 9e753848..17c8ec2f 100644 --- a/letta/services/mcp/server_side_oauth.py +++ b/letta/services/mcp/server_side_oauth.py @@ -11,7 +11,9 @@ by a web frontend rather than opening a local browser. import asyncio import time from typing import Callable, Optional, Tuple +from urllib.parse import parse_qs, urlencode, urlparse, urlunparse +import httpx from fastmcp.client.auth.oauth import OAuth from pydantic import AnyHttpUrl @@ -50,6 +52,8 @@ class ServerSideOAuth(OAuth): url_callback: Optional callback function called with the authorization URL logo_uri: Optional logo URI to include in OAuth client metadata scopes: OAuth scopes to request + exclude_resource_param: If True, prevents the RFC 8707 resource parameter from being + added to OAuth requests. Some servers (like Supabase) reject this parameter. """ def __init__( @@ -62,12 +66,14 @@ class ServerSideOAuth(OAuth): url_callback: Optional[Callable[[str], None]] = None, logo_uri: Optional[str] = None, scopes: Optional[str | list[str]] = None, + exclude_resource_param: bool = True, ): self.session_id = session_id self.mcp_manager = mcp_manager self.actor = actor self._redirect_uri = redirect_uri self._url_callback = url_callback + self._exclude_resource_param = exclude_resource_param # Initialize parent OAuth class (this creates FileTokenStorage internally) super().__init__( @@ -92,24 +98,100 @@ class ServerSideOAuth(OAuth): if logo_uri: self.context.client_metadata.logo_uri = logo_uri + async def _initialize(self) -> None: + """Load stored tokens and client info, properly setting token expiry.""" + await super()._initialize() + + # Some OAuth servers (like Supabase) don't accept the RFC 8707 resource parameter + # Clear protected_resource_metadata to prevent the SDK from adding it to requests + if self._exclude_resource_param: + self.context.protected_resource_metadata = None + + async def _handle_protected_resource_response(self, response: httpx.Response) -> None: + """Handle protected resource metadata response. + + This overrides the parent's method to: + 1. Let OAuth server discovery work (extracts auth_server_url from metadata) + 2. Then clear protected_resource_metadata to prevent RFC 8707 resource parameter + from being added to token exchange and other requests. + + Some OAuth servers (like Supabase) reject the resource parameter entirely. + """ + # Call parent to process metadata and extract auth_server_url + await super()._handle_protected_resource_response(response) + + # Clear the metadata to prevent resource parameter in subsequent requests + # The auth_server_url is already extracted, so OAuth discovery still works + if self._exclude_resource_param: + logger.debug("Clearing protected_resource_metadata to prevent resource parameter in token exchange") + self.context.protected_resource_metadata = None + + async def _handle_token_response(self, response: httpx.Response) -> None: + """Handle token exchange response, accepting both 200 and 201 status codes. + + Some OAuth servers (like Supabase) return 201 Created instead of 200 OK + for successful token exchange. The MCP SDK only accepts 200, so we override + this method to accept both. + """ + # Accept both 200 and 201 as success (Supabase returns 201) + if response.status_code == 201: + logger.debug("Token exchange returned 201 Created, treating as success") + # Monkey-patch the status code to 200 so parent method accepts it + response.status_code = 200 + + await super()._handle_token_response(response) + async def redirect_handler(self, authorization_url: str) -> None: """Store authorization URL in database and call optional callback. This overrides the parent's redirect_handler which would open a browser. Instead, we: - 1. Store the URL in the database for the API to return - 2. Call an optional callback (e.g., to yield to an SSE stream) + 1. Extract the state from the authorization URL (generated by MCP SDK) + 2. Optionally strip the resource parameter (some servers reject it) + 3. Store the URL and state in the database for the API to return + 4. Call an optional callback (e.g., to yield to an SSE stream) Args: authorization_url: The OAuth authorization URL to redirect the user to """ logger.info(f"OAuth redirect handler called with URL: {authorization_url}") - # Store URL in database for API response - session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url) + # Strip the resource parameter if exclude_resource_param is True + # Some OAuth servers (like Supabase) reject the RFC 8707 resource parameter + if self._exclude_resource_param: + parsed_url = urlparse(authorization_url) + query_params = parse_qs(parsed_url.query, keep_blank_values=True) + + if "resource" in query_params: + logger.debug(f"Stripping resource parameter from authorization URL: {query_params['resource']}") + del query_params["resource"] + + # Rebuild the URL without the resource parameter + # parse_qs returns lists, so flatten them for urlencode + flat_params = {k: v[0] if len(v) == 1 else v for k, v in query_params.items()} + new_query = urlencode(flat_params, doseq=True) + authorization_url = urlunparse( + ( + parsed_url.scheme, + parsed_url.netloc, + parsed_url.path, + parsed_url.params, + new_query, + parsed_url.fragment, + ) + ) + logger.info(f"Authorization URL after stripping resource: {authorization_url}") + + # Extract the state parameter from the authorization URL + parsed_url = urlparse(authorization_url) + query_params = parse_qs(parsed_url.query) + oauth_state = query_params.get("state", [None])[0] + + # Store URL and state in database for API response + session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url, state=oauth_state) await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor) - logger.info(f"OAuth authorization URL stored for session {self.session_id}") + logger.info(f"OAuth authorization URL stored for session {self.session_id} with state {oauth_state}") # Call the callback if provided (e.g., to yield URL to SSE stream) if self._url_callback: diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index f8f2fb97..c579de13 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -926,6 +926,18 @@ class MCPManager: return await self._oauth_orm_to_pydantic_async(oauth_session) + @enforce_types + async def get_oauth_session_by_state(self, state: str) -> Optional[MCPOAuthSession]: + """Get an OAuth session by its state parameter (used in static callback URI flow).""" + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.state == state).limit(1)) + oauth_session = result.scalar_one_or_none() + + if not oauth_session: + return None + + return await self._oauth_orm_to_pydantic_async(oauth_session) + @enforce_types async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession: """Update an existing OAuth session.""" @@ -933,6 +945,8 @@ class MCPManager: oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) # Update fields that are provided + if session_update.state is not None: + oauth_session.state = session_update.state if session_update.authorization_url is not None: oauth_session.authorization_url = session_update.authorization_url @@ -1087,11 +1101,13 @@ class MCPManager: LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT") if is_web_request and NEXT_PUBLIC_CURRENT_HOST: - redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}" + # Use static callback URI - session is identified via state parameter + redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/mcp" logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg" elif LETTA_AGENTS_ENDPOINT: # API and SDK usage should call core server directly - redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}" + # Use static callback URI - session is identified via state parameter + redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback" else: logger.error( f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}" diff --git a/letta/services/mcp_server_manager.py b/letta/services/mcp_server_manager.py index a8e7bce2..7604e6cb 100644 --- a/letta/services/mcp_server_manager.py +++ b/letta/services/mcp_server_manager.py @@ -969,7 +969,7 @@ class MCPServerManager: if oauth is None and hasattr(server_config, "server_url"): oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor) # Check if access token exists by attempting to decrypt it - if oauth_session and await oauth_session.get_access_token_secret().get_plaintext_async(): + if oauth_session and oauth_session.access_token_enc and await oauth_session.access_token_enc.get_plaintext_async(): # Create ServerSideOAuth from stored credentials oauth = ServerSideOAuth( mcp_url=oauth_session.server_url, @@ -1098,6 +1098,18 @@ class MCPServerManager: return await self._oauth_orm_to_pydantic_async(oauth_session) + @enforce_types + async def get_oauth_session_by_state(self, state: str) -> Optional[MCPOAuthSession]: + """Get an OAuth session by its state parameter (used in static callback URI flow).""" + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.state == state).limit(1)) + oauth_session = result.scalar_one_or_none() + + if not oauth_session: + return None + + return await self._oauth_orm_to_pydantic_async(oauth_session) + @enforce_types async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession: """Update an existing OAuth session.""" @@ -1283,11 +1295,13 @@ class MCPServerManager: LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT") if is_web_request and NEXT_PUBLIC_CURRENT_HOST: - redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}" + # Use static callback URI - session is identified via state parameter + redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/mcp" logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg" elif LETTA_AGENTS_ENDPOINT: # API and SDK usage should call core server directly - redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}" + # Use static callback URI - session is identified via state parameter + redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback" else: logger.error( f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}"