"""OAuth utilities for MCP server authentication.""" import asyncio import json import secrets import time import uuid from datetime import datetime, timedelta from typing import TYPE_CHECKING, Callable, Optional, Tuple from mcp.client.auth import OAuthClientProvider, TokenStorage from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken from sqlalchemy import select from letta.log import get_logger from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus from letta.schemas.mcp import MCPOAuthSessionUpdate from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.mcp.types import OauthStreamEvent if TYPE_CHECKING: from letta.services.mcp_manager import MCPManager logger = get_logger(__name__) class DatabaseTokenStorage(TokenStorage): """Database-backed token storage using MCPOAuth table via mcp_manager.""" def __init__(self, session_id: str, mcp_manager: "MCPManager", actor: PydanticUser): self.session_id = session_id self.mcp_manager = mcp_manager self.actor = actor async def get_tokens(self) -> Optional[OAuthToken]: """Retrieve tokens from database.""" oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor) if not oauth_session: return None # Read tokens directly from _enc columns access_token = await oauth_session.access_token_enc.get_plaintext_async() if oauth_session.access_token_enc else None if not access_token: return None refresh_token = await oauth_session.refresh_token_enc.get_plaintext_async() if oauth_session.refresh_token_enc else None return OAuthToken( access_token=access_token, refresh_token=refresh_token, token_type=oauth_session.token_type, expires_in=int(oauth_session.expires_at.timestamp() - time.time()), scope=oauth_session.scope, ) async def set_tokens(self, tokens: OAuthToken) -> None: """Store tokens in database.""" session_update = MCPOAuthSessionUpdate( access_token=tokens.access_token, refresh_token=tokens.refresh_token, token_type=tokens.token_type, expires_at=datetime.fromtimestamp(tokens.expires_in + time.time()), scope=tokens.scope, status=OAuthSessionStatus.AUTHORIZED, ) await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor) async def get_client_info(self) -> Optional[OAuthClientInformationFull]: """Retrieve client information from database.""" oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor) if not oauth_session or not oauth_session.client_id: return None # Read client secret directly from _enc column client_secret = await oauth_session.client_secret_enc.get_plaintext_async() if oauth_session.client_secret_enc else None return OAuthClientInformationFull( client_id=oauth_session.client_id, client_secret=client_secret, redirect_uris=[oauth_session.redirect_uri] if oauth_session.redirect_uri else [], ) async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: """Store client information in database.""" session_update = MCPOAuthSessionUpdate( client_id=client_info.client_id, client_secret=client_info.client_secret, redirect_uri=str(client_info.redirect_uris[0]) if client_info.redirect_uris else None, ) await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor) class MCPOAuthSession: """Legacy OAuth session class - deprecated, use mcp_manager directly.""" def __init__(self, server_url: str, server_name: str, user_id: Optional[str], organization_id: str): self.server_url = server_url self.server_name = server_name self.user_id = user_id self.organization_id = organization_id self.session_id = str(uuid.uuid4()) self.state = secrets.token_urlsafe(32) def __init__(self, session_id: str): self.session_id = session_id # TODO: consolidate / deprecate this in favor of mcp_manager access async def create_session(self) -> str: """Create a new OAuth session in the database.""" async with db_registry.async_session() as session: oauth_record = MCPOAuth( id=self.session_id, state=self.state, server_url=self.server_url, server_name=self.server_name, user_id=self.user_id, organization_id=self.organization_id, status=OAuthSessionStatus.PENDING, created_at=datetime.now(), updated_at=datetime.now(), ) oauth_record = await oauth_record.create_async(session, actor=None) return self.session_id async def get_session_status(self) -> OAuthSessionStatus: """Get the current status of the OAuth session.""" async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) return oauth_record.status except Exception: return OAuthSessionStatus.ERROR async def update_session_status(self, status: OAuthSessionStatus) -> None: """Update the session status.""" async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) oauth_record.status = status oauth_record.updated_at = datetime.now() await oauth_record.update_async(db_session=session, actor=None) except Exception: pass async def store_authorization_code(self, code: str, state: str) -> Optional[MCPOAuth]: """Store the authorization code from OAuth callback.""" from letta.schemas.secret import Secret async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) # Encrypt the authorization_code and store only in _enc column (async to avoid blocking event loop) if code is not None: code_secret = await Secret.from_plaintext_async(code) oauth_record.authorization_code_enc = code_secret.get_encrypted() oauth_record.status = OAuthSessionStatus.AUTHORIZED oauth_record.state = state return await oauth_record.update_async(db_session=session, actor=None) except Exception: return None async def get_authorization_url(self) -> Optional[str]: """Get the authorization URL for this session.""" async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) return oauth_record.authorization_url except Exception: return None async def set_authorization_url(self, url: str) -> None: """Set the authorization URL for this session.""" async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) oauth_record.authorization_url = url oauth_record.updated_at = datetime.now() await oauth_record.update_async(db_session=session, actor=None) except Exception: pass async def create_oauth_provider( session_id: str, server_url: str, redirect_uri: str, mcp_manager: "MCPManager", actor: PydanticUser, logo_uri: Optional[str] = None, url_callback: Optional[Callable[[str], None]] = None, ) -> OAuthClientProvider: """Create an OAuth provider for MCP server authentication. DEPRECATED: Use ServerSideOAuth from letta.services.mcp.server_side_oauth instead. This function is kept for backwards compatibility but will be removed in a future version. """ logger.warning("create_oauth_provider is deprecated. Use ServerSideOAuth from letta.services.mcp.server_side_oauth instead.") client_metadata_dict = { "client_name": "Letta", "redirect_uris": [redirect_uri], "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], "token_endpoint_auth_method": "client_secret_post", "logo_uri": logo_uri, } # Use manager-based storage storage = DatabaseTokenStorage(session_id, mcp_manager, actor) # Extract base URL (remove /mcp endpoint if present) oauth_server_url = server_url.rstrip("/").removesuffix("/sse").removesuffix("/mcp") async def redirect_handler(authorization_url: str) -> None: """Handle OAuth redirect by storing the authorization URL.""" logger.info(f"OAuth redirect handler called with URL: {authorization_url}") session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url) await mcp_manager.update_oauth_session(session_id, session_update, actor) logger.info(f"OAuth authorization URL stored: {authorization_url}") # Call the callback if provided (e.g., to yield URL to SSE stream) if url_callback: url_callback(authorization_url) async def callback_handler() -> Tuple[str, Optional[str]]: """Handle OAuth callback by waiting for authorization code.""" timeout = 300 # 5 minutes start_time = time.time() logger.info(f"Waiting for authorization code for session {session_id}") while time.time() - start_time < timeout: oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor) if oauth_session and oauth_session.authorization_code_enc: # Read authorization code directly from _enc column auth_code = await oauth_session.authorization_code_enc.get_plaintext_async() return auth_code, oauth_session.state elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR: raise Exception("OAuth authorization failed") await asyncio.sleep(1) raise Exception(f"Timeout waiting for OAuth callback after {timeout} seconds") return OAuthClientProvider( server_url=oauth_server_url, client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), storage=storage, redirect_handler=redirect_handler, callback_handler=callback_handler, ) async def cleanup_expired_oauth_sessions(max_age_hours: int = 24) -> None: """Clean up expired OAuth sessions.""" cutoff_time = datetime.now() - timedelta(hours=max_age_hours) async with db_registry.async_session() as session: result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time)) expired_sessions = result.scalars().all() for oauth_session in expired_sessions: await oauth_session.hard_delete_async(db_session=session, actor=None) if expired_sessions: logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions") def oauth_stream_event(event: OauthStreamEvent, **kwargs) -> str: data = {"event": event.value} data.update(kwargs) return f"data: {json.dumps(data)}\n\n" def drill_down_exception(exception, depth=0, max_depth=5): """Recursively drill down into nested exceptions to find the root cause""" indent = " " * depth error_details = [] error_details.append(f"{indent}Exception at depth {depth}:") error_details.append(f"{indent} Type: {type(exception).__name__}") error_details.append(f"{indent} Message: {str(exception)}") error_details.append(f"{indent} Module: {getattr(type(exception), '__module__', 'unknown')}") # Check for exception groups (TaskGroup errors) if hasattr(exception, "exceptions") and exception.exceptions: error_details.append(f"{indent} ExceptionGroup with {len(exception.exceptions)} sub-exceptions:") for i, sub_exc in enumerate(exception.exceptions): error_details.append(f"{indent} Sub-exception {i}:") if depth < max_depth: error_details.extend(drill_down_exception(sub_exc, depth + 1, max_depth)) # Check for chained exceptions (__cause__ and __context__) if hasattr(exception, "__cause__") and exception.__cause__ and depth < max_depth: error_details.append(f"{indent} Caused by:") error_details.extend(drill_down_exception(exception.__cause__, depth + 1, max_depth)) if hasattr(exception, "__context__") and exception.__context__ and depth < max_depth: error_details.append(f"{indent} Context:") error_details.extend(drill_down_exception(exception.__context__, depth + 1, max_depth)) # Add traceback info import traceback if hasattr(exception, "__traceback__") and exception.__traceback__: tb_lines = traceback.format_tb(exception.__traceback__) error_details.append(f"{indent} Traceback:") for line in tb_lines[-3:]: # Show last 3 traceback lines error_details.append(f"{indent} {line.strip()}") error_info = "".join(error_details) return error_info