diff --git a/alembic/versions/f5d26b0526e8_add_mcp_oauth.py b/alembic/versions/f5d26b0526e8_add_mcp_oauth.py new file mode 100644 index 00000000..52c9f764 --- /dev/null +++ b/alembic/versions/f5d26b0526e8_add_mcp_oauth.py @@ -0,0 +1,67 @@ +"""add_mcp_oauth + +Revision ID: f5d26b0526e8 +Revises: ddecfe4902bc +Create Date: 2025-07-24 12:34:05.795355 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "f5d26b0526e8" +down_revision: Union[str, None] = "ddecfe4902bc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "mcp_oauth", + sa.Column("id", sa.String(), nullable=False), + sa.Column("state", sa.String(length=255), nullable=False), + sa.Column("server_id", sa.String(length=255), nullable=True), + sa.Column("server_url", sa.Text(), nullable=False), + sa.Column("server_name", sa.Text(), nullable=False), + sa.Column("authorization_url", sa.Text(), nullable=True), + sa.Column("authorization_code", sa.Text(), nullable=True), + sa.Column("access_token", sa.Text(), nullable=True), + sa.Column("refresh_token", sa.Text(), nullable=True), + sa.Column("token_type", sa.String(length=50), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("scope", sa.Text(), nullable=True), + sa.Column("client_id", sa.Text(), nullable=True), + sa.Column("client_secret", sa.Text(), nullable=True), + sa.Column("redirect_uri", sa.Text(), nullable=True), + sa.Column("status", sa.String(length=20), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.Column("user_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.ForeignKeyConstraint(["server_id"], ["mcp_server.id"], ondelete="CASCADE"), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + ), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("state"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("mcp_oauth") + # ### end Alembic commands ### diff --git a/letta/orm/mcp_oauth.py b/letta/orm/mcp_oauth.py new file mode 100644 index 00000000..e34f685a --- /dev/null +++ b/letta/orm/mcp_oauth.py @@ -0,0 +1,62 @@ +import uuid +from datetime import datetime +from enum import Enum +from typing import Optional + +from sqlalchemy import DateTime, ForeignKey, String, Text +from sqlalchemy.orm import Mapped, mapped_column + +from letta.orm.mixins import OrganizationMixin, UserMixin +from letta.orm.sqlalchemy_base import SqlalchemyBase + + +class OAuthSessionStatus(str, Enum): + """OAuth session status enumeration.""" + + PENDING = "pending" + AUTHORIZED = "authorized" + ERROR = "error" + + +class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin): + """OAuth session model for MCP server authentication.""" + + __tablename__ = "mcp_oauth" + + # Override the id field to match database UUID generation + id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid.uuid4()}") + + # Core session information + state: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, doc="OAuth state parameter") + server_id: Mapped[str] = mapped_column(String(255), ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True, doc="MCP server ID") + server_url: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server URL") + server_name: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server display name") + + # OAuth flow data + authorization_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization URL") + authorization_code: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization code") + + # Token data + access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth access token") + refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth refresh token") + token_type: Mapped[str] = mapped_column(String(50), default="Bearer", doc="Token type") + expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, doc="Token expiry time") + scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth scope") + + # Client configuration + client_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client ID") + client_secret: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client secret") + redirect_uri: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth redirect URI") + + # Session state + status: Mapped[OAuthSessionStatus] = mapped_column(String(20), default=OAuthSessionStatus.PENDING, doc="Session status") + + # Timestamps + created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), doc="Session creation time") + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now(), doc="Last update time" + ) + + # Relationships (if needed in the future) + # user: Mapped[Optional["User"]] = relationship("User", back_populates="oauth_sessions") + # organization: Mapped["Organization"] = relationship("Organization", back_populates="oauth_sessions") diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index b851f2d5..f7070e8b 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -1,3 +1,4 @@ +from datetime import datetime from typing import Any, Dict, Optional, Union from pydantic import Field @@ -10,6 +11,7 @@ from letta.functions.mcp_client.types import ( StdioServerConfig, StreamableHTTPServerConfig, ) +from letta.orm.mcp_oauth import OAuthSessionStatus from letta.schemas.letta_base import LettaBase @@ -119,3 +121,71 @@ class UpdateStreamableHTTPMCPServer(LettaBase): UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer] RegisterMCPServer = Union[RegisterSSEMCPServer, RegisterStdioMCPServer, RegisterStreamableHTTPMCPServer] + + +# OAuth-related schemas +class BaseMCPOAuth(LettaBase): + __id_prefix__ = "mcp-oauth" + + +class MCPOAuthSession(BaseMCPOAuth): + """OAuth session for MCP server authentication.""" + + id: str = BaseMCPOAuth.generate_id_field() + state: str = Field(..., description="OAuth state parameter") + server_id: Optional[str] = Field(None, description="MCP server ID") + server_url: str = Field(..., description="MCP server URL") + server_name: str = Field(..., description="MCP server display name") + + # User and organization context + user_id: Optional[str] = Field(None, description="User ID associated with the session") + organization_id: str = Field(..., description="Organization ID associated with the session") + + # OAuth flow data + authorization_url: Optional[str] = Field(None, description="OAuth authorization URL") + authorization_code: Optional[str] = Field(None, description="OAuth authorization code") + + # Token data + access_token: Optional[str] = Field(None, description="OAuth access token") + refresh_token: Optional[str] = Field(None, description="OAuth refresh token") + token_type: str = Field(default="Bearer", description="Token type") + expires_at: Optional[datetime] = Field(None, description="Token expiry time") + scope: Optional[str] = Field(None, description="OAuth scope") + + # Client configuration + client_id: Optional[str] = Field(None, description="OAuth client ID") + client_secret: Optional[str] = Field(None, description="OAuth client secret") + redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI") + + # Session state + status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status") + + # Timestamps + created_at: datetime = Field(default_factory=datetime.now, description="Session creation time") + updated_at: datetime = Field(default_factory=datetime.now, description="Last update time") + + +class MCPOAuthSessionCreate(BaseMCPOAuth): + """Create a new OAuth session.""" + + server_url: str = Field(..., description="MCP server URL") + server_name: str = Field(..., description="MCP server display name") + user_id: Optional[str] = Field(None, description="User ID associated with the session") + organization_id: str = Field(..., description="Organization ID associated with the session") + state: Optional[str] = Field(None, description="OAuth state parameter") + + +class MCPOAuthSessionUpdate(BaseMCPOAuth): + """Update an existing OAuth session.""" + + authorization_url: Optional[str] = Field(None, description="OAuth authorization URL") + authorization_code: Optional[str] = Field(None, description="OAuth authorization code") + access_token: Optional[str] = Field(None, description="OAuth access token") + refresh_token: Optional[str] = Field(None, description="OAuth refresh token") + token_type: Optional[str] = Field(None, description="Token type") + expires_at: Optional[datetime] = Field(None, description="Token expiry time") + scope: Optional[str] = Field(None, description="OAuth scope") + client_id: Optional[str] = Field(None, description="OAuth client ID") + client_secret: Optional[str] = Field(None, description="OAuth client secret") + redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI") + status: Optional[OAuthSessionStatus] = Field(None, description="Session status") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index ad171613..354afdc2 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -1,4 +1,6 @@ +import asyncio import json +from collections.abc import AsyncGenerator from typing import Any, Dict, List, Optional, Union from composio.client import ComposioClientError, HTTPError, NoItemsFound @@ -11,27 +13,37 @@ from composio.exceptions import ( EnumStringNotFound, ) from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from fastapi.responses import HTMLResponse from pydantic import BaseModel, Field +from starlette.responses import StreamingResponse from letta.errors import LettaToolCreateError from letta.functions.functions import derive_openai_json_schema from letta.functions.mcp_client.exceptions import MCPTimeoutError -from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig +from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig from letta.helpers.composio_helpers import get_composio_api_key +from letta.helpers.decorators import deprecated from letta.llm_api.llm_client import LLMClient from letta.log import get_logger from letta.orm.errors import UniqueConstraintViolationError +from letta.orm.mcp_oauth import OAuthSessionStatus from letta.schemas.enums import MessageRole from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.letta_message_content import TextContent -from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer +from letta.schemas.mcp import MCPOAuthSessionCreate, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer from letta.schemas.message import Message from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate +from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer -from letta.services.mcp.sse_client import AsyncSSEMCPClient -from letta.services.mcp.stdio_client import AsyncStdioMCPClient -from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient +from letta.services.mcp.oauth_utils import ( + MCPOAuthSession, + create_oauth_provider, + drill_down_exception, + get_oauth_success_html, + oauth_stream_event, +) +from letta.services.mcp.types import OauthStreamEvent from letta.settings import tool_settings router = APIRouter(prefix="/tools", tags=["tools"]) @@ -612,35 +624,26 @@ async def delete_mcp_server_from_config( return [server.to_config() for server in all_servers] -@router.post("/mcp/servers/test", response_model=List[MCPTool], operation_id="test_mcp_server") +@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream") +@router.post("/mcp/servers/test", operation_id="test_mcp_server") async def test_mcp_server( request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Test connection to an MCP server without adding it. - Returns the list of available tools if successful. + Returns the list of available tools if successful, or OAuth information if OAuth is required. """ client = None try: - # create a temporary MCP client based on the server type - if request.type == MCPServerType.SSE: - if not isinstance(request, SSEServerConfig): - request = SSEServerConfig(**request.model_dump()) - client = AsyncSSEMCPClient(request) - elif request.type == MCPServerType.STREAMABLE_HTTP: - if not isinstance(request, StreamableHTTPServerConfig): - request = StreamableHTTPServerConfig(**request.model_dump()) - client = AsyncStreamableHTTPMCPClient(request) - elif request.type == MCPServerType.STDIO: - if not isinstance(request, StdioServerConfig): - request = StdioServerConfig(**request.model_dump()) - client = AsyncStdioMCPClient(request) - else: - raise ValueError(f"Invalid MCP server type: {request.type}") + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + client = await server.mcp_manager.get_mcp_client(request, actor) await client.connect_to_server() tools = await client.list_tools() - return tools + + return {"status": "success", "tools": tools} except ConnectionError as e: raise HTTPException( status_code=400, @@ -676,6 +679,155 @@ async def test_mcp_server( logger.warning(f"Error during MCP client cleanup: {cleanup_error}") +@router.post( + "/mcp/servers/connect", + response_model=None, + responses={ + 200: { + "description": "Successful response", + "content": { + "text/event-stream": {"description": "Server-Sent Events stream"}, + }, + } + }, + operation_id="connect_mcp_server", +) +async def connect_mcp_server( + request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +) -> StreamingResponse: + """ + Connect to an MCP server with support for OAuth via SSE. + Returns a stream of events handling authorization state and exchange if OAuth is required. + """ + + async def oauth_stream_generator( + request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] + ) -> AsyncGenerator[str, None]: + client = None + oauth_provider = None + temp_client = None + connect_task = None + + try: + # Acknolwedge connection attempt + yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name) + + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + # Create MCP client with respective transport type + try: + client = await server.mcp_manager.get_mcp_client(request, actor) + except ValueError as e: + yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e)) + return + + # Try normal connection first for flows that don't require OAuth + try: + await client.connect_to_server() + tools = await client.list_tools(serialize=True) + yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) + return + except ConnectionError: + # Continue to OAuth flow + logger.info(f"Attempting OAuth flow for {request.server_url}...") + except Exception as e: + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}") + return + + # OAuth required, yield state to client to prepare to handle authorization URL + yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required") + + # Create OAuth session to persist the state of the OAuth flow + session_create = MCPOAuthSessionCreate( + server_url=request.server_url, + server_name=request.server_name, + user_id=actor.id, + organization_id=actor.organization_id, + ) + oauth_session = await server.mcp_manager.create_oauth_session(session_create, actor) + session_id = oauth_session.id + + # Create OAuth provider for the instance of the stream connection + # Note: Using the correct API path for the callback + # do not edit this this is the correct url + redirect_uri = f"http://localhost:8283/v1/tools/mcp/oauth/callback/{session_id}" + oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, server.mcp_manager, actor) + + # Get authorization URL by triggering OAuth flow + temp_client = None + try: + temp_client = await server.mcp_manager.get_mcp_client(request, actor, oauth_provider) + + # Run connect_to_server in background to avoid blocking + # This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database + connect_task = asyncio.create_task(temp_client.connect_to_server()) + + # Give the OAuth flow time to trigger and save the URL + await asyncio.sleep(1.0) + + # Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL + auth_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor) + if auth_session and auth_session.authorization_url: + yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id) + + except Exception as e: + logger.error(f"Error triggering OAuth flow: {e}") + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}") + + # Clean up active resources + if connect_task and not connect_task.done(): + connect_task.cancel() + try: + await connect_task + except asyncio.CancelledError: + pass + if temp_client: + try: + await temp_client.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") + return + + # Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit + yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...") + + # Callback handler will poll for authorization code and state and update the OAuth session + await connect_task + + tools = await temp_client.list_tools(serialize=True) + + yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) + return + except Exception as e: + detailed_error = drill_down_exception(e) + logger.error(f"Error in OAuth stream:\n{detailed_error}") + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}") + finally: + if connect_task and not connect_task.done(): + connect_task.cancel() + try: + await connect_task + except asyncio.CancelledError: + pass + if client: + try: + await client.cleanup() + except Exception as cleanup_error: + detailed_error = drill_down_exception(cleanup_error) + logger.warning(f"Error during MCP client cleanup: {detailed_error}") + if temp_client: + try: + await temp_client.cleanup() + except Exception as cleanup_error: + # TODO: @jnjpng fix async cancel scope issue + # detailed_error = drill_down_exception(cleanup_error) + logger.warning(f"Aysnc cleanup confict during temp MCP client cleanup: {cleanup_error}") + + return StreamingResponseWithStatusCode(oauth_stream_generator(request), media_type="text/event-stream") + + class CodeInput(BaseModel): code: str = Field(..., description="Python source code to parse for JSON schema") @@ -697,6 +849,45 @@ async def generate_json_schema( raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}") +# TODO: @jnjpng need to route this through cloud API for production +@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback", response_class=HTMLResponse) +async def mcp_oauth_callback( + session_id: str, + code: Optional[str] = Query(None, description="OAuth authorization code"), + state: Optional[str] = Query(None, description="OAuth state parameter"), + error: Optional[str] = Query(None, description="OAuth error"), + error_description: Optional[str] = Query(None, description="OAuth error description"), +): + """ + Handle OAuth callback for MCP server authentication. + """ + try: + oauth_session = MCPOAuthSession(session_id) + + if error: + error_msg = f"OAuth error: {error}" + if error_description: + error_msg += f" - {error_description}" + await oauth_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": error_msg} + + if not code or not state: + await oauth_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": "Missing authorization code or state"} + + # Store authorization code + success = await oauth_session.store_authorization_code(code, state) + if not success: + await oauth_session.update_session_status(OAuthSessionStatus.ERROR) + return {"status": "error", "message": "Invalid state parameter"} + + return HTMLResponse(content=get_oauth_success_html(), status_code=200) + + except Exception as e: + logger.error(f"OAuth callback error: {e}") + return {"status": "error", "message": f"OAuth callback failed: {str(e)}"} + + class GenerateToolInput(BaseModel): tool_name: str = Field(..., description="Name of the tool to generate code for") prompt: str = Field(..., description="User prompt to generate code") diff --git a/letta/services/mcp/base_client.py b/letta/services/mcp/base_client.py index 8aeda67f..0df3cba1 100644 --- a/letta/services/mcp/base_client.py +++ b/letta/services/mcp/base_client.py @@ -1,9 +1,9 @@ -import asyncio from contextlib import AsyncExitStack from typing import Optional, Tuple from mcp import ClientSession from mcp import Tool as MCPTool +from mcp.client.auth import OAuthClientProvider from mcp.types import TextContent from letta.functions.mcp_client.types import BaseServerConfig @@ -14,14 +14,12 @@ logger = get_logger(__name__) # TODO: Get rid of Async prefix on this class name once we deprecate old sync code class AsyncBaseMCPClient: - def __init__(self, server_config: BaseServerConfig): + def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None): self.server_config = server_config + self.oauth_provider = oauth_provider self.exit_stack = AsyncExitStack() self.session: Optional[ClientSession] = None self.initialized = False - # Track the task that created this client - self._creation_task = asyncio.current_task() - self._cleanup_queue = asyncio.Queue(maxsize=1) async def connect_to_server(self): try: @@ -48,9 +46,25 @@ class AsyncBaseMCPClient: async def _initialize_connection(self, server_config: BaseServerConfig) -> None: raise NotImplementedError("Subclasses must implement _initialize_connection") - async def list_tools(self) -> list[MCPTool]: + async def list_tools(self, serialize: bool = False) -> list[MCPTool]: self._check_initialized() response = await self.session.list_tools() + if serialize: + serializable_tools = [] + for tool in response.tools: + if hasattr(tool, "model_dump"): + # Pydantic model - use model_dump + serializable_tools.append(tool.model_dump()) + elif hasattr(tool, "dict"): + # Older Pydantic model - use dict() + serializable_tools.append(tool.dict()) + elif hasattr(tool, "__dict__"): + # Regular object - use __dict__ + serializable_tools.append(tool.__dict__) + else: + # Fallback - convert to string + serializable_tools.append(str(tool)) + return serializable_tools return response.tools async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]: @@ -79,29 +93,7 @@ class AsyncBaseMCPClient: # TODO: still hitting some async errors for voice agents, need to fix async def cleanup(self): - """Clean up resources - ensure this runs in the same task""" - if hasattr(self, "_cleanup_task"): - # If we're in a different task, schedule cleanup in original task - current_task = asyncio.current_task() - if current_task != self._creation_task: - # Create a future to signal completion - cleanup_done = asyncio.Future() - self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done)) - await cleanup_done - return - - # Normal cleanup await self.exit_stack.aclose() def to_sync_client(self): raise NotImplementedError("Subclasses must implement to_sync_client") - - async def __aenter__(self): - """Enter the async context manager.""" - await self.connect_to_server() - return self - - async def __aexit__(self, exc_type, exc_val, exc_tb): - """Exit the async context manager.""" - await self.cleanup() - return False # Don't suppress exceptions diff --git a/letta/services/mcp/oauth_utils.py b/letta/services/mcp/oauth_utils.py new file mode 100644 index 00000000..7391474c --- /dev/null +++ b/letta/services/mcp/oauth_utils.py @@ -0,0 +1,433 @@ +"""OAuth utilities for MCP server authentication.""" + +import asyncio +import json +import secrets +import time +import uuid +from datetime import datetime, timedelta +from typing import Callable, Optional, Tuple + +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken +from sqlalchemy import select + +from letta.log import get_logger +from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus +from letta.schemas.mcp import MCPOAuthSessionUpdate +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.services.mcp.types import OauthStreamEvent +from letta.services.mcp_manager import MCPManager + +logger = get_logger(__name__) + + +class DatabaseTokenStorage(TokenStorage): + """Database-backed token storage using MCPOAuth table via mcp_manager.""" + + def __init__(self, session_id: str, mcp_manager: MCPManager, actor: PydanticUser): + self.session_id = session_id + self.mcp_manager = mcp_manager + self.actor = actor + + async def get_tokens(self) -> Optional[OAuthToken]: + """Retrieve tokens from database.""" + oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor) + if not oauth_session or not oauth_session.access_token: + return None + + return OAuthToken( + access_token=oauth_session.access_token, + refresh_token=oauth_session.refresh_token, + token_type=oauth_session.token_type, + expires_in=int(oauth_session.expires_at.timestamp() - time.time()), + scope=oauth_session.scope, + ) + + async def set_tokens(self, tokens: OAuthToken) -> None: + """Store tokens in database.""" + session_update = MCPOAuthSessionUpdate( + access_token=tokens.access_token, + refresh_token=tokens.refresh_token, + token_type=tokens.token_type, + expires_at=datetime.fromtimestamp(tokens.expires_in + time.time()), + scope=tokens.scope, + status=OAuthSessionStatus.AUTHORIZED, + ) + await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor) + + async def get_client_info(self) -> Optional[OAuthClientInformationFull]: + """Retrieve client information from database.""" + oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor) + if not oauth_session or not oauth_session.client_id: + return None + + return OAuthClientInformationFull( + client_id=oauth_session.client_id, + client_secret=oauth_session.client_secret, + redirect_uris=[oauth_session.redirect_uri] if oauth_session.redirect_uri else [], + ) + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + """Store client information in database.""" + session_update = MCPOAuthSessionUpdate( + client_id=client_info.client_id, + client_secret=client_info.client_secret, + redirect_uri=str(client_info.redirect_uris[0]) if client_info.redirect_uris else None, + ) + await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor) + + +class MCPOAuthSession: + """Legacy OAuth session class - deprecated, use mcp_manager directly.""" + + def __init__(self, server_url: str, server_name: str, user_id: Optional[str], organization_id: str): + self.server_url = server_url + self.server_name = server_name + self.user_id = user_id + self.organization_id = organization_id + self.session_id = str(uuid.uuid4()) + self.state = secrets.token_urlsafe(32) + + def __init__(self, session_id: str): + self.session_id = session_id + + # TODO: consolidate / deprecate this in favor of mcp_manager access + async def create_session(self) -> str: + """Create a new OAuth session in the database.""" + async with db_registry.async_session() as session: + oauth_record = MCPOAuth( + id=self.session_id, + state=self.state, + server_url=self.server_url, + server_name=self.server_name, + user_id=self.user_id, + organization_id=self.organization_id, + status=OAuthSessionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + oauth_record = await oauth_record.create_async(session, actor=None) + + return self.session_id + + async def get_session_status(self) -> OAuthSessionStatus: + """Get the current status of the OAuth session.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + return oauth_record.status + except Exception: + return OAuthSessionStatus.ERROR + + async def update_session_status(self, status: OAuthSessionStatus) -> None: + """Update the session status.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + oauth_record.status = status + oauth_record.updated_at = datetime.now() + await oauth_record.update_async(db_session=session, actor=None) + except Exception: + pass + + async def store_authorization_code(self, code: str, state: str) -> bool: + """Store the authorization code from OAuth callback.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + + # if oauth_record.state != state: + # return False + + oauth_record.authorization_code = code + oauth_record.state = state + oauth_record.status = OAuthSessionStatus.AUTHORIZED + oauth_record.updated_at = datetime.now() + await oauth_record.update_async(db_session=session, actor=None) + return True + except Exception: + return False + + async def get_authorization_url(self) -> Optional[str]: + """Get the authorization URL for this session.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + return oauth_record.authorization_url + except Exception: + return None + + async def set_authorization_url(self, url: str) -> None: + """Set the authorization URL for this session.""" + async with db_registry.async_session() as session: + try: + oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) + oauth_record.authorization_url = url + oauth_record.updated_at = datetime.now() + await oauth_record.update_async(db_session=session, actor=None) + except Exception: + pass + + +async def create_oauth_provider( + session_id: str, + server_url: str, + redirect_uri: str, + mcp_manager: MCPManager, + actor: PydanticUser, + url_callback: Optional[Callable[[str], None]] = None, +) -> OAuthClientProvider: + """Create an OAuth provider for MCP server authentication.""" + + client_metadata_dict = { + "client_name": "Letta MCP Client", + "redirect_uris": [redirect_uri], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + } + + # Use manager-based storage + storage = DatabaseTokenStorage(session_id, mcp_manager, actor) + + # Extract base URL (remove /mcp endpoint if present) + oauth_server_url = server_url.rstrip("/").removesuffix("/sse").removesuffix("/mcp") + + async def redirect_handler(authorization_url: str) -> None: + """Handle OAuth redirect by storing the authorization URL.""" + logger.info(f"OAuth redirect handler called with URL: {authorization_url}") + session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url) + await mcp_manager.update_oauth_session(session_id, session_update, actor) + logger.info(f"OAuth authorization URL stored: {authorization_url}") + + # Call the callback if provided (e.g., to yield URL to SSE stream) + if url_callback: + url_callback(authorization_url) + + async def callback_handler() -> Tuple[str, Optional[str]]: + """Handle OAuth callback by waiting for authorization code.""" + timeout = 300 # 5 minutes + start_time = time.time() + + logger.info(f"Waiting for authorization code for session {session_id}") + while time.time() - start_time < timeout: + oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor) + if oauth_session and oauth_session.authorization_code: + return oauth_session.authorization_code, oauth_session.state + elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR: + raise Exception("OAuth authorization failed") + await asyncio.sleep(1) + + raise Exception(f"Timeout waiting for OAuth callback after {timeout} seconds") + + return OAuthClientProvider( + server_url=oauth_server_url, + client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), + storage=storage, + redirect_handler=redirect_handler, + callback_handler=callback_handler, + ) + + +async def cleanup_expired_oauth_sessions(max_age_hours: int = 24) -> None: + """Clean up expired OAuth sessions.""" + cutoff_time = datetime.now() - timedelta(hours=max_age_hours) + + async with db_registry.async_session() as session: + result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time)) + expired_sessions = result.scalars().all() + + for oauth_session in expired_sessions: + await oauth_session.hard_delete_async(db_session=session, actor=None) + + if expired_sessions: + logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions") + + +def oauth_stream_event(event: OauthStreamEvent, **kwargs) -> str: + data = {"event": event.value} + data.update(kwargs) + return f"data: {json.dumps(data)}\n\n" + + +def drill_down_exception(exception, depth=0, max_depth=5): + """Recursively drill down into nested exceptions to find the root cause""" + indent = " " * depth + error_details = [] + + error_details.append(f"{indent}Exception at depth {depth}:") + error_details.append(f"{indent} Type: {type(exception).__name__}") + error_details.append(f"{indent} Message: {str(exception)}") + error_details.append(f"{indent} Module: {getattr(type(exception), '__module__', 'unknown')}") + + # Check for exception groups (TaskGroup errors) + if hasattr(exception, "exceptions") and exception.exceptions: + error_details.append(f"{indent} ExceptionGroup with {len(exception.exceptions)} sub-exceptions:") + for i, sub_exc in enumerate(exception.exceptions): + error_details.append(f"{indent} Sub-exception {i}:") + if depth < max_depth: + error_details.extend(drill_down_exception(sub_exc, depth + 1, max_depth)) + + # Check for chained exceptions (__cause__ and __context__) + if hasattr(exception, "__cause__") and exception.__cause__ and depth < max_depth: + error_details.append(f"{indent} Caused by:") + error_details.extend(drill_down_exception(exception.__cause__, depth + 1, max_depth)) + + if hasattr(exception, "__context__") and exception.__context__ and depth < max_depth: + error_details.append(f"{indent} Context:") + error_details.extend(drill_down_exception(exception.__context__, depth + 1, max_depth)) + + # Add traceback info + import traceback + + if hasattr(exception, "__traceback__") and exception.__traceback__: + tb_lines = traceback.format_tb(exception.__traceback__) + error_details.append(f"{indent} Traceback:") + for line in tb_lines[-3:]: # Show last 3 traceback lines + error_details.append(f"{indent} {line.strip()}") + + error_info = "".join(error_details) + return error_info + + +def get_oauth_success_html() -> str: + """Generate HTML for successful OAuth authorization.""" + return """ + + + + Authorization Successful - Letta + + + +
+ +

