feat: add support for oauth mcp

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-07-24 18:23:01 -07:00
committed by GitHub
parent bc471c6055
commit 772a51777e
11 changed files with 1529 additions and 85 deletions

View File

@@ -0,0 +1,67 @@
"""add_mcp_oauth
Revision ID: f5d26b0526e8
Revises: ddecfe4902bc
Create Date: 2025-07-24 12:34:05.795355
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "f5d26b0526e8"
down_revision: Union[str, None] = "ddecfe4902bc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"mcp_oauth",
sa.Column("id", sa.String(), nullable=False),
sa.Column("state", sa.String(length=255), nullable=False),
sa.Column("server_id", sa.String(length=255), nullable=True),
sa.Column("server_url", sa.Text(), nullable=False),
sa.Column("server_name", sa.Text(), nullable=False),
sa.Column("authorization_url", sa.Text(), nullable=True),
sa.Column("authorization_code", sa.Text(), nullable=True),
sa.Column("access_token", sa.Text(), nullable=True),
sa.Column("refresh_token", sa.Text(), nullable=True),
sa.Column("token_type", sa.String(length=50), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("scope", sa.Text(), nullable=True),
sa.Column("client_id", sa.Text(), nullable=True),
sa.Column("client_secret", sa.Text(), nullable=True),
sa.Column("redirect_uri", sa.Text(), nullable=True),
sa.Column("status", sa.String(length=20), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("user_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.ForeignKeyConstraint(["server_id"], ["mcp_server.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(
["user_id"],
["users.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("state"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("mcp_oauth")
# ### end Alembic commands ###

62
letta/orm/mcp_oauth.py Normal file
View File

@@ -0,0 +1,62 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Optional
from sqlalchemy import DateTime, ForeignKey, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from letta.orm.mixins import OrganizationMixin, UserMixin
from letta.orm.sqlalchemy_base import SqlalchemyBase
class OAuthSessionStatus(str, Enum):
"""OAuth session status enumeration."""
PENDING = "pending"
AUTHORIZED = "authorized"
ERROR = "error"
class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin):
"""OAuth session model for MCP server authentication."""
__tablename__ = "mcp_oauth"
# Override the id field to match database UUID generation
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid.uuid4()}")
# Core session information
state: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, doc="OAuth state parameter")
server_id: Mapped[str] = mapped_column(String(255), ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True, doc="MCP server ID")
server_url: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server URL")
server_name: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server display name")
# OAuth flow data
authorization_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization URL")
authorization_code: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization code")
# Token data
access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth access token")
refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth refresh token")
token_type: Mapped[str] = mapped_column(String(50), default="Bearer", doc="Token type")
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, doc="Token expiry time")
scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth scope")
# Client configuration
client_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client ID")
client_secret: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client secret")
redirect_uri: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth redirect URI")
# Session state
status: Mapped[OAuthSessionStatus] = mapped_column(String(20), default=OAuthSessionStatus.PENDING, doc="Session status")
# Timestamps
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), doc="Session creation time")
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now(), doc="Last update time"
)
# Relationships (if needed in the future)
# user: Mapped[Optional["User"]] = relationship("User", back_populates="oauth_sessions")
# organization: Mapped["Organization"] = relationship("Organization", back_populates="oauth_sessions")

View File

@@ -1,3 +1,4 @@
from datetime import datetime
from typing import Any, Dict, Optional, Union
from pydantic import Field
@@ -10,6 +11,7 @@ from letta.functions.mcp_client.types import (
StdioServerConfig,
StreamableHTTPServerConfig,
)
from letta.orm.mcp_oauth import OAuthSessionStatus
from letta.schemas.letta_base import LettaBase
@@ -119,3 +121,71 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]
RegisterMCPServer = Union[RegisterSSEMCPServer, RegisterStdioMCPServer, RegisterStreamableHTTPMCPServer]
# OAuth-related schemas
class BaseMCPOAuth(LettaBase):
__id_prefix__ = "mcp-oauth"
class MCPOAuthSession(BaseMCPOAuth):
"""OAuth session for MCP server authentication."""
id: str = BaseMCPOAuth.generate_id_field()
state: str = Field(..., description="OAuth state parameter")
server_id: Optional[str] = Field(None, description="MCP server ID")
server_url: str = Field(..., description="MCP server URL")
server_name: str = Field(..., description="MCP server display name")
# User and organization context
user_id: Optional[str] = Field(None, description="User ID associated with the session")
organization_id: str = Field(..., description="Organization ID associated with the session")
# OAuth flow data
authorization_url: Optional[str] = Field(None, description="OAuth authorization URL")
authorization_code: Optional[str] = Field(None, description="OAuth authorization code")
# Token data
access_token: Optional[str] = Field(None, description="OAuth access token")
refresh_token: Optional[str] = Field(None, description="OAuth refresh token")
token_type: str = Field(default="Bearer", description="Token type")
expires_at: Optional[datetime] = Field(None, description="Token expiry time")
scope: Optional[str] = Field(None, description="OAuth scope")
# Client configuration
client_id: Optional[str] = Field(None, description="OAuth client ID")
client_secret: Optional[str] = Field(None, description="OAuth client secret")
redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI")
# Session state
status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status")
# Timestamps
created_at: datetime = Field(default_factory=datetime.now, description="Session creation time")
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
class MCPOAuthSessionCreate(BaseMCPOAuth):
"""Create a new OAuth session."""
server_url: str = Field(..., description="MCP server URL")
server_name: str = Field(..., description="MCP server display name")
user_id: Optional[str] = Field(None, description="User ID associated with the session")
organization_id: str = Field(..., description="Organization ID associated with the session")
state: Optional[str] = Field(None, description="OAuth state parameter")
class MCPOAuthSessionUpdate(BaseMCPOAuth):
"""Update an existing OAuth session."""
authorization_url: Optional[str] = Field(None, description="OAuth authorization URL")
authorization_code: Optional[str] = Field(None, description="OAuth authorization code")
access_token: Optional[str] = Field(None, description="OAuth access token")
refresh_token: Optional[str] = Field(None, description="OAuth refresh token")
token_type: Optional[str] = Field(None, description="Token type")
expires_at: Optional[datetime] = Field(None, description="Token expiry time")
scope: Optional[str] = Field(None, description="OAuth scope")
client_id: Optional[str] = Field(None, description="OAuth client ID")
client_secret: Optional[str] = Field(None, description="OAuth client secret")
redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI")
status: Optional[OAuthSessionStatus] = Field(None, description="Session status")

View File

@@ -1,4 +1,6 @@
import asyncio
import json
from collections.abc import AsyncGenerator
from typing import Any, Dict, List, Optional, Union
from composio.client import ComposioClientError, HTTPError, NoItemsFound
@@ -11,27 +13,37 @@ from composio.exceptions import (
EnumStringNotFound,
)
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from fastapi.responses import HTMLResponse
from pydantic import BaseModel, Field
from starlette.responses import StreamingResponse
from letta.errors import LettaToolCreateError
from letta.functions.functions import derive_openai_json_schema
from letta.functions.mcp_client.exceptions import MCPTimeoutError
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
from letta.helpers.composio_helpers import get_composio_api_key
from letta.helpers.decorators import deprecated
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.orm.errors import UniqueConstraintViolationError
from letta.orm.mcp_oauth import OAuthSessionStatus
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
from letta.schemas.mcp import MCPOAuthSessionCreate, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
from letta.schemas.message import Message
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.services.mcp.sse_client import AsyncSSEMCPClient
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
from letta.services.mcp.oauth_utils import (
MCPOAuthSession,
create_oauth_provider,
drill_down_exception,
get_oauth_success_html,
oauth_stream_event,
)
from letta.services.mcp.types import OauthStreamEvent
from letta.settings import tool_settings
router = APIRouter(prefix="/tools", tags=["tools"])
@@ -612,35 +624,26 @@ async def delete_mcp_server_from_config(
return [server.to_config() for server in all_servers]
@router.post("/mcp/servers/test", response_model=List[MCPTool], operation_id="test_mcp_server")
@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream")
@router.post("/mcp/servers/test", operation_id="test_mcp_server")
async def test_mcp_server(
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Test connection to an MCP server without adding it.
Returns the list of available tools if successful.
Returns the list of available tools if successful, or OAuth information if OAuth is required.
"""
client = None
try:
# create a temporary MCP client based on the server type
if request.type == MCPServerType.SSE:
if not isinstance(request, SSEServerConfig):
request = SSEServerConfig(**request.model_dump())
client = AsyncSSEMCPClient(request)
elif request.type == MCPServerType.STREAMABLE_HTTP:
if not isinstance(request, StreamableHTTPServerConfig):
request = StreamableHTTPServerConfig(**request.model_dump())
client = AsyncStreamableHTTPMCPClient(request)
elif request.type == MCPServerType.STDIO:
if not isinstance(request, StdioServerConfig):
request = StdioServerConfig(**request.model_dump())
client = AsyncStdioMCPClient(request)
else:
raise ValueError(f"Invalid MCP server type: {request.type}")
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
client = await server.mcp_manager.get_mcp_client(request, actor)
await client.connect_to_server()
tools = await client.list_tools()
return tools
return {"status": "success", "tools": tools}
except ConnectionError as e:
raise HTTPException(
status_code=400,
@@ -676,6 +679,155 @@ async def test_mcp_server(
logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
@router.post(
"/mcp/servers/connect",
response_model=None,
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
operation_id="connect_mcp_server",
)
async def connect_mcp_server(
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"),
) -> StreamingResponse:
"""
Connect to an MCP server with support for OAuth via SSE.
Returns a stream of events handling authorization state and exchange if OAuth is required.
"""
async def oauth_stream_generator(
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig]
) -> AsyncGenerator[str, None]:
client = None
oauth_provider = None
temp_client = None
connect_task = None
try:
# Acknolwedge connection attempt
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
# Create MCP client with respective transport type
try:
client = await server.mcp_manager.get_mcp_client(request, actor)
except ValueError as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
return
# Try normal connection first for flows that don't require OAuth
try:
await client.connect_to_server()
tools = await client.list_tools(serialize=True)
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
return
except ConnectionError:
# Continue to OAuth flow
logger.info(f"Attempting OAuth flow for {request.server_url}...")
except Exception as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
return
# OAuth required, yield state to client to prepare to handle authorization URL
yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required")
# Create OAuth session to persist the state of the OAuth flow
session_create = MCPOAuthSessionCreate(
server_url=request.server_url,
server_name=request.server_name,
user_id=actor.id,
organization_id=actor.organization_id,
)
oauth_session = await server.mcp_manager.create_oauth_session(session_create, actor)
session_id = oauth_session.id
# Create OAuth provider for the instance of the stream connection
# Note: Using the correct API path for the callback
# do not edit this this is the correct url
redirect_uri = f"http://localhost:8283/v1/tools/mcp/oauth/callback/{session_id}"
oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, server.mcp_manager, actor)
# Get authorization URL by triggering OAuth flow
temp_client = None
try:
temp_client = await server.mcp_manager.get_mcp_client(request, actor, oauth_provider)
# Run connect_to_server in background to avoid blocking
# This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database
connect_task = asyncio.create_task(temp_client.connect_to_server())
# Give the OAuth flow time to trigger and save the URL
await asyncio.sleep(1.0)
# Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL
auth_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor)
if auth_session and auth_session.authorization_url:
yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id)
except Exception as e:
logger.error(f"Error triggering OAuth flow: {e}")
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}")
# Clean up active resources
if connect_task and not connect_task.done():
connect_task.cancel()
try:
await connect_task
except asyncio.CancelledError:
pass
if temp_client:
try:
await temp_client.cleanup()
except Exception as cleanup_error:
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
return
# Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit
yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...")
# Callback handler will poll for authorization code and state and update the OAuth session
await connect_task
tools = await temp_client.list_tools(serialize=True)
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
return
except Exception as e:
detailed_error = drill_down_exception(e)
logger.error(f"Error in OAuth stream:\n{detailed_error}")
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
finally:
if connect_task and not connect_task.done():
connect_task.cancel()
try:
await connect_task
except asyncio.CancelledError:
pass
if client:
try:
await client.cleanup()
except Exception as cleanup_error:
detailed_error = drill_down_exception(cleanup_error)
logger.warning(f"Error during MCP client cleanup: {detailed_error}")
if temp_client:
try:
await temp_client.cleanup()
except Exception as cleanup_error:
# TODO: @jnjpng fix async cancel scope issue
# detailed_error = drill_down_exception(cleanup_error)
logger.warning(f"Aysnc cleanup confict during temp MCP client cleanup: {cleanup_error}")
return StreamingResponseWithStatusCode(oauth_stream_generator(request), media_type="text/event-stream")
class CodeInput(BaseModel):
code: str = Field(..., description="Python source code to parse for JSON schema")
@@ -697,6 +849,45 @@ async def generate_json_schema(
raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}")
# TODO: @jnjpng need to route this through cloud API for production
@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback", response_class=HTMLResponse)
async def mcp_oauth_callback(
session_id: str,
code: Optional[str] = Query(None, description="OAuth authorization code"),
state: Optional[str] = Query(None, description="OAuth state parameter"),
error: Optional[str] = Query(None, description="OAuth error"),
error_description: Optional[str] = Query(None, description="OAuth error description"),
):
"""
Handle OAuth callback for MCP server authentication.
"""
try:
oauth_session = MCPOAuthSession(session_id)
if error:
error_msg = f"OAuth error: {error}"
if error_description:
error_msg += f" - {error_description}"
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": error_msg}
if not code or not state:
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": "Missing authorization code or state"}
# Store authorization code
success = await oauth_session.store_authorization_code(code, state)
if not success:
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
return {"status": "error", "message": "Invalid state parameter"}
return HTMLResponse(content=get_oauth_success_html(), status_code=200)
except Exception as e:
logger.error(f"OAuth callback error: {e}")
return {"status": "error", "message": f"OAuth callback failed: {str(e)}"}
class GenerateToolInput(BaseModel):
tool_name: str = Field(..., description="Name of the tool to generate code for")
prompt: str = Field(..., description="User prompt to generate code")

