diff --git a/alembic/versions/f5d26b0526e8_add_mcp_oauth.py b/alembic/versions/f5d26b0526e8_add_mcp_oauth.py new file mode 100644 index 00000000..52c9f764 --- /dev/null +++ b/alembic/versions/f5d26b0526e8_add_mcp_oauth.py @@ -0,0 +1,67 @@ +"""add_mcp_oauth + +Revision ID: f5d26b0526e8 +Revises: ddecfe4902bc +Create Date: 2025-07-24 12:34:05.795355 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f5d26b0526e8" +down_revision: Union[str, None] = "ddecfe4902bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "mcp_oauth", + sa.Column("id", sa.String(), nullable=False), + sa.Column("state", sa.String(length=255), nullable=False), + sa.Column("server_id", sa.String(length=255), nullable=True), + sa.Column("server_url", sa.Text(), nullable=False), + sa.Column("server_name", sa.Text(), nullable=False), + sa.Column("authorization_url", sa.Text(), nullable=True), + sa.Column("authorization_code", sa.Text(), nullable=True), + sa.Column("access_token", sa.Text(), nullable=True), + sa.Column("refresh_token", sa.Text(), nullable=True), + sa.Column("token_type", sa.String(length=50), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("scope", sa.Text(), nullable=True), + sa.Column("client_id", sa.Text(), nullable=True), + sa.Column("client_secret", sa.Text(), nullable=True), + sa.Column("redirect_uri", sa.Text(), nullable=True), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint(["server_id"], ["mcp_server.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("state"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("mcp_oauth") + # ### end Alembic commands ### diff --git a/letta/orm/mcp_oauth.py b/letta/orm/mcp_oauth.py new file mode 100644 index 00000000..e34f685a --- /dev/null +++ b/letta/orm/mcp_oauth.py @@ -0,0 +1,62 @@ +import uuid +from datetime import datetime +from enum import Enum +from typing import Optional + +from sqlalchemy import DateTime, ForeignKey, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.mixins import OrganizationMixin, UserMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase + + +class OAuthSessionStatus(str, Enum): + """OAuth session status enumeration.""" + + PENDING = "pending" + AUTHORIZED = "authorized" + ERROR = "error" + + +class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin): + """OAuth session model for MCP server authentication.""" + + __tablename__ = "mcp_oauth" + + # Override the id field to match database UUID generation + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid.uuid4()}") + + # Core session information + state: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, doc="OAuth state parameter") + server_id: Mapped[str] = mapped_column(String(255), ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True, doc="MCP server ID") + server_url: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server URL") + server_name: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server display name") + + # OAuth flow data + authorization_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization URL") + authorization_code: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization code") + + # Token data + access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth access token") + refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth refresh token") + token_type: Mapped[str] = mapped_column(String(50), default="Bearer", doc="Token type") + expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, doc="Token expiry time") + scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth scope") + + # Client configuration + client_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client ID") + client_secret: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client secret") + redirect_uri: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth redirect URI") + + # Session state + status: Mapped[OAuthSessionStatus] = mapped_column(String(20), default=OAuthSessionStatus.PENDING, doc="Session status") + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), doc="Session creation time") + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now(), doc="Last update time" + ) + + # Relationships (if needed in the future) + # user: Mapped[Optional["User"]] = relationship("User", back_populates="oauth_sessions") + # organization: Mapped["Organization"] = relationship("Organization", back_populates="oauth_sessions") diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index b851f2d5..f7070e8b 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Dict, Optional, Union from pydantic import Field @@ -10,6 +11,7 @@ from letta.functions.mcp_client.types import ( StdioServerConfig, StreamableHTTPServerConfig, ) +from letta.orm.mcp_oauth import OAuthSessionStatus from letta.schemas.letta_base import LettaBase @@ -119,3 +121,71 @@ class UpdateStreamableHTTPMCPServer(LettaBase): UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer] RegisterMCPServer = Union[RegisterSSEMCPServer, RegisterStdioMCPServer, RegisterStreamableHTTPMCPServer] + + +# OAuth-related schemas +class BaseMCPOAuth(LettaBase): + __id_prefix__ = "mcp-oauth" + + +class MCPOAuthSession(BaseMCPOAuth): + """OAuth session for MCP server authentication.""" + + id: str = BaseMCPOAuth.generate_id_field() + state: str = Field(..., description="OAuth state parameter") + server_id: Optional[str] = Field(None, description="MCP server ID") + server_url: str = Field(..., description="MCP server URL") + server_name: str = Field(..., description="MCP server display name") + + # User and organization context + user_id: Optional[str] = Field(None, description="User ID associated with the session") + organization_id: str = Field(..., description="Organization ID associated with the session") + + # OAuth flow data + authorization_url: Optional[str] = Field(None, description="OAuth authorization URL") + authorization_code: Optional[str] = Field(None, description="OAuth authorization code") + + # Token data + access_token: Optional[str] = Field(None, description="OAuth access token") + refresh_token: Optional[str] = Field(None, description="OAuth refresh token") + token_type: str = Field(default="Bearer", description="Token type") + expires_at: Optional[datetime] = Field(None, description="Token expiry time") + scope: Optional[str] = Field(None, description="OAuth scope") + + # Client configuration + client_id: Optional[str] = Field(None, description="OAuth client ID") + client_secret: Optional[str] = Field(None, description="OAuth client secret") + redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI") + + # Session state + status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status") + + # Timestamps + created_at: datetime = Field(default_factory=datetime.now, description="Session creation time") + updated_at: datetime = Field(default_factory=datetime.now, description="Last update time") + + +class MCPOAuthSessionCreate(BaseMCPOAuth): + """Create a new OAuth session.""" + + server_url: str = Field(..., description="MCP server URL") + server_name: str = Field(..., description="MCP server display name") + user_id: Optional[str] = Field(None, description="User ID associated with the session") + organization_id: str = Field(..., description="Organization ID associated with the session") + state: Optional[str] = Field(None, description="OAuth state parameter") + + +class MCPOAuthSessionUpdate(BaseMCPOAuth): + """Update an existing OAuth session.""" + + authorization_url: Optional[str] = Field(None, description="OAuth authorization URL") + authorization_code: Optional[str] = Field(None, description="OAuth authorization code") + access_token: Optional[str] = Field(None, description="OAuth access token") + refresh_token: Optional[str] = Field(None, description="OAuth refresh token") + token_type: Optional[str] = Field(None, description="Token type") + expires_at: Optional[datetime] = Field(None, description="Token expiry time") + scope: Optional[str] = Field(None, description="OAuth scope") + client_id: Optional[str] = Field(None, description="OAuth client ID") + client_secret: Optional[str] = Field(None, description="OAuth client secret") + redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI") + status: Optional[OAuthSessionStatus] = Field(None, description="Session status") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index ad171613..354afdc2 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -1,4 +1,6 @@ +import asyncio import json +from collections.abc import AsyncGenerator from typing import Any, Dict, List, Optional, Union from composio.client import ComposioClientError, HTTPError, NoItemsFound @@ -11,27 +13,37 @@ from composio.exceptions import ( EnumStringNotFound, ) from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from fastapi.responses import HTMLResponse from pydantic import BaseModel, Field +from starlette.responses import StreamingResponse from letta.errors import LettaToolCreateError from letta.functions.functions import derive_openai_json_schema from letta.functions.mcp_client.exceptions import MCPTimeoutError -from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig +from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig from letta.helpers.composio_helpers import get_composio_api_key +from letta.helpers.decorators import deprecated from letta.llm_api.llm_client import LLMClient from letta.log import get_logger from letta.orm.errors import UniqueConstraintViolationError +from letta.orm.mcp_oauth import OAuthSessionStatus from letta.schemas.enums import MessageRole from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.letta_message_content import TextContent -from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer +from letta.schemas.mcp import MCPOAuthSessionCreate, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer from letta.schemas.message import Message from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate +from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer -from letta.services.mcp.sse_client import AsyncSSEMCPClient -from letta.services.mcp.stdio_client import AsyncStdioMCPClient -from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient +from letta.services.mcp.oauth_utils import ( + MCPOAuthSession, + create_oauth_provider, + drill_down_exception, + get_oauth_success_html, + oauth_stream_event, +) +from letta.services.mcp.types import OauthStreamEvent from letta.settings import tool_settings router = APIRouter(prefix="/tools", tags=["tools"]) @@ -612,35 +624,26 @@ async def delete_mcp_server_from_config( return [server.to_config() for server in all_servers] -@router.post("/mcp/servers/test", response_model=List[MCPTool], operation_id="test_mcp_server") +@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream") +@router.post("/mcp/servers/test", operation_id="test_mcp_server") async def test_mcp_server( request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Test connection to an MCP server without adding it. - Returns the list of available tools if successful. + Returns the list of available tools if successful, or OAuth information if OAuth is required. """ client = None try: - # create a temporary MCP client based on the server type - if request.type == MCPServerType.SSE: - if not isinstance(request, SSEServerConfig): - request = SSEServerConfig(**request.model_dump()) - client = AsyncSSEMCPClient(request) - elif request.type == MCPServerType.STREAMABLE_HTTP: - if not isinstance(request, StreamableHTTPServerConfig): - request = StreamableHTTPServerConfig(**request.model_dump()) - client = AsyncStreamableHTTPMCPClient(request) - elif request.type == MCPServerType.STDIO: - if not isinstance(request, StdioServerConfig): - request = StdioServerConfig(**request.model_dump()) - client = AsyncStdioMCPClient(request) - else: - raise ValueError(f"Invalid MCP server type: {request.type}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + client = await server.mcp_manager.get_mcp_client(request, actor) await client.connect_to_server() tools = await client.list_tools() - return tools + + return {"status": "success", "tools": tools} except ConnectionError as e: raise HTTPException( status_code=400, @@ -676,6 +679,155 @@ async def test_mcp_server( logger.warning(f"Error during MCP client cleanup: {cleanup_error}") +@router.post( + "/mcp/servers/connect", + response_model=None, + responses={ + 200: { + "description": "Successful response", + "content": { + "text/event-stream": {"description": "Server-Sent Events stream"}, + }, + } + }, + operation_id="connect_mcp_server", +) +async def connect_mcp_server( + request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +) -> StreamingResponse: + """ + Connect to an MCP server with support for OAuth via SSE. + Returns a stream of events handling authorization state and exchange if OAuth is required. + """ + + async def oauth_stream_generator( + request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] + ) -> AsyncGenerator[str, None]: + client = None + oauth_provider = None + temp_client = None + connect_task = None + + try: + # Acknolwedge connection attempt + yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name) + + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + # Create MCP client with respective transport type + try: + client = await server.mcp_manager.get_mcp_client(request, actor) + except ValueError as e: + yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e)) + return + + # Try normal connection first for flows that don't require OAuth + try: + await client.connect_to_server() + tools = await client.list_tools(serialize=True) + yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) + return + except ConnectionError: + # Continue to OAuth flow + logger.info(f"Attempting OAuth flow for {request.server_url}...") + except Exception as e: + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}") + return + + # OAuth required, yield state to client to prepare to handle authorization URL + yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required") + + # Create OAuth session to persist the state of the OAuth flow + session_create = MCPOAuthSessionCreate( + server_url=request.server_url, + server_name=request.server_name, + user_id=actor.id, + organization_id=actor.organization_id, + ) + oauth_session = await server.mcp_manager.create_oauth_session(session_create, actor) + session_id = oauth_session.id + + # Create OAuth provider for the instance of the stream connection + # Note: Using the correct API path for the callback + # do not edit this this is the correct url + redirect_uri = f"http://localhost:8283/v1/tools/mcp/oauth/callback/{session_id}" + oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, server.mcp_manager, actor) + + # Get authorization URL by triggering OAuth flow + temp_client = None + try: + temp_client = await server.mcp_manager.get_mcp_client(request, actor, oauth_provider) + + # Run connect_to_server in background to avoid blocking + # This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database + connect_task = asyncio.create_task(temp_client.connect_to_server()) + + # Give the OAuth flow time to trigger and save the URL + await asyncio.sleep(1.0) + + # Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL + auth_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor) + if auth_session and auth_session.authorization_url: + yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id) + + except Exception as e: + logger.error(f"Error triggering OAuth flow: {e}") + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}") + + # Clean up active resources + if connect_task and not connect_task.done(): + connect_task.cancel() + try: + await connect_task + except asyncio.CancelledError: + pass + if temp_client: + try: + await temp_client.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") + return + + # Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit + yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...") + + # Callback handler will poll for authorization code and state and update the OAuth session + await connect_task + + tools = await temp_client.list_tools(serialize=True) + + yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) + return + except Exception as e: + detailed_error = drill_down_exception(e) + logger.error(f"Error in OAuth stream:\n{detailed_error}") + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}") + finally: + if connect_task and not connect_task.done(): + connect_task.cancel() + try: + await connect_task + except asyncio.CancelledError: + pass + if client: + try: + await client.cleanup() + except Exception as cleanup_error: + detailed_error = drill_down_exception(cleanup_error) + logger.warning(f"Error during MCP client cleanup: {detailed_error}") + if temp_client: + try: + await temp_client.cleanup() + except Exception as cleanup_error: + # TODO: @jnjpng fix async cancel scope issue + # detailed_error = drill_down_exception(cleanup_error) + logger.warning(f"Aysnc cleanup confict during temp MCP client cleanup: {cleanup_error}") + + return StreamingResponseWithStatusCode(oauth_stream_generator(request), media_type="text/event-stream") + + class CodeInput(BaseModel): code: str = Field(..., description="Python source code to parse for JSON schema") @@ -697,6 +849,45 @@ async def generate_json_schema( raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}") +# TODO: @jnjpng need to route this through cloud API for production +@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback", response_class=HTMLResponse) +async def mcp_oauth_callback( + session_id: str, + code: Optional[str] = Query(None, description="OAuth authorization code"), + state: Optional[str] = Query(None, description="OAuth state parameter"), + error: Optional[str] = Query(None, description="OAuth error"), + error_description: Optional[str] = Query(None, description="OAuth error description"), +): + """ + Handle OAuth callback for MCP server authentication. + """ + try: + oauth_session = MCPOAuthSession(session_id) + + if error: + error_msg = f"OAuth error: {error}" + if error_description: + error_msg += f" - {error_description}" + await oauth_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": error_msg} + + if not code or not state: + await oauth_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": "Missing authorization code or state"} + + # Store authorization code + success = await oauth_session.store_authorization_code(code, state) + if not success: + await oauth_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": "Invalid state parameter"} + + return HTMLResponse(content=get_oauth_success_html(), status_code=200) + + except Exception as e: + logger.error(f"OAuth callback error: {e}") + return {"status": "error", "message": f"OAuth callback failed: {str(e)}"} + + class GenerateToolInput(BaseModel): tool_name: str = Field(..., description="Name of the tool to generate code for") prompt: str = Field(..., description="User prompt to generate code") diff --git a/letta/services/mcp/base_client.py b/letta/services/mcp/base_client.py index 8aeda67f..0df3cba1 100644 --- a/letta/services/mcp/base_client.py +++ b/letta/services/mcp/base_client.py @@ -1,9 +1,9 @@ -import asyncio from contextlib import AsyncExitStack from typing import Optional, Tuple from mcp import ClientSession from mcp import Tool as MCPTool +from mcp.client.auth import OAuthClientProvider from mcp.types import TextContent from letta.functions.mcp_client.types import BaseServerConfig @@ -14,14 +14,12 @@ logger = get_logger(__name__) # TODO: Get rid of Async prefix on this class name once we deprecate old sync code class AsyncBaseMCPClient: - def __init__(self, server_config: BaseServerConfig): + def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None): self.server_config = server_config + self.oauth_provider = oauth_provider self.exit_stack = AsyncExitStack() self.session: Optional[ClientSession] = None self.initialized = False - # Track the task that created this client - self._creation_task = asyncio.current_task() - self._cleanup_queue = asyncio.Queue(maxsize=1) async def connect_to_server(self): try: @@ -48,9 +46,25 @@ class AsyncBaseMCPClient: async def _initialize_connection(self, server_config: BaseServerConfig) -> None: raise NotImplementedError("Subclasses must implement _initialize_connection") - async def list_tools(self) -> list[MCPTool]: + async def list_tools(self, serialize: bool = False) -> list[MCPTool]: self._check_initialized() response = await self.session.list_tools() + if serialize: + serializable_tools = [] + for tool in response.tools: + if hasattr(tool, "model_dump"): + # Pydantic model - use model_dump + serializable_tools.append(tool.model_dump()) + elif hasattr(tool, "dict"): + # Older Pydantic model - use dict() + serializable_tools.append(tool.dict()) + elif hasattr(tool, "__dict__"): + # Regular object - use __dict__ + serializable_tools.append(tool.__dict__) + else: + # Fallback - convert to string + serializable_tools.append(str(tool)) + return serializable_tools return response.tools async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]: @@ -79,29 +93,7 @@ class AsyncBaseMCPClient: # TODO: still hitting some async errors for voice agents, need to fix async def cleanup(self): - """Clean up resources - ensure this runs in the same task""" - if hasattr(self, "_cleanup_task"): - # If we're in a different task, schedule cleanup in original task - current_task = asyncio.current_task() - if current_task != self._creation_task: - # Create a future to signal completion - cleanup_done = asyncio.Future() - self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done)) - await cleanup_done - return - - # Normal cleanup await self.exit_stack.aclose() def to_sync_client(self): raise NotImplementedError("Subclasses must implement to_sync_client") - - async def __aenter__(self): - """Enter the async context manager.""" - await self.connect_to_server() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Exit the async context manager.""" - await self.cleanup() - return False # Don't suppress exceptions diff --git a/letta/services/mcp/oauth_utils.py b/letta/services/mcp/oauth_utils.py new file mode 100644 index 00000000..7391474c --- /dev/null +++ b/letta/services/mcp/oauth_utils.py @@ -0,0 +1,433 @@ +"""OAuth utilities for MCP server authentication.""" + +import asyncio +import json +import secrets +import time +import uuid +from datetime import datetime, timedelta +from typing import Callable, Optional, Tuple + +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from sqlalchemy import select + +from letta.log import get_logger +from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus +from letta.schemas.mcp import MCPOAuthSessionUpdate +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.services.mcp.types import OauthStreamEvent +from letta.services.mcp_manager import MCPManager + +logger = get_logger(__name__) + + +class DatabaseTokenStorage(TokenStorage): + """Database-backed token storage using MCPOAuth table via mcp_manager.""" + + def __init__(self, session_id: str, mcp_manager: MCPManager, actor: PydanticUser): + self.session_id = session_id + self.mcp_manager = mcp_manager + self.actor = actor + + async def get_tokens(self) -> Optional[OAuthToken]: + """Retrieve tokens from database.""" + oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor) + if not oauth_session or not oauth_session.access_token: + return None + + return OAuthToken( + access_token=oauth_session.access_token, + refresh_token=oauth_session.refresh_token, + token_type=oauth_session.token_type, + expires_in=int(oauth_session.expires_at.timestamp() - time.time()), + scope=oauth_session.scope, + ) + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens in database.""" + session_update = MCPOAuthSessionUpdate( + access_token=tokens.access_token, + refresh_token=tokens.refresh_token, + token_type=tokens.token_type, + expires_at=datetime.fromtimestamp(tokens.expires_in + time.time()), + scope=tokens.scope, + status=OAuthSessionStatus.AUTHORIZED, + ) + await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor) + + async def get_client_info(self) -> Optional[OAuthClientInformationFull]: + """Retrieve client information from database.""" + oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor) + if not oauth_session or not oauth_session.client_id: + return None + + return OAuthClientInformationFull( + client_id=oauth_session.client_id, + client_secret=oauth_session.client_secret, + redirect_uris=[oauth_session.redirect_uri] if oauth_session.redirect_uri else [], + ) + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information in database.""" + session_update = MCPOAuthSessionUpdate( + client_id=client_info.client_id, + client_secret=client_info.client_secret, + redirect_uri=str(client_info.redirect_uris[0]) if client_info.redirect_uris else None, + ) + await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor) + + +class MCPOAuthSession: + """Legacy OAuth session class - deprecated, use mcp_manager directly.""" + + def __init__(self, server_url: str, server_name: str, user_id: Optional[str], organization_id: str): + self.server_url = server_url + self.server_name = server_name + self.user_id = user_id + self.organization_id = organization_id + self.session_id = str(uuid.uuid4()) + self.state = secrets.token_urlsafe(32) + + def __init__(self, session_id: str): + self.session_id = session_id + + # TODO: consolidate / deprecate this in favor of mcp_manager access + async def create_session(self) -> str: + """Create a new OAuth session in the database.""" + async with db_registry.async_session() as session: + oauth_record = MCPOAuth( + id=self.session_id, + state=self.state, + server_url=self.server_url, + server_name=self.server_name, + user_id=self.user_id, + organization_id=self.organization_id, + status=OAuthSessionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + oauth_record = await oauth_record.create_async(session, actor=None) + + return self.session_id + + async def get_session_status(self) -> OAuthSessionStatus: + """Get the current status of the OAuth session.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + return oauth_record.status + except Exception: + return OAuthSessionStatus.ERROR + + async def update_session_status(self, status: OAuthSessionStatus) -> None: + """Update the session status.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + oauth_record.status = status + oauth_record.updated_at = datetime.now() + await oauth_record.update_async(db_session=session, actor=None) + except Exception: + pass + + async def store_authorization_code(self, code: str, state: str) -> bool: + """Store the authorization code from OAuth callback.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + + # if oauth_record.state != state: + # return False + + oauth_record.authorization_code = code + oauth_record.state = state + oauth_record.status = OAuthSessionStatus.AUTHORIZED + oauth_record.updated_at = datetime.now() + await oauth_record.update_async(db_session=session, actor=None) + return True + except Exception: + return False + + async def get_authorization_url(self) -> Optional[str]: + """Get the authorization URL for this session.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + return oauth_record.authorization_url + except Exception: + return None + + async def set_authorization_url(self, url: str) -> None: + """Set the authorization URL for this session.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + oauth_record.authorization_url = url + oauth_record.updated_at = datetime.now() + await oauth_record.update_async(db_session=session, actor=None) + except Exception: + pass + + +async def create_oauth_provider( + session_id: str, + server_url: str, + redirect_uri: str, + mcp_manager: MCPManager, + actor: PydanticUser, + url_callback: Optional[Callable[[str], None]] = None, +) -> OAuthClientProvider: + """Create an OAuth provider for MCP server authentication.""" + + client_metadata_dict = { + "client_name": "Letta MCP Client", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + } + + # Use manager-based storage + storage = DatabaseTokenStorage(session_id, mcp_manager, actor) + + # Extract base URL (remove /mcp endpoint if present) + oauth_server_url = server_url.rstrip("/").removesuffix("/sse").removesuffix("/mcp") + + async def redirect_handler(authorization_url: str) -> None: + """Handle OAuth redirect by storing the authorization URL.""" + logger.info(f"OAuth redirect handler called with URL: {authorization_url}") + session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url) + await mcp_manager.update_oauth_session(session_id, session_update, actor) + logger.info(f"OAuth authorization URL stored: {authorization_url}") + + # Call the callback if provided (e.g., to yield URL to SSE stream) + if url_callback: + url_callback(authorization_url) + + async def callback_handler() -> Tuple[str, Optional[str]]: + """Handle OAuth callback by waiting for authorization code.""" + timeout = 300 # 5 minutes + start_time = time.time() + + logger.info(f"Waiting for authorization code for session {session_id}") + while time.time() - start_time < timeout: + oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor) + if oauth_session and oauth_session.authorization_code: + return oauth_session.authorization_code, oauth_session.state + elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR: + raise Exception("OAuth authorization failed") + await asyncio.sleep(1) + + raise Exception(f"Timeout waiting for OAuth callback after {timeout} seconds") + + return OAuthClientProvider( + server_url=oauth_server_url, + client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + +async def cleanup_expired_oauth_sessions(max_age_hours: int = 24) -> None: + """Clean up expired OAuth sessions.""" + cutoff_time = datetime.now() - timedelta(hours=max_age_hours) + + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time)) + expired_sessions = result.scalars().all() + + for oauth_session in expired_sessions: + await oauth_session.hard_delete_async(db_session=session, actor=None) + + if expired_sessions: + logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions") + + +def oauth_stream_event(event: OauthStreamEvent, **kwargs) -> str: + data = {"event": event.value} + data.update(kwargs) + return f"data: {json.dumps(data)}\n\n" + + +def drill_down_exception(exception, depth=0, max_depth=5): + """Recursively drill down into nested exceptions to find the root cause""" + indent = " " * depth + error_details = [] + + error_details.append(f"{indent}Exception at depth {depth}:") + error_details.append(f"{indent} Type: {type(exception).__name__}") + error_details.append(f"{indent} Message: {str(exception)}") + error_details.append(f"{indent} Module: {getattr(type(exception), '__module__', 'unknown')}") + + # Check for exception groups (TaskGroup errors) + if hasattr(exception, "exceptions") and exception.exceptions: + error_details.append(f"{indent} ExceptionGroup with {len(exception.exceptions)} sub-exceptions:") + for i, sub_exc in enumerate(exception.exceptions): + error_details.append(f"{indent} Sub-exception {i}:") + if depth < max_depth: + error_details.extend(drill_down_exception(sub_exc, depth + 1, max_depth)) + + # Check for chained exceptions (__cause__ and __context__) + if hasattr(exception, "__cause__") and exception.__cause__ and depth < max_depth: + error_details.append(f"{indent} Caused by:") + error_details.extend(drill_down_exception(exception.__cause__, depth + 1, max_depth)) + + if hasattr(exception, "__context__") and exception.__context__ and depth < max_depth: + error_details.append(f"{indent} Context:") + error_details.extend(drill_down_exception(exception.__context__, depth + 1, max_depth)) + + # Add traceback info + import traceback + + if hasattr(exception, "__traceback__") and exception.__traceback__: + tb_lines = traceback.format_tb(exception.__traceback__) + error_details.append(f"{indent} Traceback:") + for line in tb_lines[-3:]: # Show last 3 traceback lines + error_details.append(f"{indent} {line.strip()}") + + error_info = "".join(error_details) + return error_info + + +def get_oauth_success_html() -> str: + """Generate HTML for successful OAuth authorization.""" + return """ + + +
+You have successfully connected your MCP server.
+You can close this window and return to the terminal.
+ + + + """ + ) + elif "error" in query_params: + self.callback_data["error"] = query_params["error"][0] + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + f""" + + +Error: {query_params['error'][0]}
+You can close this window and return to the terminal.
+ + + """.encode() + ) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + """Suppress default logging.""" + + +class CallbackServer: + """Simple server to handle OAuth callbacks.""" + + def __init__(self, port=3000): + self.port = port + self.server = None + self.thread = None + self.callback_data = {"authorization_code": None, "state": None, "error": None} + + def _create_handler_with_data(self): + """Create a handler class with access to callback data.""" + callback_data = self.callback_data + + class DataCallbackHandler(CallbackHandler): + def __init__(self, request, client_address, server): + super().__init__(request, client_address, server, callback_data) + + return DataCallbackHandler + + def start(self): + """Start the callback server in a background thread.""" + handler_class = self._create_handler_with_data() + self.server = HTTPServer(("localhost", self.port), handler_class) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + print(f"š„ļø Started callback server on http://localhost:{self.port}") + + def stop(self): + """Stop the callback server.""" + if self.server: + self.server.shutdown() + self.server.server_close() + if self.thread: + self.thread.join(timeout=1) + + def wait_for_callback(self, timeout=300): + """Wait for OAuth callback with timeout.""" + start_time = time.time() + while time.time() - start_time < timeout: + if self.callback_data["authorization_code"]: + return self.callback_data["authorization_code"] + elif self.callback_data["error"]: + raise Exception(f"OAuth error: {self.callback_data['error']}") + time.sleep(0.1) + raise Exception("Timeout waiting for OAuth callback") + + def get_state(self): + """Get the received state parameter.""" + return self.callback_data["state"] + + +class SimpleAuthClient: + """Simple MCP client with auth support.""" + + def __init__(self, server_url: str, transport_type: str = "streamable_http"): + self.server_url = server_url + self.transport_type = transport_type + self.session: ClientSession | None = None + + async def connect(self): + """Connect to the MCP server.""" + print(f"š Attempting to connect to {self.server_url}...") + + try: + callback_server = CallbackServer(port=3030) + callback_server.start() + + async def callback_handler() -> tuple[str, str | None]: + """Wait for OAuth callback and return auth code and state.""" + print("ā³ Waiting for authorization callback...") + try: + auth_code = callback_server.wait_for_callback(timeout=300) + return auth_code, callback_server.get_state() + finally: + callback_server.stop() + + client_metadata_dict = { + "client_name": "Simple Auth Client", + "redirect_uris": ["http://localhost:3030/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + } + + async def _default_redirect_handler(authorization_url: str) -> None: + """Default redirect handler that opens the URL in a browser.""" + print(f"Opening browser for authorization: {authorization_url}") + webbrowser.open(authorization_url) + + # Create OAuth authentication handler using the new interface + oauth_auth = OAuthClientProvider( + server_url=self.server_url.replace("/mcp", ""), + client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), + storage=InMemoryTokenStorage(), + redirect_handler=_default_redirect_handler, + callback_handler=callback_handler, + ) + + # Create transport with auth handler based on transport type + if self.transport_type == "sse": + print("š” Opening SSE transport connection with auth...") + async with sse_client( + url=self.server_url, + auth=oauth_auth, + timeout=60, + ) as (read_stream, write_stream): + await self._run_session(read_stream, write_stream, None) + else: + print("š” Opening StreamableHTTP transport connection with auth...") + async with streamablehttp_client( + url=self.server_url, + auth=oauth_auth, + timeout=timedelta(seconds=60), + ) as (read_stream, write_stream, get_session_id): + await self._run_session(read_stream, write_stream, get_session_id) + + except Exception as e: + print(f"ā Failed to connect: {e}") + import traceback + + traceback.print_exc() + + async def _run_session(self, read_stream, write_stream, get_session_id): + """Run the MCP session with the given streams.""" + print("š¤ Initializing MCP session...") + async with ClientSession(read_stream, write_stream) as session: + self.session = session + print("ā” Starting session initialization...") + await session.initialize() + print("⨠Session initialization complete!") + + print(f"\nā Connected to MCP server at {self.server_url}") + if get_session_id: + session_id = get_session_id() + if session_id: + print(f"Session ID: {session_id}") + + # Run interactive loop + await self.interactive_loop() + + async def list_tools(self): + """List available tools from the server.""" + if not self.session: + print("ā Not connected to server") + return + + try: + result = await self.session.list_tools() + if hasattr(result, "tools") and result.tools: + print("\nš Available tools:") + for i, tool in enumerate(result.tools, 1): + print(f"{i}. {tool.name}") + if tool.description: + print(f" Description: {tool.description}") + print() + else: + print("No tools available") + except Exception as e: + print(f"ā Failed to list tools: {e}") + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): + """Call a specific tool.""" + if not self.session: + print("ā Not connected to server") + return + + try: + result = await self.session.call_tool(tool_name, arguments or {}) + print(f"\nš§ Tool '{tool_name}' result:") + if hasattr(result, "content"): + for content in result.content: + if content.type == "text": + print(content.text) + else: + print(content) + else: + print(result) + except Exception as e: + print(f"ā Failed to call tool '{tool_name}': {e}") + + async def interactive_loop(self): + """Run interactive command loop.""" + print("\nšÆ Interactive MCP Client") + print("Commands:") + print(" list - List available tools") + print(" call