Authorization Successful

+

You have successfully connected your MCP server.

+
+ You can now close this window. +
+
+ + +""" diff --git a/letta/services/mcp/sse_client.py b/letta/services/mcp/sse_client.py index 91e2515c..950b4ae0 100644 --- a/letta/services/mcp/sse_client.py +++ b/letta/services/mcp/sse_client.py @@ -1,4 +1,7 @@ +from typing import Optional + from mcp import ClientSession +from mcp.client.auth import OAuthClientProvider from mcp.client.sse import sse_client from letta.functions.mcp_client.types import SSEServerConfig @@ -13,6 +16,9 @@ logger = get_logger(__name__) # TODO: Get rid of Async prefix on this class name once we deprecate old sync code class AsyncSSEMCPClient(AsyncBaseMCPClient): + def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None): + super().__init__(server_config, oauth_provider) + async def _initialize_connection(self, server_config: SSEServerConfig) -> None: headers = {} if server_config.custom_headers: @@ -21,7 +27,12 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient): if server_config.auth_header and server_config.auth_token: headers[server_config.auth_header] = server_config.auth_token - sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None) + # Use OAuth provider if available, otherwise use regular headers + if self.oauth_provider: + sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider) + else: + sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None) + sse_transport = await self.exit_stack.enter_async_context(sse_cm) self.stdio, self.write = sse_transport diff --git a/letta/services/mcp/streamable_http_client.py b/letta/services/mcp/streamable_http_client.py index 63c269ac..baf2f7c6 100644 --- a/letta/services/mcp/streamable_http_client.py +++ b/letta/services/mcp/streamable_http_client.py @@ -1,4 +1,7 @@ +from typing import Optional + from mcp import ClientSession +from mcp.client.auth import OAuthClientProvider from mcp.client.streamable_http import streamablehttp_client from letta.functions.mcp_client.types import BaseServerConfig, StreamableHTTPServerConfig @@ -9,10 +12,12 @@ logger = get_logger(__name__) class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient): + def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None): + super().__init__(server_config, oauth_provider) + async def _initialize_connection(self, server_config: BaseServerConfig) -> None: if not isinstance(server_config, StreamableHTTPServerConfig): raise ValueError("Expected StreamableHTTPServerConfig") - try: # Prepare headers for authentication headers = {} @@ -23,11 +28,18 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient): if server_config.auth_header and server_config.auth_token: headers[server_config.auth_header] = server_config.auth_token - # Use streamablehttp_client context manager with headers if provided - if headers: - streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers) + # Use OAuth provider if available, otherwise use regular headers + if self.oauth_provider: + streamable_http_cm = streamablehttp_client( + server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider + ) else: - streamable_http_cm = streamablehttp_client(server_config.server_url) + # Use streamablehttp_client context manager with headers if provided + if headers: + streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers) + else: + streamable_http_cm = streamablehttp_client(server_config.server_url) + read_stream, write_stream, _ = await self.exit_stack.enter_async_context(streamable_http_cm) # Create and enter the ClientSession context manager diff --git a/letta/services/mcp/types.py b/letta/services/mcp/types.py index 2d8b7af6..b5e873cb 100644 --- a/letta/services/mcp/types.py +++ b/letta/services/mcp/types.py @@ -46,3 +46,12 @@ class StdioServerConfig(BaseServerConfig): if self.env is not None: values["env"] = self.env return values + + +class OauthStreamEvent(str, Enum): + CONNECTION_ATTEMPT = "connection_attempt" + SUCCESS = "success" + ERROR = "error" + OAUTH_REQUIRED = "oauth_required" + AUTHORIZATION_URL = "authorization_url" + WAITING_FOR_AUTH = "waiting_for_auth" diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 77cd2a9a..a075b5d0 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -1,5 +1,8 @@ import json import os +import secrets +import uuid +from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Tuple, Union from sqlalchemy import null @@ -8,8 +11,18 @@ import letta.constants as constants from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig from letta.log import get_logger from letta.orm.errors import NoResultFound +from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus from letta.orm.mcp_server import MCPServer as MCPServerModel -from letta.schemas.mcp import MCPServer, UpdateMCPServer, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer +from letta.schemas.mcp import ( + MCPOAuthSession, + MCPOAuthSessionCreate, + MCPOAuthSessionUpdate, + MCPServer, + UpdateMCPServer, + UpdateSSEMCPServer, + UpdateStdioMCPServer, + UpdateStreamableHTTPMCPServer, +) from letta.schemas.tool import Tool as PydanticTool from letta.schemas.tool import ToolCreate from letta.schemas.user import User as PydanticUser @@ -38,14 +51,7 @@ class MCPManager: mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) server_config = mcp_config.to_config() - if mcp_config.server_type == MCPServerType.SSE: - mcp_client = AsyncSSEMCPClient(server_config=server_config) - elif mcp_config.server_type == MCPServerType.STDIO: - mcp_client = AsyncStdioMCPClient(server_config=server_config) - elif mcp_config.server_type == MCPServerType.STREAMABLE_HTTP: - mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config) - else: - raise ValueError(f"Unsupported MCP server type: {mcp_config.server_type}") + mcp_client = await self.get_mcp_client(server_config, actor) await mcp_client.connect_to_server() # list tools @@ -72,28 +78,20 @@ class MCPManager: # read from config file mcp_config = self.read_mcp_config() if mcp_server_name not in mcp_config: - print("MCP server not found in config.", mcp_config) raise ValueError(f"MCP server {mcp_server_name} not found in config.") server_config = mcp_config[mcp_server_name] - if isinstance(server_config, SSEServerConfig): - # mcp_client = AsyncSSEMCPClient(server_config=server_config) - async with AsyncSSEMCPClient(server_config=server_config) as mcp_client: - result, success = await mcp_client.execute_tool(tool_name, tool_args) - logger.info(f"MCP Result: {result}, Success: {success}") - return result, success - elif isinstance(server_config, StdioServerConfig): - async with AsyncStdioMCPClient(server_config=server_config) as mcp_client: - result, success = await mcp_client.execute_tool(tool_name, tool_args) - logger.info(f"MCP Result: {result}, Success: {success}") - return result, success - elif isinstance(server_config, StreamableHTTPServerConfig): - async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client: - result, success = await mcp_client.execute_tool(tool_name, tool_args) - logger.info(f"MCP Result: {result}, Success: {success}") - return result, success - else: - raise ValueError(f"Unsupported server config type: {type(server_config)}") + mcp_client = await self.get_mcp_client(server_config, actor) + await mcp_client.connect_to_server() + + # call tool + result, success = await mcp_client.execute_tool(tool_name, tool_args) + logger.info(f"MCP Result: {result}, Success: {success}") + # TODO: change to pydantic tool + + await mcp_client.cleanup() + + return result, success @enforce_types async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool: @@ -324,3 +322,246 @@ class MCPManager: logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}") continue return mcp_server_list + + async def get_mcp_client( + self, + server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig], + actor: PydanticUser, + oauth_provider: Optional[Any] = None, + ) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]: + """ + Helper function to create the appropriate MCP client based on server configuration. + + Args: + server_config: The server configuration object + actor: The user making the request + oauth_provider: Optional OAuth provider for authentication + + Returns: + The appropriate MCP client instance + + Raises: + ValueError: If server config type is not supported + """ + # If no OAuth provider is provided, check if we have stored OAuth credentials + if oauth_provider is None and hasattr(server_config, "server_url"): + oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor) + if oauth_session and oauth_session.access_token: + # Create OAuth provider from stored credentials + from letta.services.mcp.oauth_utils import create_oauth_provider + + oauth_provider = await create_oauth_provider( + session_id=oauth_session.id, + server_url=oauth_session.server_url, + redirect_uri=oauth_session.redirect_uri, + mcp_manager=self, + actor=actor, + ) + + if server_config.type == MCPServerType.SSE: + server_config = SSEServerConfig(**server_config.model_dump()) + return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider) + elif server_config.type == MCPServerType.STDIO: + server_config = StdioServerConfig(**server_config.model_dump()) + return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider) + elif server_config.type == MCPServerType.STREAMABLE_HTTP: + server_config = StreamableHTTPServerConfig(**server_config.model_dump()) + return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider) + else: + raise ValueError(f"Unsupported server config type: {type(server_config)}") + + # OAuth-related methods + @enforce_types + async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession: + """Create a new OAuth session for MCP server authentication.""" + async with db_registry.async_session() as session: + # Create the OAuth session with a unique state + oauth_session = MCPOAuth( + id="mcp-oauth-" + str(uuid.uuid4())[:8], + state=secrets.token_urlsafe(32), + server_url=session_create.server_url, + server_name=session_create.server_name, + user_id=session_create.user_id, + organization_id=session_create.organization_id, + status=OAuthSessionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + oauth_session = await oauth_session.create_async(session, actor=actor) + + # Convert to Pydantic model + return MCPOAuthSession( + id=oauth_session.id, + state=oauth_session.state, + server_url=oauth_session.server_url, + server_name=oauth_session.server_name, + user_id=oauth_session.user_id, + organization_id=oauth_session.organization_id, + status=oauth_session.status, + created_at=oauth_session.created_at, + updated_at=oauth_session.updated_at, + ) + + @enforce_types + async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: + """Get an OAuth session by its ID.""" + async with db_registry.async_session() as session: + try: + oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) + return MCPOAuthSession( + id=oauth_session.id, + state=oauth_session.state, + server_url=oauth_session.server_url, + server_name=oauth_session.server_name, + user_id=oauth_session.user_id, + organization_id=oauth_session.organization_id, + authorization_url=oauth_session.authorization_url, + authorization_code=oauth_session.authorization_code, + access_token=oauth_session.access_token, + refresh_token=oauth_session.refresh_token, + token_type=oauth_session.token_type, + expires_at=oauth_session.expires_at, + scope=oauth_session.scope, + client_id=oauth_session.client_id, + client_secret=oauth_session.client_secret, + redirect_uri=oauth_session.redirect_uri, + status=oauth_session.status, + created_at=oauth_session.created_at, + updated_at=oauth_session.updated_at, + ) + except NoResultFound: + return None + + @enforce_types + async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: + """Get the latest OAuth session by server URL, organization, and user.""" + from sqlalchemy import desc, select + + async with db_registry.async_session() as session: + # Query for OAuth session matching organization, user, server URL, and status + # Order by updated_at desc to get the most recent record + result = await session.execute( + select(MCPOAuth) + .where( + MCPOAuth.organization_id == actor.organization_id, + MCPOAuth.user_id == actor.id, + MCPOAuth.server_url == server_url, + MCPOAuth.status == OAuthSessionStatus.AUTHORIZED, + ) + .order_by(desc(MCPOAuth.updated_at)) + .limit(1) + ) + oauth_session = result.scalar_one_or_none() + + if not oauth_session: + return None + + return MCPOAuthSession( + id=oauth_session.id, + state=oauth_session.state, + server_url=oauth_session.server_url, + server_name=oauth_session.server_name, + user_id=oauth_session.user_id, + organization_id=oauth_session.organization_id, + authorization_url=oauth_session.authorization_url, + authorization_code=oauth_session.authorization_code, + access_token=oauth_session.access_token, + refresh_token=oauth_session.refresh_token, + token_type=oauth_session.token_type, + expires_at=oauth_session.expires_at, + scope=oauth_session.scope, + client_id=oauth_session.client_id, + client_secret=oauth_session.client_secret, + redirect_uri=oauth_session.redirect_uri, + status=oauth_session.status, + created_at=oauth_session.created_at, + updated_at=oauth_session.updated_at, + ) + + @enforce_types + async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession: + """Update an existing OAuth session.""" + async with db_registry.async_session() as session: + oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) + + # Update fields that are provided + if session_update.authorization_url is not None: + oauth_session.authorization_url = session_update.authorization_url + if session_update.authorization_code is not None: + oauth_session.authorization_code = session_update.authorization_code + if session_update.access_token is not None: + oauth_session.access_token = session_update.access_token + if session_update.refresh_token is not None: + oauth_session.refresh_token = session_update.refresh_token + if session_update.token_type is not None: + oauth_session.token_type = session_update.token_type + if session_update.expires_at is not None: + oauth_session.expires_at = session_update.expires_at + if session_update.scope is not None: + oauth_session.scope = session_update.scope + if session_update.client_id is not None: + oauth_session.client_id = session_update.client_id + if session_update.client_secret is not None: + oauth_session.client_secret = session_update.client_secret + if session_update.redirect_uri is not None: + oauth_session.redirect_uri = session_update.redirect_uri + if session_update.status is not None: + oauth_session.status = session_update.status + + # Always update the updated_at timestamp + oauth_session.updated_at = datetime.now() + + oauth_session = await oauth_session.update_async(db_session=session, actor=actor) + + return MCPOAuthSession( + id=oauth_session.id, + state=oauth_session.state, + server_url=oauth_session.server_url, + server_name=oauth_session.server_name, + user_id=oauth_session.user_id, + organization_id=oauth_session.organization_id, + authorization_url=oauth_session.authorization_url, + authorization_code=oauth_session.authorization_code, + access_token=oauth_session.access_token, + refresh_token=oauth_session.refresh_token, + token_type=oauth_session.token_type, + expires_at=oauth_session.expires_at, + scope=oauth_session.scope, + client_id=oauth_session.client_id, + client_secret=oauth_session.client_secret, + redirect_uri=oauth_session.redirect_uri, + status=oauth_session.status, + created_at=oauth_session.created_at, + updated_at=oauth_session.updated_at, + ) + + @enforce_types + async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None: + """Delete an OAuth session.""" + async with db_registry.async_session() as session: + try: + oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) + await oauth_session.hard_delete_async(db_session=session, actor=actor) + except NoResultFound: + raise ValueError(f"OAuth session with id {session_id} not found.") + + @enforce_types + async def cleanup_expired_oauth_sessions(self, max_age_hours: int = 24) -> int: + """Clean up expired OAuth sessions and return the count of deleted sessions.""" + cutoff_time = datetime.now() - timedelta(hours=max_age_hours) + + async with db_registry.async_session() as session: + from sqlalchemy import select + + # Find expired sessions + result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time)) + expired_sessions = result.scalars().all() + + # Delete expired sessions using async ORM method + for oauth_session in expired_sessions: + await oauth_session.hard_delete_async(db_session=session, actor=None) + + if expired_sessions: + logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions") + + return len(expired_sessions) diff --git a/mcp_test.py b/mcp_test.py new file mode 100644 index 00000000..507d5c7c --- /dev/null +++ b/mcp_test.py @@ -0,0 +1,356 @@ +#!/usr/bin/env python3 +""" +Simple MCP client example with OAuth authentication support. + +This client connects to an MCP server using streamable HTTP transport with OAuth. + +""" + +import asyncio +import os +import threading +import time +import webbrowser +from datetime import timedelta +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any +from urllib.parse import parse_qs, urlparse + +from mcp.client.auth import OAuthClientProvider, TokenStorage +from mcp.client.session import ClientSession +from mcp.client.sse import sse_client +from mcp.client.streamable_http import streamablehttp_client +from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken + + +class InMemoryTokenStorage(TokenStorage): + """Simple in-memory token storage implementation.""" + + def __init__(self): + self._tokens: OAuthToken | None = None + self._client_info: OAuthClientInformationFull | None = None + + async def get_tokens(self) -> OAuthToken | None: + return self._tokens + + async def set_tokens(self, tokens: OAuthToken) -> None: + self._tokens = tokens + + async def get_client_info(self) -> OAuthClientInformationFull | None: + return self._client_info + + async def set_client_info(self, client_info: OAuthClientInformationFull) -> None: + self._client_info = client_info + + +class CallbackHandler(BaseHTTPRequestHandler): + """Simple HTTP handler to capture OAuth callback.""" + + def __init__(self, request, client_address, server, callback_data): + """Initialize with callback data storage.""" + self.callback_data = callback_data + super().__init__(request, client_address, server) + + def do_GET(self): + """Handle GET request from OAuth redirect.""" + parsed = urlparse(self.path) + query_params = parse_qs(parsed.query) + + if "code" in query_params: + self.callback_data["authorization_code"] = query_params["code"][0] + self.callback_data["state"] = query_params.get("state", [None])[0] + self.send_response(200) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + b""" + + +