View File

@@ -1,9 +1,9 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Optional, Tuple
from mcp import ClientSession
from mcp import Tool as MCPTool
from mcp.client.auth import OAuthClientProvider
from mcp.types import TextContent
from letta.functions.mcp_client.types import BaseServerConfig
@@ -14,14 +14,12 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncBaseMCPClient:
def __init__(self, server_config: BaseServerConfig):
def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
self.server_config = server_config
self.oauth_provider = oauth_provider
self.exit_stack = AsyncExitStack()
self.session: Optional[ClientSession] = None
self.initialized = False
# Track the task that created this client
self._creation_task = asyncio.current_task()
self._cleanup_queue = asyncio.Queue(maxsize=1)
async def connect_to_server(self):
try:
@@ -48,9 +46,25 @@ class AsyncBaseMCPClient:
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
raise NotImplementedError("Subclasses must implement _initialize_connection")
async def list_tools(self) -> list[MCPTool]:
async def list_tools(self, serialize: bool = False) -> list[MCPTool]:
self._check_initialized()
response = await self.session.list_tools()
if serialize:
serializable_tools = []
for tool in response.tools:
if hasattr(tool, "model_dump"):
# Pydantic model - use model_dump
serializable_tools.append(tool.model_dump())
elif hasattr(tool, "dict"):
# Older Pydantic model - use dict()
serializable_tools.append(tool.dict())
elif hasattr(tool, "__dict__"):
# Regular object - use __dict__
serializable_tools.append(tool.__dict__)
else:
# Fallback - convert to string
serializable_tools.append(str(tool))
return serializable_tools
return response.tools
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
@@ -79,29 +93,7 @@ class AsyncBaseMCPClient:
# TODO: still hitting some async errors for voice agents, need to fix
async def cleanup(self):
"""Clean up resources - ensure this runs in the same task"""
if hasattr(self, "_cleanup_task"):
# If we're in a different task, schedule cleanup in original task
current_task = asyncio.current_task()
if current_task != self._creation_task:
# Create a future to signal completion
cleanup_done = asyncio.Future()
self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done))
await cleanup_done
return
# Normal cleanup
await self.exit_stack.aclose()
def to_sync_client(self):
raise NotImplementedError("Subclasses must implement to_sync_client")
async def __aenter__(self):
"""Enter the async context manager."""
await self.connect_to_server()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit the async context manager."""
await self.cleanup()
return False # Don't suppress exceptions

View File

@@ -0,0 +1,433 @@
"""OAuth utilities for MCP server authentication."""
import asyncio
import json
import secrets
import time
import uuid
from datetime import datetime, timedelta
from typing import Callable, Optional, Tuple
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
from sqlalchemy import select
from letta.log import get_logger
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
from letta.schemas.mcp import MCPOAuthSessionUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.mcp.types import OauthStreamEvent
from letta.services.mcp_manager import MCPManager
logger = get_logger(__name__)
class DatabaseTokenStorage(TokenStorage):
"""Database-backed token storage using MCPOAuth table via mcp_manager."""
def __init__(self, session_id: str, mcp_manager: MCPManager, actor: PydanticUser):
self.session_id = session_id
self.mcp_manager = mcp_manager
self.actor = actor
async def get_tokens(self) -> Optional[OAuthToken]:
"""Retrieve tokens from database."""
oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor)
if not oauth_session or not oauth_session.access_token:
return None
return OAuthToken(
access_token=oauth_session.access_token,
refresh_token=oauth_session.refresh_token,
token_type=oauth_session.token_type,
expires_in=int(oauth_session.expires_at.timestamp() - time.time()),
scope=oauth_session.scope,
)
async def set_tokens(self, tokens: OAuthToken) -> None:
"""Store tokens in database."""
session_update = MCPOAuthSessionUpdate(
access_token=tokens.access_token,
refresh_token=tokens.refresh_token,
token_type=tokens.token_type,
expires_at=datetime.fromtimestamp(tokens.expires_in + time.time()),
scope=tokens.scope,
status=OAuthSessionStatus.AUTHORIZED,
)
await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
async def get_client_info(self) -> Optional[OAuthClientInformationFull]:
"""Retrieve client information from database."""
oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor)
if not oauth_session or not oauth_session.client_id:
return None
return OAuthClientInformationFull(
client_id=oauth_session.client_id,
client_secret=oauth_session.client_secret,
redirect_uris=[oauth_session.redirect_uri] if oauth_session.redirect_uri else [],
)
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
"""Store client information in database."""
session_update = MCPOAuthSessionUpdate(
client_id=client_info.client_id,
client_secret=client_info.client_secret,
redirect_uri=str(client_info.redirect_uris[0]) if client_info.redirect_uris else None,
)
await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
class MCPOAuthSession:
"""Legacy OAuth session class - deprecated, use mcp_manager directly."""
def __init__(self, server_url: str, server_name: str, user_id: Optional[str], organization_id: str):
self.server_url = server_url
self.server_name = server_name
self.user_id = user_id
self.organization_id = organization_id
self.session_id = str(uuid.uuid4())
self.state = secrets.token_urlsafe(32)
def __init__(self, session_id: str):
self.session_id = session_id
# TODO: consolidate / deprecate this in favor of mcp_manager access
async def create_session(self) -> str:
"""Create a new OAuth session in the database."""
async with db_registry.async_session() as session:
oauth_record = MCPOAuth(
id=self.session_id,
state=self.state,
server_url=self.server_url,
server_name=self.server_name,
user_id=self.user_id,
organization_id=self.organization_id,
status=OAuthSessionStatus.PENDING,
created_at=datetime.now(),
updated_at=datetime.now(),
)
oauth_record = await oauth_record.create_async(session, actor=None)
return self.session_id
async def get_session_status(self) -> OAuthSessionStatus:
"""Get the current status of the OAuth session."""
async with db_registry.async_session() as session:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
return oauth_record.status
except Exception:
return OAuthSessionStatus.ERROR
async def update_session_status(self, status: OAuthSessionStatus) -> None:
"""Update the session status."""
async with db_registry.async_session() as session:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
oauth_record.status = status
oauth_record.updated_at = datetime.now()
await oauth_record.update_async(db_session=session, actor=None)
except Exception:
pass
async def store_authorization_code(self, code: str, state: str) -> bool:
"""Store the authorization code from OAuth callback."""
async with db_registry.async_session() as session:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
# if oauth_record.state != state:
# return False
oauth_record.authorization_code = code
oauth_record.state = state
oauth_record.status = OAuthSessionStatus.AUTHORIZED
oauth_record.updated_at = datetime.now()
await oauth_record.update_async(db_session=session, actor=None)
return True
except Exception:
return False
async def get_authorization_url(self) -> Optional[str]:
"""Get the authorization URL for this session."""
async with db_registry.async_session() as session:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
return oauth_record.authorization_url
except Exception:
return None
async def set_authorization_url(self, url: str) -> None:
"""Set the authorization URL for this session."""
async with db_registry.async_session() as session:
try:
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
oauth_record.authorization_url = url
oauth_record.updated_at = datetime.now()
await oauth_record.update_async(db_session=session, actor=None)
except Exception:
pass
async def create_oauth_provider(
session_id: str,
server_url: str,
redirect_uri: str,
mcp_manager: MCPManager,
actor: PydanticUser,
url_callback: Optional[Callable[[str], None]] = None,
) -> OAuthClientProvider:
"""Create an OAuth provider for MCP server authentication."""
client_metadata_dict = {
"client_name": "Letta MCP Client",
"redirect_uris": [redirect_uri],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
}
# Use manager-based storage
storage = DatabaseTokenStorage(session_id, mcp_manager, actor)
# Extract base URL (remove /mcp endpoint if present)
oauth_server_url = server_url.rstrip("/").removesuffix("/sse").removesuffix("/mcp")
async def redirect_handler(authorization_url: str) -> None:
"""Handle OAuth redirect by storing the authorization URL."""
logger.info(f"OAuth redirect handler called with URL: {authorization_url}")
session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url)
await mcp_manager.update_oauth_session(session_id, session_update, actor)
logger.info(f"OAuth authorization URL stored: {authorization_url}")
# Call the callback if provided (e.g., to yield URL to SSE stream)
if url_callback:
url_callback(authorization_url)
async def callback_handler() -> Tuple[str, Optional[str]]:
"""Handle OAuth callback by waiting for authorization code."""
timeout = 300 # 5 minutes
start_time = time.time()
logger.info(f"Waiting for authorization code for session {session_id}")
while time.time() - start_time < timeout:
oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor)
if oauth_session and oauth_session.authorization_code:
return oauth_session.authorization_code, oauth_session.state
elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR:
raise Exception("OAuth authorization failed")
await asyncio.sleep(1)
raise Exception(f"Timeout waiting for OAuth callback after {timeout} seconds")
return OAuthClientProvider(
server_url=oauth_server_url,
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
storage=storage,
redirect_handler=redirect_handler,
callback_handler=callback_handler,
)
async def cleanup_expired_oauth_sessions(max_age_hours: int = 24) -> None:
"""Clean up expired OAuth sessions."""
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
async with db_registry.async_session() as session:
result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time))
expired_sessions = result.scalars().all()
for oauth_session in expired_sessions:
await oauth_session.hard_delete_async(db_session=session, actor=None)
if expired_sessions:
logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions")
def oauth_stream_event(event: OauthStreamEvent, **kwargs) -> str:
data = {"event": event.value}
data.update(kwargs)
return f"data: {json.dumps(data)}\n\n"
def drill_down_exception(exception, depth=0, max_depth=5):
"""Recursively drill down into nested exceptions to find the root cause"""
indent = " " * depth
error_details = []
error_details.append(f"{indent}Exception at depth {depth}:")
error_details.append(f"{indent} Type: {type(exception).__name__}")
error_details.append(f"{indent} Message: {str(exception)}")
error_details.append(f"{indent} Module: {getattr(type(exception), '__module__', 'unknown')}")
# Check for exception groups (TaskGroup errors)
if hasattr(exception, "exceptions") and exception.exceptions:
error_details.append(f"{indent} ExceptionGroup with {len(exception.exceptions)} sub-exceptions:")
for i, sub_exc in enumerate(exception.exceptions):
error_details.append(f"{indent} Sub-exception {i}:")
if depth < max_depth:
error_details.extend(drill_down_exception(sub_exc, depth + 1, max_depth))
# Check for chained exceptions (__cause__ and __context__)
if hasattr(exception, "__cause__") and exception.__cause__ and depth < max_depth:
error_details.append(f"{indent} Caused by:")
error_details.extend(drill_down_exception(exception.__cause__, depth + 1, max_depth))
if hasattr(exception, "__context__") and exception.__context__ and depth < max_depth:
error_details.append(f"{indent} Context:")
error_details.extend(drill_down_exception(exception.__context__, depth + 1, max_depth))
# Add traceback info
import traceback
if hasattr(exception, "__traceback__") and exception.__traceback__:
tb_lines = traceback.format_tb(exception.__traceback__)
error_details.append(f"{indent} Traceback:")
for line in tb_lines[-3:]: # Show last 3 traceback lines
error_details.append(f"{indent} {line.strip()}")
error_info = "".join(error_details)
return error_info
def get_oauth_success_html() -> str:
"""Generate HTML for successful OAuth authorization."""
return """
<!DOCTYPE html>
<html>
<head>
<title>Authorization Successful - Letta</title>
<style>
* {
margin: 0;
padding: 0;
box-sizing: border-box;
}
body {
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
display: flex;
justify-content: center;
align-items: center;
min-height: 100vh;
margin: 0;
background-color: #f5f5f5;
background-image: url("data:image/svg+xml,%3Csvg width='1440' height='860' viewBox='0 0 1440 860' fill='none' xmlns='http://www.w3.org/2000/svg'%3E%3Cg clip-path='url(%23clip0_14823_146864)'%3E%3Cpath d='M720.001 1003.14C1080.62 1003.14 1372.96 824.028 1372.96 603.083C1372.96 382.138 1080.62 203.026 720.001 203.026C359.384 203.026 67.046 382.138 67.046 603.083C67.046 824.028 359.384 1003.14 720.001 1003.14Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 978.04C910.334 978.04 1064.63 883.505 1064.63 766.891C1064.63 650.276 910.334 555.741 719.999 555.741C529.665 555.741 375.368 650.276 375.368 766.891C375.368 883.505 529.665 978.04 719.999 978.04Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720 1020.95C1262.17 1020.95 1701.68 756.371 1701.68 430C1701.68 103.629 1262.17 -160.946 720 -160.946C177.834 -160.946 -261.678 103.629 -261.678 430C-261.678 756.371 177.834 1020.95 720 1020.95Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 323.658C910.334 323.658 1064.63 223.814 1064.63 100.649C1064.63 -22.5157 910.334 -122.36 719.999 -122.36C529.665 -122.36 375.368 -22.5157 375.368 100.649C375.368 223.814 529.665 323.658 719.999 323.658Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720.001 706.676C1080.62 706.676 1372.96 517.507 1372.96 284.155C1372.96 50.8029 1080.62 -138.366 720.001 -138.366C359.384 -138.366 67.046 50.8029 67.046 284.155C67.046 517.507 359.384 706.676 720.001 706.676Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 874.604C1180.69 874.604 1554.15 645.789 1554.15 363.531C1554.15 81.2725 1180.69 -147.543 719.999 -147.543C259.311 -147.543 -114.15 81.2725 -114.15 363.531C-114.15 645.789 259.311 874.604 719.999 874.604Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3C/g%3E%3Cdefs%3E%3CclipPath id='clip0_14823_146864'%3E%3Crect width='1440' height='860' fill='white'/%3E%3C/clipPath%3E%3C/defs%3E%3C/svg%3E");
background-size: cover;
background-position: center;
background-repeat: no-repeat;
}
.card {
text-align: center;
padding: 48px;
background: white;
border-radius: 8px;
border: 1px solid #E1E2E3;
max-width: 400px;
width: 90%;
position: relative;
z-index: 1;
}
.logo {
width: 48px;
height: 48px;
margin: 0 auto 24px;
display: block;
}
.logo svg {
width: 100%;
height: 100%;
}
h1 {
font-size: 20px;
font-weight: 600;
color: #101010;
margin-bottom: 12px;
line-height: 1.2;
}
.subtitle {
color: #666;
font-size: 12px;
margin-top: 10px;
margin-bottom: 24px;
line-height: 1.5;
}
.close-info {
font-size: 12px;
color: #999;
display: flex;
align-items: center;
justify-content: center;
gap: 8px;
}
.spinner {
width: 16px;
height: 16px;
border: 2px solid #E1E2E3;
border-top: 2px solid #333;
border-radius: 50%;
animation: spin 1s linear infinite;
}
@keyframes spin {
0% { transform: rotate(0deg); }
100% { transform: rotate(360deg); }
}
/* Dark mode styles */
@media (prefers-color-scheme: dark) {
body {
background-color: #101010;
background-image: url("data:image/svg+xml,%3Csvg width='1440' height='860' viewBox='0 0 1440 860' fill='none' xmlns='http://www.w3.org/2000/svg'%3E%3Cg clip-path='url(%23clip0_14833_149362)'%3E%3Cpath d='M720.001 1003.14C1080.62 1003.14 1372.96 824.028 1372.96 603.083C1372.96 382.138 1080.62 203.026 720.001 203.026C359.384 203.026 67.046 382.138 67.046 603.083C67.046 824.028 359.384 1003.14 720.001 1003.14Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 978.04C910.334 978.04 1064.63 883.505 1064.63 766.891C1064.63 650.276 910.334 555.741 719.999 555.741C529.665 555.741 375.368 650.276 375.368 766.891C375.368 883.505 529.665 978.04 719.999 978.04Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720 1020.95C1262.17 1020.95 1701.68 756.371 1701.68 430C1701.68 103.629 1262.17 -160.946 720 -160.946C177.834 -160.946 -261.678 103.629 -261.678 430C-261.678 756.371 177.834 1020.95 720 1020.95Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 323.658C910.334 323.658 1064.63 223.814 1064.63 100.649C1064.63 -22.5157 910.334 -122.36 719.999 -122.36C529.665 -122.36 375.368 -22.5157 375.368 100.649C375.368 223.814 529.665 323.658 719.999 323.658Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720.001 706.676C1080.62 706.676 1372.96 517.507 1372.96 284.155C1372.96 50.8029 1080.62 -138.366 720.001 -138.366C359.384 -138.366 67.046 50.8029 67.046 284.155C67.046 517.507 359.384 706.676 720.001 706.676Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 874.604C1180.69 874.604 1554.15 645.789 1554.15 363.531C1554.15 81.2725 1180.69 -147.543 719.999 -147.543C259.311 -147.543 -114.15 81.2725 -114.15 363.531C-114.15 645.789 259.311 874.604 719.999 874.604Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3C/g%3E%3Cdefs%3E%3CclipPath id='clip0_14833_149362'%3E%3Crect width='1440' height='860' fill='white'/%3E%3C/clipPath%3E%3C/defs%3E%3C/svg%3E");
}
.card {
background-color: #141414;
border-color: #202020;
}
h1 {
color: #E1E2E3;
}
.subtitle {
color: #999;
}
.logo svg path {
fill: #E1E2E3;
}
.spinner {
border-color: #46484A;
border-top-color: #E1E2E3;
}
}
</style>
</head>
<body>
<div class="card">
<div class="logo">
<svg width="48" height="48" viewBox="0 0 18 18" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M10.7134 7.30028H7.28759V10.7002H10.7134V7.30028Z" fill="#333"/>
<path d="M14.1391 2.81618V0.5H3.86131V2.81618C3.86131 3.41495 3.37266 3.89991 2.76935 3.89991H0.435547V14.1001H2.76935C3.37266 14.1001 3.86131 14.5851 3.86131 15.1838V17.5H14.1391V15.1838C14.1391 14.5851 14.6277 14.1001 15.231 14.1001H17.5648V3.89991H15.231C14.6277 3.89991 14.1391 3.41495 14.1391 2.81618ZM14.1391 13.0159C14.1391 13.6147 13.6504 14.0996 13.0471 14.0996H4.95375C4.35043 14.0996 3.86179 13.6147 3.86179 13.0159V4.98363C3.86179 4.38486 4.35043 3.89991 4.95375 3.89991H13.0471C13.6504 3.89991 14.1391 4.38486 14.1391 4.98363V13.0159Z" fill="#333"/>
</svg>
</div>
<h3>Authorization Successful</h3>
<p class="subtitle">You have successfully connected your MCP server.</p>
<div class="close-info">
<span>You can now close this window.</span>
</div>
</div>
</body>
</html>
"""

View File

@@ -1,4 +1,7 @@
from typing import Optional
from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider
from mcp.client.sse import sse_client
from letta.functions.mcp_client.types import SSEServerConfig
@@ -13,6 +16,9 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncSSEMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
super().__init__(server_config, oauth_provider)
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
headers = {}
if server_config.custom_headers:
@@ -21,7 +27,12 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient):
if server_config.auth_header and server_config.auth_token:
headers[server_config.auth_header] = server_config.auth_token
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None)
# Use OAuth provider if available, otherwise use regular headers
if self.oauth_provider:
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider)
else:
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None)
sse_transport = await self.exit_stack.enter_async_context(sse_cm)
self.stdio, self.write = sse_transport

View File

@@ -1,4 +1,7 @@
from typing import Optional
from mcp import ClientSession
from mcp.client.auth import OAuthClientProvider
from mcp.client.streamable_http import streamablehttp_client
from letta.functions.mcp_client.types import BaseServerConfig, StreamableHTTPServerConfig
@@ -9,10 +12,12 @@ logger = get_logger(__name__)
class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
super().__init__(server_config, oauth_provider)
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
if not isinstance(server_config, StreamableHTTPServerConfig):
raise ValueError("Expected StreamableHTTPServerConfig")
try:
# Prepare headers for authentication
headers = {}
@@ -23,11 +28,18 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
if server_config.auth_header and server_config.auth_token:
headers[server_config.auth_header] = server_config.auth_token
# Use streamablehttp_client context manager with headers if provided
if headers:
streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers)
# Use OAuth provider if available, otherwise use regular headers
if self.oauth_provider:
streamable_http_cm = streamablehttp_client(
server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider
)
else:
streamable_http_cm = streamablehttp_client(server_config.server_url)
# Use streamablehttp_client context manager with headers if provided
if headers:
streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers)
else:
streamable_http_cm = streamablehttp_client(server_config.server_url)
read_stream, write_stream, _ = await self.exit_stack.enter_async_context(streamable_http_cm)
# Create and enter the ClientSession context manager

View File

@@ -46,3 +46,12 @@ class StdioServerConfig(BaseServerConfig):
if self.env is not None:
values["env"] = self.env
return values
class OauthStreamEvent(str, Enum):
CONNECTION_ATTEMPT = "connection_attempt"
SUCCESS = "success"
ERROR = "error"
OAUTH_REQUIRED = "oauth_required"
AUTHORIZATION_URL = "authorization_url"
WAITING_FOR_AUTH = "waiting_for_auth"

View File

@@ -1,5 +1,8 @@
import json
import os
import secrets
import uuid
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Union
from sqlalchemy import null
@@ -8,8 +11,18 @@ import letta.constants as constants
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
from letta.log import get_logger
from letta.orm.errors import NoResultFound
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
from letta.orm.mcp_server import MCPServer as MCPServerModel
from letta.schemas.mcp import MCPServer, UpdateMCPServer, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
from letta.schemas.mcp import (
MCPOAuthSession,
MCPOAuthSessionCreate,
MCPOAuthSessionUpdate,
MCPServer,
UpdateMCPServer,
UpdateSSEMCPServer,
UpdateStdioMCPServer,
UpdateStreamableHTTPMCPServer,
)
from letta.schemas.tool import Tool as PydanticTool
from letta.schemas.tool import ToolCreate
from letta.schemas.user import User as PydanticUser
@@ -38,14 +51,7 @@ class MCPManager:
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config()
if mcp_config.server_type == MCPServerType.SSE:
mcp_client = AsyncSSEMCPClient(server_config=server_config)
elif mcp_config.server_type == MCPServerType.STDIO:
mcp_client = AsyncStdioMCPClient(server_config=server_config)
elif mcp_config.server_type == MCPServerType.STREAMABLE_HTTP:
mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config)
else:
raise ValueError(f"Unsupported MCP server type: {mcp_config.server_type}")
mcp_client = await self.get_mcp_client(server_config, actor)
await mcp_client.connect_to_server()
# list tools
@@ -72,28 +78,20 @@ class MCPManager:
# read from config file
mcp_config = self.read_mcp_config()
if mcp_server_name not in mcp_config:
print("MCP server not found in config.", mcp_config)
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
server_config = mcp_config[mcp_server_name]
if isinstance(server_config, SSEServerConfig):
# mcp_client = AsyncSSEMCPClient(server_config=server_config)
async with AsyncSSEMCPClient(server_config=server_config) as mcp_client:
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
return result, success
elif isinstance(server_config, StdioServerConfig):
async with AsyncStdioMCPClient(server_config=server_config) as mcp_client:
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
return result, success
elif isinstance(server_config, StreamableHTTPServerConfig):
async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client:
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
return result, success
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
mcp_client = await self.get_mcp_client(server_config, actor)
await mcp_client.connect_to_server()
# call tool
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
# TODO: change to pydantic tool
await mcp_client.cleanup()
return result, success
@enforce_types
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
@@ -324,3 +322,246 @@ class MCPManager:
logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
continue
return mcp_server_list
async def get_mcp_client(
self,
server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
actor: PydanticUser,
oauth_provider: Optional[Any] = None,
) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]:
"""
Helper function to create the appropriate MCP client based on server configuration.
Args:
server_config: The server configuration object
actor: The user making the request
oauth_provider: Optional OAuth provider for authentication
Returns:
The appropriate MCP client instance
Raises:
ValueError: If server config type is not supported
"""
# If no OAuth provider is provided, check if we have stored OAuth credentials
if oauth_provider is None and hasattr(server_config, "server_url"):
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
if oauth_session and oauth_session.access_token:
# Create OAuth provider from stored credentials
from letta.services.mcp.oauth_utils import create_oauth_provider
oauth_provider = await create_oauth_provider(
session_id=oauth_session.id,
server_url=oauth_session.server_url,
redirect_uri=oauth_session.redirect_uri,
mcp_manager=self,
actor=actor,
)
if server_config.type == MCPServerType.SSE:
server_config = SSEServerConfig(**server_config.model_dump())
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider)
elif server_config.type == MCPServerType.STDIO:
server_config = StdioServerConfig(**server_config.model_dump())
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider)
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
server_config = StreamableHTTPServerConfig(**server_config.model_dump())
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider)
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
# OAuth-related methods
@enforce_types
async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession:
"""Create a new OAuth session for MCP server authentication."""
async with db_registry.async_session() as session:
# Create the OAuth session with a unique state
oauth_session = MCPOAuth(
id="mcp-oauth-" + str(uuid.uuid4())[:8],
state=secrets.token_urlsafe(32),
server_url=session_create.server_url,
server_name=session_create.server_name,
user_id=session_create.user_id,
organization_id=session_create.organization_id,
status=OAuthSessionStatus.PENDING,
created_at=datetime.now(),
updated_at=datetime.now(),
)
oauth_session = await oauth_session.create_async(session, actor=actor)
# Convert to Pydantic model
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
@enforce_types
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
"""Get an OAuth session by its ID."""
async with db_registry.async_session() as session:
try:
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=oauth_session.authorization_code,
access_token=oauth_session.access_token,
refresh_token=oauth_session.refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=oauth_session.client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
except NoResultFound:
return None
@enforce_types
async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
"""Get the latest OAuth session by server URL, organization, and user."""
from sqlalchemy import desc, select
async with db_registry.async_session() as session:
# Query for OAuth session matching organization, user, server URL, and status
# Order by updated_at desc to get the most recent record
result = await session.execute(
select(MCPOAuth)
.where(
MCPOAuth.organization_id == actor.organization_id,
MCPOAuth.user_id == actor.id,
MCPOAuth.server_url == server_url,
MCPOAuth.status == OAuthSessionStatus.AUTHORIZED,
)
.order_by(desc(MCPOAuth.updated_at))
.limit(1)
)
oauth_session = result.scalar_one_or_none()
if not oauth_session:
return None
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=oauth_session.authorization_code,
access_token=oauth_session.access_token,
refresh_token=oauth_session.refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=oauth_session.client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
@enforce_types
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
"""Update an existing OAuth session."""
async with db_registry.async_session() as session:
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
# Update fields that are provided
if session_update.authorization_url is not None:
oauth_session.authorization_url = session_update.authorization_url
if session_update.authorization_code is not None:
oauth_session.authorization_code = session_update.authorization_code
if session_update.access_token is not None:
oauth_session.access_token = session_update.access_token
if session_update.refresh_token is not None:
oauth_session.refresh_token = session_update.refresh_token
if session_update.token_type is not None:
oauth_session.token_type = session_update.token_type
if session_update.expires_at is not None:
oauth_session.expires_at = session_update.expires_at
if session_update.scope is not None:
oauth_session.scope = session_update.scope
if session_update.client_id is not None:
oauth_session.client_id = session_update.client_id
if session_update.client_secret is not None:
oauth_session.client_secret = session_update.client_secret
if session_update.redirect_uri is not None:
oauth_session.redirect_uri = session_update.redirect_uri
if session_update.status is not None:
oauth_session.status = session_update.status
# Always update the updated_at timestamp
oauth_session.updated_at = datetime.now()
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
return MCPOAuthSession(
id=oauth_session.id,
state=oauth_session.state,
server_url=oauth_session.server_url,
server_name=oauth_session.server_name,
user_id=oauth_session.user_id,
organization_id=oauth_session.organization_id,
authorization_url=oauth_session.authorization_url,
authorization_code=oauth_session.authorization_code,
access_token=oauth_session.access_token,
refresh_token=oauth_session.refresh_token,
token_type=oauth_session.token_type,
expires_at=oauth_session.expires_at,
scope=oauth_session.scope,
client_id=oauth_session.client_id,
client_secret=oauth_session.client_secret,
redirect_uri=oauth_session.redirect_uri,
status=oauth_session.status,
created_at=oauth_session.created_at,
updated_at=oauth_session.updated_at,
)
@enforce_types
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
"""Delete an OAuth session."""
async with db_registry.async_session() as session:
try:
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
await oauth_session.hard_delete_async(db_session=session, actor=actor)
except NoResultFound:
raise ValueError(f"OAuth session with id {session_id} not found.")
@enforce_types
async def cleanup_expired_oauth_sessions(self, max_age_hours: int = 24) -> int:
"""Clean up expired OAuth sessions and return the count of deleted sessions."""
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
async with db_registry.async_session() as session:
from sqlalchemy import select
# Find expired sessions
result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time))
expired_sessions = result.scalars().all()
# Delete expired sessions using async ORM method
for oauth_session in expired_sessions:
await oauth_session.hard_delete_async(db_session=session, actor=None)
if expired_sessions:
logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions")
return len(expired_sessions)

356
mcp_test.py Normal file
View File

@@ -0,0 +1,356 @@
#!/usr/bin/env python3
"""
Simple MCP client example with OAuth authentication support.
This client connects to an MCP server using streamable HTTP transport with OAuth.
"""
import asyncio
import os
import threading
import time
import webbrowser
from datetime import timedelta
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any
from urllib.parse import parse_qs, urlparse
from mcp.client.auth import OAuthClientProvider, TokenStorage
from mcp.client.session import ClientSession
from mcp.client.sse import sse_client
from mcp.client.streamable_http import streamablehttp_client
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
class InMemoryTokenStorage(TokenStorage):
"""Simple in-memory token storage implementation."""
def __init__(self):
self._tokens: OAuthToken | None = None
self._client_info: OAuthClientInformationFull | None = None
async def get_tokens(self) -> OAuthToken | None:
return self._tokens
async def set_tokens(self, tokens: OAuthToken) -> None:
self._tokens = tokens
async def get_client_info(self) -> OAuthClientInformationFull | None:
return self._client_info
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
self._client_info = client_info
class CallbackHandler(BaseHTTPRequestHandler):
"""Simple HTTP handler to capture OAuth callback."""
def __init__(self, request, client_address, server, callback_data):
"""Initialize with callback data storage."""
self.callback_data = callback_data
super().__init__(request, client_address, server)
def do_GET(self):
"""Handle GET request from OAuth redirect."""
parsed = urlparse(self.path)
query_params = parse_qs(parsed.query)
if "code" in query_params:
self.callback_data["authorization_code"] = query_params["code"][0]
self.callback_data["state"] = query_params.get("state", [None])[0]
self.send_response(200)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(
b"""
<html>
<body>
<h1>Authorization Successful!</h1>
<p>You can close this window and return to the terminal.</p>
<script>setTimeout(() => window.close(), 2000);</script>
</body>
</html>
"""
)
elif "error" in query_params:
self.callback_data["error"] = query_params["error"][0]
self.send_response(400)
self.send_header("Content-type", "text/html")
self.end_headers()
self.wfile.write(
f"""
<html>
<body>
<h1>Authorization Failed</h1>
<p>Error: {query_params['error'][0]}</p>
<p>You can close this window and return to the terminal.</p>
</body>
</html>
""".encode()
)
else:
self.send_response(404)
self.end_headers()
def log_message(self, format, *args):
"""Suppress default logging."""
class CallbackServer:
"""Simple server to handle OAuth callbacks."""
def __init__(self, port=3000):
self.port = port
self.server = None
self.thread = None
self.callback_data = {"authorization_code": None, "state": None, "error": None}
def _create_handler_with_data(self):
"""Create a handler class with access to callback data."""
callback_data = self.callback_data
class DataCallbackHandler(CallbackHandler):
def __init__(self, request, client_address, server):
super().__init__(request, client_address, server, callback_data)
return DataCallbackHandler
def start(self):
"""Start the callback server in a background thread."""
handler_class = self._create_handler_with_data()
self.server = HTTPServer(("localhost", self.port), handler_class)
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start()
print(f"🖥️ Started callback server on http://localhost:{self.port}")
def stop(self):
"""Stop the callback server."""
if self.server:
self.server.shutdown()
self.server.server_close()
if self.thread:
self.thread.join(timeout=1)
def wait_for_callback(self, timeout=300):
"""Wait for OAuth callback with timeout."""
start_time = time.time()
while time.time() - start_time < timeout:
if self.callback_data["authorization_code"]:
return self.callback_data["authorization_code"]
elif self.callback_data["error"]:
raise Exception(f"OAuth error: {self.callback_data['error']}")
time.sleep(0.1)
raise Exception("Timeout waiting for OAuth callback")
def get_state(self):
"""Get the received state parameter."""
return self.callback_data["state"]
class SimpleAuthClient:
"""Simple MCP client with auth support."""
def __init__(self, server_url: str, transport_type: str = "streamable_http"):
self.server_url = server_url
self.transport_type = transport_type
self.session: ClientSession | None = None
async def connect(self):
"""Connect to the MCP server."""
print(f"🔗 Attempting to connect to {self.server_url}...")
try:
callback_server = CallbackServer(port=3030)
callback_server.start()
async def callback_handler() -> tuple[str, str | None]:
"""Wait for OAuth callback and return auth code and state."""
print("⏳ Waiting for authorization callback...")
try:
auth_code = callback_server.wait_for_callback(timeout=300)
return auth_code, callback_server.get_state()
finally:
callback_server.stop()
client_metadata_dict = {
"client_name": "Simple Auth Client",
"redirect_uris": ["http://localhost:3030/callback"],
"grant_types": ["authorization_code", "refresh_token"],
"response_types": ["code"],
"token_endpoint_auth_method": "client_secret_post",
}
async def _default_redirect_handler(authorization_url: str) -> None:
"""Default redirect handler that opens the URL in a browser."""
print(f"Opening browser for authorization: {authorization_url}")
webbrowser.open(authorization_url)
# Create OAuth authentication handler using the new interface
oauth_auth = OAuthClientProvider(
server_url=self.server_url.replace("/mcp", ""),
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
storage=InMemoryTokenStorage(),
redirect_handler=_default_redirect_handler,
callback_handler=callback_handler,
)
# Create transport with auth handler based on transport type
if self.transport_type == "sse":
print("📡 Opening SSE transport connection with auth...")
async with sse_client(
url=self.server_url,
auth=oauth_auth,
timeout=60,
) as (read_stream, write_stream):
await self._run_session(read_stream, write_stream, None)
else:
print("📡 Opening StreamableHTTP transport connection with auth...")
async with streamablehttp_client(
url=self.server_url,
auth=oauth_auth,
timeout=timedelta(seconds=60),
) as (read_stream, write_stream, get_session_id):
await self._run_session(read_stream, write_stream, get_session_id)
except Exception as e:
print(f"❌ Failed to connect: {e}")
import traceback
traceback.print_exc()
async def _run_session(self, read_stream, write_stream, get_session_id):
"""Run the MCP session with the given streams."""
print("🤝 Initializing MCP session...")
async with ClientSession(read_stream, write_stream) as session:
self.session = session
print("⚡ Starting session initialization...")
await session.initialize()
print("✨ Session initialization complete!")
print(f"\n✅ Connected to MCP server at {self.server_url}")
if get_session_id:
session_id = get_session_id()
if session_id:
print(f"Session ID: {session_id}")
# Run interactive loop
await self.interactive_loop()
async def list_tools(self):
"""List available tools from the server."""
if not self.session:
print("❌ Not connected to server")
return
try:
result = await self.session.list_tools()
if hasattr(result, "tools") and result.tools:
print("\n📋 Available tools:")
for i, tool in enumerate(result.tools, 1):
print(f"{i}. {tool.name}")
if tool.description:
print(f" Description: {tool.description}")
print()
else:
print("No tools available")
except Exception as e:
print(f"❌ Failed to list tools: {e}")
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None):
"""Call a specific tool."""
if not self.session:
print("❌ Not connected to server")
return
try:
result = await self.session.call_tool(tool_name, arguments or {})
print(f"\n🔧 Tool '{tool_name}' result:")
if hasattr(result, "content"):
for content in result.content:
if content.type == "text":
print(content.text)
else:
print(content)
else:
print(result)
except Exception as e:
print(f"❌ Failed to call tool '{tool_name}': {e}")
async def interactive_loop(self):
"""Run interactive command loop."""
print("\n🎯 Interactive MCP Client")
print("Commands:")
print(" list - List available tools")
print(" call <tool_name> [args] - Call a tool")
print(" quit - Exit the client")
print()
while True:
try:
command = input("mcp> ").strip()
if not command:
continue
if command == "quit":
break
elif command == "list":
await self.list_tools()
elif command.startswith("call "):
parts = command.split(maxsplit=2)
tool_name = parts[1] if len(parts) > 1 else ""
if not tool_name:
print("❌ Please specify a tool name")
continue
# Parse arguments (simple JSON-like format)
arguments = {}
if len(parts) > 2:
import json
try:
arguments = json.loads(parts[2])
except json.JSONDecodeError:
print("❌ Invalid arguments format (expected JSON)")
continue
await self.call_tool(tool_name, arguments)
else:
print("❌ Unknown command. Try 'list', 'call <tool_name>', or 'quit'")
except KeyboardInterrupt:
print("\n\n👋 Goodbye!")
break
except EOFError:
break
async def main():
"""Main entry point."""
# Default server URL - can be overridden with environment variable
# Most MCP streamable HTTP servers use /mcp as the endpoint
server_url = os.getenv("MCP_SERVER_PORT", 8000)
transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http")
server_url = f"http://localhost:{server_url}/mcp" if transport_type == "streamable_http" else f"http://localhost:{server_url}/sse"
print("🚀 Simple MCP Auth Client")
print(f"Connecting to: {server_url}")
print(f"Transport type: {transport_type}")
# Start connection flow - OAuth will be handled automatically
client = SimpleAuthClient(server_url, transport_type)
await client.connect()
def cli():
"""CLI entry point for uv script."""
asyncio.run(main())
if __name__ == "__main__":
cli()