Authorization Successful!

+

You can close this window and return to the terminal.

+ + + + """ + ) + elif "error" in query_params: + self.callback_data["error"] = query_params["error"][0] + self.send_response(400) + self.send_header("Content-type", "text/html") + self.end_headers() + self.wfile.write( + f""" + + +

Authorization Failed

+

Error: {query_params['error'][0]}

+

You can close this window and return to the terminal.

+ + + """.encode() + ) + else: + self.send_response(404) + self.end_headers() + + def log_message(self, format, *args): + """Suppress default logging.""" + + +class CallbackServer: + """Simple server to handle OAuth callbacks.""" + + def __init__(self, port=3000): + self.port = port + self.server = None + self.thread = None + self.callback_data = {"authorization_code": None, "state": None, "error": None} + + def _create_handler_with_data(self): + """Create a handler class with access to callback data.""" + callback_data = self.callback_data + + class DataCallbackHandler(CallbackHandler): + def __init__(self, request, client_address, server): + super().__init__(request, client_address, server, callback_data) + + return DataCallbackHandler + + def start(self): + """Start the callback server in a background thread.""" + handler_class = self._create_handler_with_data() + self.server = HTTPServer(("localhost", self.port), handler_class) + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + print(f"šŸ–„ļø Started callback server on http://localhost:{self.port}") + + def stop(self): + """Stop the callback server.""" + if self.server: + self.server.shutdown() + self.server.server_close() + if self.thread: + self.thread.join(timeout=1) + + def wait_for_callback(self, timeout=300): + """Wait for OAuth callback with timeout.""" + start_time = time.time() + while time.time() - start_time < timeout: + if self.callback_data["authorization_code"]: + return self.callback_data["authorization_code"] + elif self.callback_data["error"]: + raise Exception(f"OAuth error: {self.callback_data['error']}") + time.sleep(0.1) + raise Exception("Timeout waiting for OAuth callback") + + def get_state(self): + """Get the received state parameter.""" + return self.callback_data["state"] + + +class SimpleAuthClient: + """Simple MCP client with auth support.""" + + def __init__(self, server_url: str, transport_type: str = "streamable_http"): + self.server_url = server_url + self.transport_type = transport_type + self.session: ClientSession | None = None + + async def connect(self): + """Connect to the MCP server.""" + print(f"šŸ”— Attempting to connect to {self.server_url}...") + + try: + callback_server = CallbackServer(port=3030) + callback_server.start() + + async def callback_handler() -> tuple[str, str | None]: + """Wait for OAuth callback and return auth code and state.""" + print("ā³ Waiting for authorization callback...") + try: + auth_code = callback_server.wait_for_callback(timeout=300) + return auth_code, callback_server.get_state() + finally: + callback_server.stop() + + client_metadata_dict = { + "client_name": "Simple Auth Client", + "redirect_uris": ["http://localhost:3030/callback"], + "grant_types": ["authorization_code", "refresh_token"], + "response_types": ["code"], + "token_endpoint_auth_method": "client_secret_post", + } + + async def _default_redirect_handler(authorization_url: str) -> None: + """Default redirect handler that opens the URL in a browser.""" + print(f"Opening browser for authorization: {authorization_url}") + webbrowser.open(authorization_url) + + # Create OAuth authentication handler using the new interface + oauth_auth = OAuthClientProvider( + server_url=self.server_url.replace("/mcp", ""), + client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict), + storage=InMemoryTokenStorage(), + redirect_handler=_default_redirect_handler, + callback_handler=callback_handler, + ) + + # Create transport with auth handler based on transport type + if self.transport_type == "sse": + print("šŸ“” Opening SSE transport connection with auth...") + async with sse_client( + url=self.server_url, + auth=oauth_auth, + timeout=60, + ) as (read_stream, write_stream): + await self._run_session(read_stream, write_stream, None) + else: + print("šŸ“” Opening StreamableHTTP transport connection with auth...") + async with streamablehttp_client( + url=self.server_url, + auth=oauth_auth, + timeout=timedelta(seconds=60), + ) as (read_stream, write_stream, get_session_id): + await self._run_session(read_stream, write_stream, get_session_id) + + except Exception as e: + print(f"āŒ Failed to connect: {e}") + import traceback + + traceback.print_exc() + + async def _run_session(self, read_stream, write_stream, get_session_id): + """Run the MCP session with the given streams.""" + print("šŸ¤ Initializing MCP session...") + async with ClientSession(read_stream, write_stream) as session: + self.session = session + print("⚔ Starting session initialization...") + await session.initialize() + print("✨ Session initialization complete!") + + print(f"\nāœ… Connected to MCP server at {self.server_url}") + if get_session_id: + session_id = get_session_id() + if session_id: + print(f"Session ID: {session_id}") + + # Run interactive loop + await self.interactive_loop() + + async def list_tools(self): + """List available tools from the server.""" + if not self.session: + print("āŒ Not connected to server") + return + + try: + result = await self.session.list_tools() + if hasattr(result, "tools") and result.tools: + print("\nšŸ“‹ Available tools:") + for i, tool in enumerate(result.tools, 1): + print(f"{i}. {tool.name}") + if tool.description: + print(f" Description: {tool.description}") + print() + else: + print("No tools available") + except Exception as e: + print(f"āŒ Failed to list tools: {e}") + + async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None): + """Call a specific tool.""" + if not self.session: + print("āŒ Not connected to server") + return + + try: + result = await self.session.call_tool(tool_name, arguments or {}) + print(f"\nšŸ”§ Tool '{tool_name}' result:") + if hasattr(result, "content"): + for content in result.content: + if content.type == "text": + print(content.text) + else: + print(content) + else: + print(result) + except Exception as e: + print(f"āŒ Failed to call tool '{tool_name}': {e}") + + async def interactive_loop(self): + """Run interactive command loop.""" + print("\nšŸŽÆ Interactive MCP Client") + print("Commands:") + print(" list - List available tools") + print(" call [args] - Call a tool") + print(" quit - Exit the client") + print() + + while True: + try: + command = input("mcp> ").strip() + + if not command: + continue + + if command == "quit": + break + + elif command == "list": + await self.list_tools() + + elif command.startswith("call "): + parts = command.split(maxsplit=2) + tool_name = parts[1] if len(parts) > 1 else "" + + if not tool_name: + print("āŒ Please specify a tool name") + continue + + # Parse arguments (simple JSON-like format) + arguments = {} + if len(parts) > 2: + import json + + try: + arguments = json.loads(parts[2]) + except json.JSONDecodeError: + print("āŒ Invalid arguments format (expected JSON)") + continue + + await self.call_tool(tool_name, arguments) + + else: + print("āŒ Unknown command. Try 'list', 'call ', or 'quit'") + + except KeyboardInterrupt: + print("\n\nšŸ‘‹ Goodbye!") + break + except EOFError: + break + + +async def main(): + """Main entry point.""" + # Default server URL - can be overridden with environment variable + # Most MCP streamable HTTP servers use /mcp as the endpoint + server_url = os.getenv("MCP_SERVER_PORT", 8000) + transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http") + server_url = f"http://localhost:{server_url}/mcp" if transport_type == "streamable_http" else f"http://localhost:{server_url}/sse" + + print("šŸš€ Simple MCP Auth Client") + print(f"Connecting to: {server_url}") + print(f"Transport type: {transport_type}") + + # Start connection flow - OAuth will be handled automatically + client = SimpleAuthClient(server_url, transport_type) + await client.connect() + + +def cli(): + """CLI entry point for uv script.""" + asyncio.run(main()) + + +if __name__ == "__main__": + cli()