feat: add support for oauth mcp
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
67
alembic/versions/f5d26b0526e8_add_mcp_oauth.py
Normal file
67
alembic/versions/f5d26b0526e8_add_mcp_oauth.py
Normal file
@@ -0,0 +1,67 @@
|
||||
"""add_mcp_oauth
|
||||
|
||||
Revision ID: f5d26b0526e8
|
||||
Revises: ddecfe4902bc
|
||||
Create Date: 2025-07-24 12:34:05.795355
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "f5d26b0526e8"
|
||||
down_revision: Union[str, None] = "ddecfe4902bc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"mcp_oauth",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("state", sa.String(length=255), nullable=False),
|
||||
sa.Column("server_id", sa.String(length=255), nullable=True),
|
||||
sa.Column("server_url", sa.Text(), nullable=False),
|
||||
sa.Column("server_name", sa.Text(), nullable=False),
|
||||
sa.Column("authorization_url", sa.Text(), nullable=True),
|
||||
sa.Column("authorization_code", sa.Text(), nullable=True),
|
||||
sa.Column("access_token", sa.Text(), nullable=True),
|
||||
sa.Column("refresh_token", sa.Text(), nullable=True),
|
||||
sa.Column("token_type", sa.String(length=50), nullable=False),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("scope", sa.Text(), nullable=True),
|
||||
sa.Column("client_id", sa.Text(), nullable=True),
|
||||
sa.Column("client_secret", sa.Text(), nullable=True),
|
||||
sa.Column("redirect_uri", sa.Text(), nullable=True),
|
||||
sa.Column("status", sa.String(length=20), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("user_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(["server_id"], ["mcp_server.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["users.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.UniqueConstraint("state"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("mcp_oauth")
|
||||
# ### end Alembic commands ###
|
||||
62
letta/orm/mcp_oauth.py
Normal file
62
letta/orm/mcp_oauth.py
Normal file
@@ -0,0 +1,62 @@
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import DateTime, ForeignKey, String, Text
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.orm.mixins import OrganizationMixin, UserMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
|
||||
|
||||
class OAuthSessionStatus(str, Enum):
|
||||
"""OAuth session status enumeration."""
|
||||
|
||||
PENDING = "pending"
|
||||
AUTHORIZED = "authorized"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class MCPOAuth(SqlalchemyBase, OrganizationMixin, UserMixin):
|
||||
"""OAuth session model for MCP server authentication."""
|
||||
|
||||
__tablename__ = "mcp_oauth"
|
||||
|
||||
# Override the id field to match database UUID generation
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True, default=lambda: f"{uuid.uuid4()}")
|
||||
|
||||
# Core session information
|
||||
state: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, doc="OAuth state parameter")
|
||||
server_id: Mapped[str] = mapped_column(String(255), ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True, doc="MCP server ID")
|
||||
server_url: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server URL")
|
||||
server_name: Mapped[str] = mapped_column(Text, nullable=False, doc="MCP server display name")
|
||||
|
||||
# OAuth flow data
|
||||
authorization_url: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization URL")
|
||||
authorization_code: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth authorization code")
|
||||
|
||||
# Token data
|
||||
access_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth access token")
|
||||
refresh_token: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth refresh token")
|
||||
token_type: Mapped[str] = mapped_column(String(50), default="Bearer", doc="Token type")
|
||||
expires_at: Mapped[Optional[datetime]] = mapped_column(DateTime(timezone=True), nullable=True, doc="Token expiry time")
|
||||
scope: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth scope")
|
||||
|
||||
# Client configuration
|
||||
client_id: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client ID")
|
||||
client_secret: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth client secret")
|
||||
redirect_uri: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="OAuth redirect URI")
|
||||
|
||||
# Session state
|
||||
status: Mapped[OAuthSessionStatus] = mapped_column(String(20), default=OAuthSessionStatus.PENDING, doc="Session status")
|
||||
|
||||
# Timestamps
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), default=lambda: datetime.now(), doc="Session creation time")
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), default=lambda: datetime.now(), onupdate=lambda: datetime.now(), doc="Last update time"
|
||||
)
|
||||
|
||||
# Relationships (if needed in the future)
|
||||
# user: Mapped[Optional["User"]] = relationship("User", back_populates="oauth_sessions")
|
||||
# organization: Mapped["Organization"] = relationship("Organization", back_populates="oauth_sessions")
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
@@ -10,6 +11,7 @@ from letta.functions.mcp_client.types import (
|
||||
StdioServerConfig,
|
||||
StreamableHTTPServerConfig,
|
||||
)
|
||||
from letta.orm.mcp_oauth import OAuthSessionStatus
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -119,3 +121,71 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
|
||||
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]
|
||||
RegisterMCPServer = Union[RegisterSSEMCPServer, RegisterStdioMCPServer, RegisterStreamableHTTPMCPServer]
|
||||
|
||||
|
||||
# OAuth-related schemas
|
||||
class BaseMCPOAuth(LettaBase):
|
||||
__id_prefix__ = "mcp-oauth"
|
||||
|
||||
|
||||
class MCPOAuthSession(BaseMCPOAuth):
|
||||
"""OAuth session for MCP server authentication."""
|
||||
|
||||
id: str = BaseMCPOAuth.generate_id_field()
|
||||
state: str = Field(..., description="OAuth state parameter")
|
||||
server_id: Optional[str] = Field(None, description="MCP server ID")
|
||||
server_url: str = Field(..., description="MCP server URL")
|
||||
server_name: str = Field(..., description="MCP server display name")
|
||||
|
||||
# User and organization context
|
||||
user_id: Optional[str] = Field(None, description="User ID associated with the session")
|
||||
organization_id: str = Field(..., description="Organization ID associated with the session")
|
||||
|
||||
# OAuth flow data
|
||||
authorization_url: Optional[str] = Field(None, description="OAuth authorization URL")
|
||||
authorization_code: Optional[str] = Field(None, description="OAuth authorization code")
|
||||
|
||||
# Token data
|
||||
access_token: Optional[str] = Field(None, description="OAuth access token")
|
||||
refresh_token: Optional[str] = Field(None, description="OAuth refresh token")
|
||||
token_type: str = Field(default="Bearer", description="Token type")
|
||||
expires_at: Optional[datetime] = Field(None, description="Token expiry time")
|
||||
scope: Optional[str] = Field(None, description="OAuth scope")
|
||||
|
||||
# Client configuration
|
||||
client_id: Optional[str] = Field(None, description="OAuth client ID")
|
||||
client_secret: Optional[str] = Field(None, description="OAuth client secret")
|
||||
redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI")
|
||||
|
||||
# Session state
|
||||
status: OAuthSessionStatus = Field(default=OAuthSessionStatus.PENDING, description="Session status")
|
||||
|
||||
# Timestamps
|
||||
created_at: datetime = Field(default_factory=datetime.now, description="Session creation time")
|
||||
updated_at: datetime = Field(default_factory=datetime.now, description="Last update time")
|
||||
|
||||
|
||||
class MCPOAuthSessionCreate(BaseMCPOAuth):
|
||||
"""Create a new OAuth session."""
|
||||
|
||||
server_url: str = Field(..., description="MCP server URL")
|
||||
server_name: str = Field(..., description="MCP server display name")
|
||||
user_id: Optional[str] = Field(None, description="User ID associated with the session")
|
||||
organization_id: str = Field(..., description="Organization ID associated with the session")
|
||||
state: Optional[str] = Field(None, description="OAuth state parameter")
|
||||
|
||||
|
||||
class MCPOAuthSessionUpdate(BaseMCPOAuth):
|
||||
"""Update an existing OAuth session."""
|
||||
|
||||
authorization_url: Optional[str] = Field(None, description="OAuth authorization URL")
|
||||
authorization_code: Optional[str] = Field(None, description="OAuth authorization code")
|
||||
access_token: Optional[str] = Field(None, description="OAuth access token")
|
||||
refresh_token: Optional[str] = Field(None, description="OAuth refresh token")
|
||||
token_type: Optional[str] = Field(None, description="Token type")
|
||||
expires_at: Optional[datetime] = Field(None, description="Token expiry time")
|
||||
scope: Optional[str] = Field(None, description="OAuth scope")
|
||||
client_id: Optional[str] = Field(None, description="OAuth client ID")
|
||||
client_secret: Optional[str] = Field(None, description="OAuth client secret")
|
||||
redirect_uri: Optional[str] = Field(None, description="OAuth redirect URI")
|
||||
status: Optional[OAuthSessionStatus] = Field(None, description="Session status")
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from composio.client import ComposioClientError, HTTPError, NoItemsFound
|
||||
@@ -11,27 +13,37 @@ from composio.exceptions import (
|
||||
EnumStringNotFound,
|
||||
)
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
from fastapi.responses import HTMLResponse
|
||||
from pydantic import BaseModel, Field
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.functions.functions import derive_openai_json_schema
|
||||
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.decorators import deprecated
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.orm.mcp_oauth import OAuthSessionStatus
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
||||
from letta.schemas.mcp import MCPOAuthSessionCreate, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.mcp.sse_client import AsyncSSEMCPClient
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
|
||||
from letta.services.mcp.oauth_utils import (
|
||||
MCPOAuthSession,
|
||||
create_oauth_provider,
|
||||
drill_down_exception,
|
||||
get_oauth_success_html,
|
||||
oauth_stream_event,
|
||||
)
|
||||
from letta.services.mcp.types import OauthStreamEvent
|
||||
from letta.settings import tool_settings
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
@@ -612,35 +624,26 @@ async def delete_mcp_server_from_config(
|
||||
return [server.to_config() for server in all_servers]
|
||||
|
||||
|
||||
@router.post("/mcp/servers/test", response_model=List[MCPTool], operation_id="test_mcp_server")
|
||||
@deprecated("Deprecated in favor of /mcp/servers/connect which handles OAuth flow via SSE stream")
|
||||
@router.post("/mcp/servers/test", operation_id="test_mcp_server")
|
||||
async def test_mcp_server(
|
||||
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Test connection to an MCP server without adding it.
|
||||
Returns the list of available tools if successful.
|
||||
Returns the list of available tools if successful, or OAuth information if OAuth is required.
|
||||
"""
|
||||
client = None
|
||||
try:
|
||||
# create a temporary MCP client based on the server type
|
||||
if request.type == MCPServerType.SSE:
|
||||
if not isinstance(request, SSEServerConfig):
|
||||
request = SSEServerConfig(**request.model_dump())
|
||||
client = AsyncSSEMCPClient(request)
|
||||
elif request.type == MCPServerType.STREAMABLE_HTTP:
|
||||
if not isinstance(request, StreamableHTTPServerConfig):
|
||||
request = StreamableHTTPServerConfig(**request.model_dump())
|
||||
client = AsyncStreamableHTTPMCPClient(request)
|
||||
elif request.type == MCPServerType.STDIO:
|
||||
if not isinstance(request, StdioServerConfig):
|
||||
request = StdioServerConfig(**request.model_dump())
|
||||
client = AsyncStdioMCPClient(request)
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server type: {request.type}")
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
client = await server.mcp_manager.get_mcp_client(request, actor)
|
||||
|
||||
await client.connect_to_server()
|
||||
tools = await client.list_tools()
|
||||
return tools
|
||||
|
||||
return {"status": "success", "tools": tools}
|
||||
except ConnectionError as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
@@ -676,6 +679,155 @@ async def test_mcp_server(
|
||||
logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/mcp/servers/connect",
|
||||
response_model=None,
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
operation_id="connect_mcp_server",
|
||||
)
|
||||
async def connect_mcp_server(
|
||||
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
) -> StreamingResponse:
|
||||
"""
|
||||
Connect to an MCP server with support for OAuth via SSE.
|
||||
Returns a stream of events handling authorization state and exchange if OAuth is required.
|
||||
"""
|
||||
|
||||
async def oauth_stream_generator(
|
||||
request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig]
|
||||
) -> AsyncGenerator[str, None]:
|
||||
client = None
|
||||
oauth_provider = None
|
||||
temp_client = None
|
||||
connect_task = None
|
||||
|
||||
try:
|
||||
# Acknolwedge connection attempt
|
||||
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# Create MCP client with respective transport type
|
||||
try:
|
||||
client = await server.mcp_manager.get_mcp_client(request, actor)
|
||||
except ValueError as e:
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
|
||||
return
|
||||
|
||||
# Try normal connection first for flows that don't require OAuth
|
||||
try:
|
||||
await client.connect_to_server()
|
||||
tools = await client.list_tools(serialize=True)
|
||||
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
||||
return
|
||||
except ConnectionError:
|
||||
# Continue to OAuth flow
|
||||
logger.info(f"Attempting OAuth flow for {request.server_url}...")
|
||||
except Exception as e:
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
|
||||
return
|
||||
|
||||
# OAuth required, yield state to client to prepare to handle authorization URL
|
||||
yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required")
|
||||
|
||||
# Create OAuth session to persist the state of the OAuth flow
|
||||
session_create = MCPOAuthSessionCreate(
|
||||
server_url=request.server_url,
|
||||
server_name=request.server_name,
|
||||
user_id=actor.id,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
oauth_session = await server.mcp_manager.create_oauth_session(session_create, actor)
|
||||
session_id = oauth_session.id
|
||||
|
||||
# Create OAuth provider for the instance of the stream connection
|
||||
# Note: Using the correct API path for the callback
|
||||
# do not edit this this is the correct url
|
||||
redirect_uri = f"http://localhost:8283/v1/tools/mcp/oauth/callback/{session_id}"
|
||||
oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, server.mcp_manager, actor)
|
||||
|
||||
# Get authorization URL by triggering OAuth flow
|
||||
temp_client = None
|
||||
try:
|
||||
temp_client = await server.mcp_manager.get_mcp_client(request, actor, oauth_provider)
|
||||
|
||||
# Run connect_to_server in background to avoid blocking
|
||||
# This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database
|
||||
connect_task = asyncio.create_task(temp_client.connect_to_server())
|
||||
|
||||
# Give the OAuth flow time to trigger and save the URL
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL
|
||||
auth_session = await server.mcp_manager.get_oauth_session_by_id(session_id, actor)
|
||||
if auth_session and auth_session.authorization_url:
|
||||
yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error triggering OAuth flow: {e}")
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}")
|
||||
|
||||
# Clean up active resources
|
||||
if connect_task and not connect_task.done():
|
||||
connect_task.cancel()
|
||||
try:
|
||||
await connect_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if temp_client:
|
||||
try:
|
||||
await temp_client.cleanup()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|
||||
return
|
||||
|
||||
# Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit
|
||||
yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...")
|
||||
|
||||
# Callback handler will poll for authorization code and state and update the OAuth session
|
||||
await connect_task
|
||||
|
||||
tools = await temp_client.list_tools(serialize=True)
|
||||
|
||||
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
||||
return
|
||||
except Exception as e:
|
||||
detailed_error = drill_down_exception(e)
|
||||
logger.error(f"Error in OAuth stream:\n{detailed_error}")
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
|
||||
finally:
|
||||
if connect_task and not connect_task.done():
|
||||
connect_task.cancel()
|
||||
try:
|
||||
await connect_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
if client:
|
||||
try:
|
||||
await client.cleanup()
|
||||
except Exception as cleanup_error:
|
||||
detailed_error = drill_down_exception(cleanup_error)
|
||||
logger.warning(f"Error during MCP client cleanup: {detailed_error}")
|
||||
if temp_client:
|
||||
try:
|
||||
await temp_client.cleanup()
|
||||
except Exception as cleanup_error:
|
||||
# TODO: @jnjpng fix async cancel scope issue
|
||||
# detailed_error = drill_down_exception(cleanup_error)
|
||||
logger.warning(f"Aysnc cleanup confict during temp MCP client cleanup: {cleanup_error}")
|
||||
|
||||
return StreamingResponseWithStatusCode(oauth_stream_generator(request), media_type="text/event-stream")
|
||||
|
||||
|
||||
class CodeInput(BaseModel):
|
||||
code: str = Field(..., description="Python source code to parse for JSON schema")
|
||||
|
||||
@@ -697,6 +849,45 @@ async def generate_json_schema(
|
||||
raise HTTPException(status_code=400, detail=f"Failed to generate schema: {str(e)}")
|
||||
|
||||
|
||||
# TODO: @jnjpng need to route this through cloud API for production
|
||||
@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback", response_class=HTMLResponse)
|
||||
async def mcp_oauth_callback(
|
||||
session_id: str,
|
||||
code: Optional[str] = Query(None, description="OAuth authorization code"),
|
||||
state: Optional[str] = Query(None, description="OAuth state parameter"),
|
||||
error: Optional[str] = Query(None, description="OAuth error"),
|
||||
error_description: Optional[str] = Query(None, description="OAuth error description"),
|
||||
):
|
||||
"""
|
||||
Handle OAuth callback for MCP server authentication.
|
||||
"""
|
||||
try:
|
||||
oauth_session = MCPOAuthSession(session_id)
|
||||
|
||||
if error:
|
||||
error_msg = f"OAuth error: {error}"
|
||||
if error_description:
|
||||
error_msg += f" - {error_description}"
|
||||
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": error_msg}
|
||||
|
||||
if not code or not state:
|
||||
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": "Missing authorization code or state"}
|
||||
|
||||
# Store authorization code
|
||||
success = await oauth_session.store_authorization_code(code, state)
|
||||
if not success:
|
||||
await oauth_session.update_session_status(OAuthSessionStatus.ERROR)
|
||||
return {"status": "error", "message": "Invalid state parameter"}
|
||||
|
||||
return HTMLResponse(content=get_oauth_success_html(), status_code=200)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OAuth callback error: {e}")
|
||||
return {"status": "error", "message": f"OAuth callback failed: {str(e)}"}
|
||||
|
||||
|
||||
class GenerateToolInput(BaseModel):
|
||||
tool_name: str = Field(..., description="Name of the tool to generate code for")
|
||||
prompt: str = Field(..., description="User prompt to generate code")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Optional, Tuple
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp import Tool as MCPTool
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.types import TextContent
|
||||
|
||||
from letta.functions.mcp_client.types import BaseServerConfig
|
||||
@@ -14,14 +14,12 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncBaseMCPClient:
|
||||
def __init__(self, server_config: BaseServerConfig):
|
||||
def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
self.server_config = server_config
|
||||
self.oauth_provider = oauth_provider
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.initialized = False
|
||||
# Track the task that created this client
|
||||
self._creation_task = asyncio.current_task()
|
||||
self._cleanup_queue = asyncio.Queue(maxsize=1)
|
||||
|
||||
async def connect_to_server(self):
|
||||
try:
|
||||
@@ -48,9 +46,25 @@ class AsyncBaseMCPClient:
|
||||
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
|
||||
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
async def list_tools(self, serialize: bool = False) -> list[MCPTool]:
|
||||
self._check_initialized()
|
||||
response = await self.session.list_tools()
|
||||
if serialize:
|
||||
serializable_tools = []
|
||||
for tool in response.tools:
|
||||
if hasattr(tool, "model_dump"):
|
||||
# Pydantic model - use model_dump
|
||||
serializable_tools.append(tool.model_dump())
|
||||
elif hasattr(tool, "dict"):
|
||||
# Older Pydantic model - use dict()
|
||||
serializable_tools.append(tool.dict())
|
||||
elif hasattr(tool, "__dict__"):
|
||||
# Regular object - use __dict__
|
||||
serializable_tools.append(tool.__dict__)
|
||||
else:
|
||||
# Fallback - convert to string
|
||||
serializable_tools.append(str(tool))
|
||||
return serializable_tools
|
||||
return response.tools
|
||||
|
||||
async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
@@ -79,29 +93,7 @@ class AsyncBaseMCPClient:
|
||||
|
||||
# TODO: still hitting some async errors for voice agents, need to fix
|
||||
async def cleanup(self):
|
||||
"""Clean up resources - ensure this runs in the same task"""
|
||||
if hasattr(self, "_cleanup_task"):
|
||||
# If we're in a different task, schedule cleanup in original task
|
||||
current_task = asyncio.current_task()
|
||||
if current_task != self._creation_task:
|
||||
# Create a future to signal completion
|
||||
cleanup_done = asyncio.Future()
|
||||
self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done))
|
||||
await cleanup_done
|
||||
return
|
||||
|
||||
# Normal cleanup
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
def to_sync_client(self):
|
||||
raise NotImplementedError("Subclasses must implement to_sync_client")
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the async context manager."""
|
||||
await self.connect_to_server()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit the async context manager."""
|
||||
await self.cleanup()
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
433
letta/services/mcp/oauth_utils.py
Normal file
433
letta/services/mcp/oauth_utils.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""OAuth utilities for MCP server authentication."""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import secrets
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Callable, Optional, Tuple
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
from sqlalchemy import select
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
|
||||
from letta.schemas.mcp import MCPOAuthSessionUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.mcp.types import OauthStreamEvent
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class DatabaseTokenStorage(TokenStorage):
|
||||
"""Database-backed token storage using MCPOAuth table via mcp_manager."""
|
||||
|
||||
def __init__(self, session_id: str, mcp_manager: MCPManager, actor: PydanticUser):
|
||||
self.session_id = session_id
|
||||
self.mcp_manager = mcp_manager
|
||||
self.actor = actor
|
||||
|
||||
async def get_tokens(self) -> Optional[OAuthToken]:
|
||||
"""Retrieve tokens from database."""
|
||||
oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor)
|
||||
if not oauth_session or not oauth_session.access_token:
|
||||
return None
|
||||
|
||||
return OAuthToken(
|
||||
access_token=oauth_session.access_token,
|
||||
refresh_token=oauth_session.refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_in=int(oauth_session.expires_at.timestamp() - time.time()),
|
||||
scope=oauth_session.scope,
|
||||
)
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
"""Store tokens in database."""
|
||||
session_update = MCPOAuthSessionUpdate(
|
||||
access_token=tokens.access_token,
|
||||
refresh_token=tokens.refresh_token,
|
||||
token_type=tokens.token_type,
|
||||
expires_at=datetime.fromtimestamp(tokens.expires_in + time.time()),
|
||||
scope=tokens.scope,
|
||||
status=OAuthSessionStatus.AUTHORIZED,
|
||||
)
|
||||
await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
|
||||
|
||||
async def get_client_info(self) -> Optional[OAuthClientInformationFull]:
|
||||
"""Retrieve client information from database."""
|
||||
oauth_session = await self.mcp_manager.get_oauth_session_by_id(self.session_id, self.actor)
|
||||
if not oauth_session or not oauth_session.client_id:
|
||||
return None
|
||||
|
||||
return OAuthClientInformationFull(
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=oauth_session.client_secret,
|
||||
redirect_uris=[oauth_session.redirect_uri] if oauth_session.redirect_uri else [],
|
||||
)
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
"""Store client information in database."""
|
||||
session_update = MCPOAuthSessionUpdate(
|
||||
client_id=client_info.client_id,
|
||||
client_secret=client_info.client_secret,
|
||||
redirect_uri=str(client_info.redirect_uris[0]) if client_info.redirect_uris else None,
|
||||
)
|
||||
await self.mcp_manager.update_oauth_session(self.session_id, session_update, self.actor)
|
||||
|
||||
|
||||
class MCPOAuthSession:
|
||||
"""Legacy OAuth session class - deprecated, use mcp_manager directly."""
|
||||
|
||||
def __init__(self, server_url: str, server_name: str, user_id: Optional[str], organization_id: str):
|
||||
self.server_url = server_url
|
||||
self.server_name = server_name
|
||||
self.user_id = user_id
|
||||
self.organization_id = organization_id
|
||||
self.session_id = str(uuid.uuid4())
|
||||
self.state = secrets.token_urlsafe(32)
|
||||
|
||||
def __init__(self, session_id: str):
|
||||
self.session_id = session_id
|
||||
|
||||
# TODO: consolidate / deprecate this in favor of mcp_manager access
|
||||
async def create_session(self) -> str:
|
||||
"""Create a new OAuth session in the database."""
|
||||
async with db_registry.async_session() as session:
|
||||
oauth_record = MCPOAuth(
|
||||
id=self.session_id,
|
||||
state=self.state,
|
||||
server_url=self.server_url,
|
||||
server_name=self.server_name,
|
||||
user_id=self.user_id,
|
||||
organization_id=self.organization_id,
|
||||
status=OAuthSessionStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
oauth_record = await oauth_record.create_async(session, actor=None)
|
||||
|
||||
return self.session_id
|
||||
|
||||
async def get_session_status(self) -> OAuthSessionStatus:
|
||||
"""Get the current status of the OAuth session."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
|
||||
return oauth_record.status
|
||||
except Exception:
|
||||
return OAuthSessionStatus.ERROR
|
||||
|
||||
async def update_session_status(self, status: OAuthSessionStatus) -> None:
|
||||
"""Update the session status."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
|
||||
oauth_record.status = status
|
||||
oauth_record.updated_at = datetime.now()
|
||||
await oauth_record.update_async(db_session=session, actor=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def store_authorization_code(self, code: str, state: str) -> bool:
|
||||
"""Store the authorization code from OAuth callback."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
|
||||
|
||||
# if oauth_record.state != state:
|
||||
# return False
|
||||
|
||||
oauth_record.authorization_code = code
|
||||
oauth_record.state = state
|
||||
oauth_record.status = OAuthSessionStatus.AUTHORIZED
|
||||
oauth_record.updated_at = datetime.now()
|
||||
await oauth_record.update_async(db_session=session, actor=None)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def get_authorization_url(self) -> Optional[str]:
|
||||
"""Get the authorization URL for this session."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
|
||||
return oauth_record.authorization_url
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
async def set_authorization_url(self, url: str) -> None:
|
||||
"""Set the authorization URL for this session."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None)
|
||||
oauth_record.authorization_url = url
|
||||
oauth_record.updated_at = datetime.now()
|
||||
await oauth_record.update_async(db_session=session, actor=None)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def create_oauth_provider(
|
||||
session_id: str,
|
||||
server_url: str,
|
||||
redirect_uri: str,
|
||||
mcp_manager: MCPManager,
|
||||
actor: PydanticUser,
|
||||
url_callback: Optional[Callable[[str], None]] = None,
|
||||
) -> OAuthClientProvider:
|
||||
"""Create an OAuth provider for MCP server authentication."""
|
||||
|
||||
client_metadata_dict = {
|
||||
"client_name": "Letta MCP Client",
|
||||
"redirect_uris": [redirect_uri],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
}
|
||||
|
||||
# Use manager-based storage
|
||||
storage = DatabaseTokenStorage(session_id, mcp_manager, actor)
|
||||
|
||||
# Extract base URL (remove /mcp endpoint if present)
|
||||
oauth_server_url = server_url.rstrip("/").removesuffix("/sse").removesuffix("/mcp")
|
||||
|
||||
async def redirect_handler(authorization_url: str) -> None:
|
||||
"""Handle OAuth redirect by storing the authorization URL."""
|
||||
logger.info(f"OAuth redirect handler called with URL: {authorization_url}")
|
||||
session_update = MCPOAuthSessionUpdate(authorization_url=authorization_url)
|
||||
await mcp_manager.update_oauth_session(session_id, session_update, actor)
|
||||
logger.info(f"OAuth authorization URL stored: {authorization_url}")
|
||||
|
||||
# Call the callback if provided (e.g., to yield URL to SSE stream)
|
||||
if url_callback:
|
||||
url_callback(authorization_url)
|
||||
|
||||
async def callback_handler() -> Tuple[str, Optional[str]]:
|
||||
"""Handle OAuth callback by waiting for authorization code."""
|
||||
timeout = 300 # 5 minutes
|
||||
start_time = time.time()
|
||||
|
||||
logger.info(f"Waiting for authorization code for session {session_id}")
|
||||
while time.time() - start_time < timeout:
|
||||
oauth_session = await mcp_manager.get_oauth_session_by_id(session_id, actor)
|
||||
if oauth_session and oauth_session.authorization_code:
|
||||
return oauth_session.authorization_code, oauth_session.state
|
||||
elif oauth_session and oauth_session.status == OAuthSessionStatus.ERROR:
|
||||
raise Exception("OAuth authorization failed")
|
||||
await asyncio.sleep(1)
|
||||
|
||||
raise Exception(f"Timeout waiting for OAuth callback after {timeout} seconds")
|
||||
|
||||
return OAuthClientProvider(
|
||||
server_url=oauth_server_url,
|
||||
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
|
||||
storage=storage,
|
||||
redirect_handler=redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
)
|
||||
|
||||
|
||||
async def cleanup_expired_oauth_sessions(max_age_hours: int = 24) -> None:
|
||||
"""Clean up expired OAuth sessions."""
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time))
|
||||
expired_sessions = result.scalars().all()
|
||||
|
||||
for oauth_session in expired_sessions:
|
||||
await oauth_session.hard_delete_async(db_session=session, actor=None)
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions")
|
||||
|
||||
|
||||
def oauth_stream_event(event: OauthStreamEvent, **kwargs) -> str:
|
||||
data = {"event": event.value}
|
||||
data.update(kwargs)
|
||||
return f"data: {json.dumps(data)}\n\n"
|
||||
|
||||
|
||||
def drill_down_exception(exception, depth=0, max_depth=5):
|
||||
"""Recursively drill down into nested exceptions to find the root cause"""
|
||||
indent = " " * depth
|
||||
error_details = []
|
||||
|
||||
error_details.append(f"{indent}Exception at depth {depth}:")
|
||||
error_details.append(f"{indent} Type: {type(exception).__name__}")
|
||||
error_details.append(f"{indent} Message: {str(exception)}")
|
||||
error_details.append(f"{indent} Module: {getattr(type(exception), '__module__', 'unknown')}")
|
||||
|
||||
# Check for exception groups (TaskGroup errors)
|
||||
if hasattr(exception, "exceptions") and exception.exceptions:
|
||||
error_details.append(f"{indent} ExceptionGroup with {len(exception.exceptions)} sub-exceptions:")
|
||||
for i, sub_exc in enumerate(exception.exceptions):
|
||||
error_details.append(f"{indent} Sub-exception {i}:")
|
||||
if depth < max_depth:
|
||||
error_details.extend(drill_down_exception(sub_exc, depth + 1, max_depth))
|
||||
|
||||
# Check for chained exceptions (__cause__ and __context__)
|
||||
if hasattr(exception, "__cause__") and exception.__cause__ and depth < max_depth:
|
||||
error_details.append(f"{indent} Caused by:")
|
||||
error_details.extend(drill_down_exception(exception.__cause__, depth + 1, max_depth))
|
||||
|
||||
if hasattr(exception, "__context__") and exception.__context__ and depth < max_depth:
|
||||
error_details.append(f"{indent} Context:")
|
||||
error_details.extend(drill_down_exception(exception.__context__, depth + 1, max_depth))
|
||||
|
||||
# Add traceback info
|
||||
import traceback
|
||||
|
||||
if hasattr(exception, "__traceback__") and exception.__traceback__:
|
||||
tb_lines = traceback.format_tb(exception.__traceback__)
|
||||
error_details.append(f"{indent} Traceback:")
|
||||
for line in tb_lines[-3:]: # Show last 3 traceback lines
|
||||
error_details.append(f"{indent} {line.strip()}")
|
||||
|
||||
error_info = "".join(error_details)
|
||||
return error_info
|
||||
|
||||
|
||||
def get_oauth_success_html() -> str:
|
||||
"""Generate HTML for successful OAuth authorization."""
|
||||
return """
|
||||
<!DOCTYPE html>
|
||||
<html>
|
||||
<head>
|
||||
<title>Authorization Successful - Letta</title>
|
||||
<style>
|
||||
* {
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
box-sizing: border-box;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
|
||||
display: flex;
|
||||
justify-content: center;
|
||||
align-items: center;
|
||||
min-height: 100vh;
|
||||
margin: 0;
|
||||
background-color: #f5f5f5;
|
||||
background-image: url("data:image/svg+xml,%3Csvg width='1440' height='860' viewBox='0 0 1440 860' fill='none' xmlns='http://www.w3.org/2000/svg'%3E%3Cg clip-path='url(%23clip0_14823_146864)'%3E%3Cpath d='M720.001 1003.14C1080.62 1003.14 1372.96 824.028 1372.96 603.083C1372.96 382.138 1080.62 203.026 720.001 203.026C359.384 203.026 67.046 382.138 67.046 603.083C67.046 824.028 359.384 1003.14 720.001 1003.14Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 978.04C910.334 978.04 1064.63 883.505 1064.63 766.891C1064.63 650.276 910.334 555.741 719.999 555.741C529.665 555.741 375.368 650.276 375.368 766.891C375.368 883.505 529.665 978.04 719.999 978.04Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720 1020.95C1262.17 1020.95 1701.68 756.371 1701.68 430C1701.68 103.629 1262.17 -160.946 720 -160.946C177.834 -160.946 -261.678 103.629 -261.678 430C-261.678 756.371 177.834 1020.95 720 1020.95Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 323.658C910.334 323.658 1064.63 223.814 1064.63 100.649C1064.63 -22.5157 910.334 -122.36 719.999 -122.36C529.665 -122.36 375.368 -22.5157 375.368 100.649C375.368 223.814 529.665 323.658 719.999 323.658Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720.001 706.676C1080.62 706.676 1372.96 517.507 1372.96 284.155C1372.96 50.8029 1080.62 -138.366 720.001 -138.366C359.384 -138.366 67.046 50.8029 67.046 284.155C67.046 517.507 359.384 706.676 720.001 706.676Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 874.604C1180.69 874.604 1554.15 645.789 1554.15 363.531C1554.15 81.2725 1180.69 -147.543 719.999 -147.543C259.311 -147.543 -114.15 81.2725 -114.15 363.531C-114.15 645.789 259.311 874.604 719.999 874.604Z' stroke='%23E1E2E3' stroke-width='1.5' stroke-miterlimit='10'/%3E%3C/g%3E%3Cdefs%3E%3CclipPath id='clip0_14823_146864'%3E%3Crect width='1440' height='860' fill='white'/%3E%3C/clipPath%3E%3C/defs%3E%3C/svg%3E");
|
||||
background-size: cover;
|
||||
background-position: center;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
|
||||
.card {
|
||||
text-align: center;
|
||||
padding: 48px;
|
||||
background: white;
|
||||
border-radius: 8px;
|
||||
border: 1px solid #E1E2E3;
|
||||
max-width: 400px;
|
||||
width: 90%;
|
||||
position: relative;
|
||||
z-index: 1;
|
||||
}
|
||||
|
||||
.logo {
|
||||
width: 48px;
|
||||
height: 48px;
|
||||
margin: 0 auto 24px;
|
||||
display: block;
|
||||
}
|
||||
|
||||
.logo svg {
|
||||
width: 100%;
|
||||
height: 100%;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 20px;
|
||||
font-weight: 600;
|
||||
color: #101010;
|
||||
margin-bottom: 12px;
|
||||
line-height: 1.2;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: #666;
|
||||
font-size: 12px;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 24px;
|
||||
line-height: 1.5;
|
||||
}
|
||||
|
||||
.close-info {
|
||||
font-size: 12px;
|
||||
color: #999;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 8px;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
width: 16px;
|
||||
height: 16px;
|
||||
border: 2px solid #E1E2E3;
|
||||
border-top: 2px solid #333;
|
||||
border-radius: 50%;
|
||||
animation: spin 1s linear infinite;
|
||||
}
|
||||
|
||||
@keyframes spin {
|
||||
0% { transform: rotate(0deg); }
|
||||
100% { transform: rotate(360deg); }
|
||||
}
|
||||
|
||||
/* Dark mode styles */
|
||||
@media (prefers-color-scheme: dark) {
|
||||
body {
|
||||
background-color: #101010;
|
||||
background-image: url("data:image/svg+xml,%3Csvg width='1440' height='860' viewBox='0 0 1440 860' fill='none' xmlns='http://www.w3.org/2000/svg'%3E%3Cg clip-path='url(%23clip0_14833_149362)'%3E%3Cpath d='M720.001 1003.14C1080.62 1003.14 1372.96 824.028 1372.96 603.083C1372.96 382.138 1080.62 203.026 720.001 203.026C359.384 203.026 67.046 382.138 67.046 603.083C67.046 824.028 359.384 1003.14 720.001 1003.14Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 978.04C910.334 978.04 1064.63 883.505 1064.63 766.891C1064.63 650.276 910.334 555.741 719.999 555.741C529.665 555.741 375.368 650.276 375.368 766.891C375.368 883.505 529.665 978.04 719.999 978.04Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720 1020.95C1262.17 1020.95 1701.68 756.371 1701.68 430C1701.68 103.629 1262.17 -160.946 720 -160.946C177.834 -160.946 -261.678 103.629 -261.678 430C-261.678 756.371 177.834 1020.95 720 1020.95Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 323.658C910.334 323.658 1064.63 223.814 1064.63 100.649C1064.63 -22.5157 910.334 -122.36 719.999 -122.36C529.665 -122.36 375.368 -22.5157 375.368 100.649C375.368 223.814 529.665 323.658 719.999 323.658Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M720.001 706.676C1080.62 706.676 1372.96 517.507 1372.96 284.155C1372.96 50.8029 1080.62 -138.366 720.001 -138.366C359.384 -138.366 67.046 50.8029 67.046 284.155C67.046 517.507 359.384 706.676 720.001 706.676Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3Cpath d='M719.999 874.604C1180.69 874.604 1554.15 645.789 1554.15 363.531C1554.15 81.2725 1180.69 -147.543 719.999 -147.543C259.311 -147.543 -114.15 81.2725 -114.15 363.531C-114.15 645.789 259.311 874.604 719.999 874.604Z' stroke='%2346484A' stroke-width='1.5' stroke-miterlimit='10'/%3E%3C/g%3E%3Cdefs%3E%3CclipPath id='clip0_14833_149362'%3E%3Crect width='1440' height='860' fill='white'/%3E%3C/clipPath%3E%3C/defs%3E%3C/svg%3E");
|
||||
}
|
||||
|
||||
.card {
|
||||
background-color: #141414;
|
||||
border-color: #202020;
|
||||
}
|
||||
|
||||
h1 {
|
||||
color: #E1E2E3;
|
||||
}
|
||||
|
||||
.subtitle {
|
||||
color: #999;
|
||||
}
|
||||
|
||||
.logo svg path {
|
||||
fill: #E1E2E3;
|
||||
}
|
||||
|
||||
.spinner {
|
||||
border-color: #46484A;
|
||||
border-top-color: #E1E2E3;
|
||||
}
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="card">
|
||||
<div class="logo">
|
||||
<svg width="48" height="48" viewBox="0 0 18 18" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M10.7134 7.30028H7.28759V10.7002H10.7134V7.30028Z" fill="#333"/>
|
||||
<path d="M14.1391 2.81618V0.5H3.86131V2.81618C3.86131 3.41495 3.37266 3.89991 2.76935 3.89991H0.435547V14.1001H2.76935C3.37266 14.1001 3.86131 14.5851 3.86131 15.1838V17.5H14.1391V15.1838C14.1391 14.5851 14.6277 14.1001 15.231 14.1001H17.5648V3.89991H15.231C14.6277 3.89991 14.1391 3.41495 14.1391 2.81618ZM14.1391 13.0159C14.1391 13.6147 13.6504 14.0996 13.0471 14.0996H4.95375C4.35043 14.0996 3.86179 13.6147 3.86179 13.0159V4.98363C3.86179 4.38486 4.35043 3.89991 4.95375 3.89991H13.0471C13.6504 3.89991 14.1391 4.38486 14.1391 4.98363V13.0159Z" fill="#333"/>
|
||||
</svg>
|
||||
</div>
|
||||
<h3>Authorization Successful</h3>
|
||||
<p class="subtitle">You have successfully connected your MCP server.</p>
|
||||
<div class="close-info">
|
||||
<span>You can now close this window.</span>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from letta.functions.mcp_client.types import SSEServerConfig
|
||||
@@ -13,6 +16,9 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
super().__init__(server_config, oauth_provider)
|
||||
|
||||
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
|
||||
headers = {}
|
||||
if server_config.custom_headers:
|
||||
@@ -21,7 +27,12 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None)
|
||||
# Use OAuth provider if available, otherwise use regular headers
|
||||
if self.oauth_provider:
|
||||
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider)
|
||||
else:
|
||||
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None)
|
||||
|
||||
sse_transport = await self.exit_stack.enter_async_context(sse_cm)
|
||||
self.stdio, self.write = sse_transport
|
||||
|
||||
|
||||
@@ -1,4 +1,7 @@
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.auth import OAuthClientProvider
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from letta.functions.mcp_client.types import BaseServerConfig, StreamableHTTPServerConfig
|
||||
@@ -9,10 +12,12 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
super().__init__(server_config, oauth_provider)
|
||||
|
||||
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
|
||||
if not isinstance(server_config, StreamableHTTPServerConfig):
|
||||
raise ValueError("Expected StreamableHTTPServerConfig")
|
||||
|
||||
try:
|
||||
# Prepare headers for authentication
|
||||
headers = {}
|
||||
@@ -23,11 +28,18 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
# Use streamablehttp_client context manager with headers if provided
|
||||
if headers:
|
||||
streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers)
|
||||
# Use OAuth provider if available, otherwise use regular headers
|
||||
if self.oauth_provider:
|
||||
streamable_http_cm = streamablehttp_client(
|
||||
server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider
|
||||
)
|
||||
else:
|
||||
streamable_http_cm = streamablehttp_client(server_config.server_url)
|
||||
# Use streamablehttp_client context manager with headers if provided
|
||||
if headers:
|
||||
streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers)
|
||||
else:
|
||||
streamable_http_cm = streamablehttp_client(server_config.server_url)
|
||||
|
||||
read_stream, write_stream, _ = await self.exit_stack.enter_async_context(streamable_http_cm)
|
||||
|
||||
# Create and enter the ClientSession context manager
|
||||
|
||||
@@ -46,3 +46,12 @@ class StdioServerConfig(BaseServerConfig):
|
||||
if self.env is not None:
|
||||
values["env"] = self.env
|
||||
return values
|
||||
|
||||
|
||||
class OauthStreamEvent(str, Enum):
|
||||
CONNECTION_ATTEMPT = "connection_attempt"
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
OAUTH_REQUIRED = "oauth_required"
|
||||
AUTHORIZATION_URL = "authorization_url"
|
||||
WAITING_FOR_AUTH = "waiting_for_auth"
|
||||
|
||||
@@ -1,5 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from sqlalchemy import null
|
||||
@@ -8,8 +11,18 @@ import letta.constants as constants
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus
|
||||
from letta.orm.mcp_server import MCPServer as MCPServerModel
|
||||
from letta.schemas.mcp import MCPServer, UpdateMCPServer, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
||||
from letta.schemas.mcp import (
|
||||
MCPOAuthSession,
|
||||
MCPOAuthSessionCreate,
|
||||
MCPOAuthSessionUpdate,
|
||||
MCPServer,
|
||||
UpdateMCPServer,
|
||||
UpdateSSEMCPServer,
|
||||
UpdateStdioMCPServer,
|
||||
UpdateStreamableHTTPMCPServer,
|
||||
)
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -38,14 +51,7 @@ class MCPManager:
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
|
||||
if mcp_config.server_type == MCPServerType.SSE:
|
||||
mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
elif mcp_config.server_type == MCPServerType.STDIO:
|
||||
mcp_client = AsyncStdioMCPClient(server_config=server_config)
|
||||
elif mcp_config.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server type: {mcp_config.server_type}")
|
||||
mcp_client = await self.get_mcp_client(server_config, actor)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# list tools
|
||||
@@ -72,28 +78,20 @@ class MCPManager:
|
||||
# read from config file
|
||||
mcp_config = self.read_mcp_config()
|
||||
if mcp_server_name not in mcp_config:
|
||||
print("MCP server not found in config.", mcp_config)
|
||||
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
|
||||
server_config = mcp_config[mcp_server_name]
|
||||
|
||||
if isinstance(server_config, SSEServerConfig):
|
||||
# mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
async with AsyncSSEMCPClient(server_config=server_config) as mcp_client:
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
return result, success
|
||||
elif isinstance(server_config, StdioServerConfig):
|
||||
async with AsyncStdioMCPClient(server_config=server_config) as mcp_client:
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
return result, success
|
||||
elif isinstance(server_config, StreamableHTTPServerConfig):
|
||||
async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client:
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
return result, success
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
mcp_client = await self.get_mcp_client(server_config, actor)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# call tool
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
# TODO: change to pydantic tool
|
||||
|
||||
await mcp_client.cleanup()
|
||||
|
||||
return result, success
|
||||
|
||||
@enforce_types
|
||||
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
@@ -324,3 +322,246 @@ class MCPManager:
|
||||
logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
|
||||
continue
|
||||
return mcp_server_list
|
||||
|
||||
async def get_mcp_client(
|
||||
self,
|
||||
server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
|
||||
actor: PydanticUser,
|
||||
oauth_provider: Optional[Any] = None,
|
||||
) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]:
|
||||
"""
|
||||
Helper function to create the appropriate MCP client based on server configuration.
|
||||
|
||||
Args:
|
||||
server_config: The server configuration object
|
||||
actor: The user making the request
|
||||
oauth_provider: Optional OAuth provider for authentication
|
||||
|
||||
Returns:
|
||||
The appropriate MCP client instance
|
||||
|
||||
Raises:
|
||||
ValueError: If server config type is not supported
|
||||
"""
|
||||
# If no OAuth provider is provided, check if we have stored OAuth credentials
|
||||
if oauth_provider is None and hasattr(server_config, "server_url"):
|
||||
oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor)
|
||||
if oauth_session and oauth_session.access_token:
|
||||
# Create OAuth provider from stored credentials
|
||||
from letta.services.mcp.oauth_utils import create_oauth_provider
|
||||
|
||||
oauth_provider = await create_oauth_provider(
|
||||
session_id=oauth_session.id,
|
||||
server_url=oauth_session.server_url,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
mcp_manager=self,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
server_config = SSEServerConfig(**server_config.model_dump())
|
||||
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
server_config = StdioServerConfig(**server_config.model_dump())
|
||||
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
|
||||
server_config = StreamableHTTPServerConfig(**server_config.model_dump())
|
||||
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
# OAuth-related methods
|
||||
@enforce_types
|
||||
async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
"""Create a new OAuth session for MCP server authentication."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Create the OAuth session with a unique state
|
||||
oauth_session = MCPOAuth(
|
||||
id="mcp-oauth-" + str(uuid.uuid4())[:8],
|
||||
state=secrets.token_urlsafe(32),
|
||||
server_url=session_create.server_url,
|
||||
server_name=session_create.server_name,
|
||||
user_id=session_create.user_id,
|
||||
organization_id=session_create.organization_id,
|
||||
status=OAuthSessionStatus.PENDING,
|
||||
created_at=datetime.now(),
|
||||
updated_at=datetime.now(),
|
||||
)
|
||||
oauth_session = await oauth_session.create_async(session, actor=actor)
|
||||
|
||||
# Convert to Pydantic model
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
"""Get an OAuth session by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=oauth_session.authorization_code,
|
||||
access_token=oauth_session.access_token,
|
||||
refresh_token=oauth_session.refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=oauth_session.client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]:
|
||||
"""Get the latest OAuth session by server URL, organization, and user."""
|
||||
from sqlalchemy import desc, select
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Query for OAuth session matching organization, user, server URL, and status
|
||||
# Order by updated_at desc to get the most recent record
|
||||
result = await session.execute(
|
||||
select(MCPOAuth)
|
||||
.where(
|
||||
MCPOAuth.organization_id == actor.organization_id,
|
||||
MCPOAuth.user_id == actor.id,
|
||||
MCPOAuth.server_url == server_url,
|
||||
MCPOAuth.status == OAuthSessionStatus.AUTHORIZED,
|
||||
)
|
||||
.order_by(desc(MCPOAuth.updated_at))
|
||||
.limit(1)
|
||||
)
|
||||
oauth_session = result.scalar_one_or_none()
|
||||
|
||||
if not oauth_session:
|
||||
return None
|
||||
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=oauth_session.authorization_code,
|
||||
access_token=oauth_session.access_token,
|
||||
refresh_token=oauth_session.refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=oauth_session.client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession:
|
||||
"""Update an existing OAuth session."""
|
||||
async with db_registry.async_session() as session:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
|
||||
# Update fields that are provided
|
||||
if session_update.authorization_url is not None:
|
||||
oauth_session.authorization_url = session_update.authorization_url
|
||||
if session_update.authorization_code is not None:
|
||||
oauth_session.authorization_code = session_update.authorization_code
|
||||
if session_update.access_token is not None:
|
||||
oauth_session.access_token = session_update.access_token
|
||||
if session_update.refresh_token is not None:
|
||||
oauth_session.refresh_token = session_update.refresh_token
|
||||
if session_update.token_type is not None:
|
||||
oauth_session.token_type = session_update.token_type
|
||||
if session_update.expires_at is not None:
|
||||
oauth_session.expires_at = session_update.expires_at
|
||||
if session_update.scope is not None:
|
||||
oauth_session.scope = session_update.scope
|
||||
if session_update.client_id is not None:
|
||||
oauth_session.client_id = session_update.client_id
|
||||
if session_update.client_secret is not None:
|
||||
oauth_session.client_secret = session_update.client_secret
|
||||
if session_update.redirect_uri is not None:
|
||||
oauth_session.redirect_uri = session_update.redirect_uri
|
||||
if session_update.status is not None:
|
||||
oauth_session.status = session_update.status
|
||||
|
||||
# Always update the updated_at timestamp
|
||||
oauth_session.updated_at = datetime.now()
|
||||
|
||||
oauth_session = await oauth_session.update_async(db_session=session, actor=actor)
|
||||
|
||||
return MCPOAuthSession(
|
||||
id=oauth_session.id,
|
||||
state=oauth_session.state,
|
||||
server_url=oauth_session.server_url,
|
||||
server_name=oauth_session.server_name,
|
||||
user_id=oauth_session.user_id,
|
||||
organization_id=oauth_session.organization_id,
|
||||
authorization_url=oauth_session.authorization_url,
|
||||
authorization_code=oauth_session.authorization_code,
|
||||
access_token=oauth_session.access_token,
|
||||
refresh_token=oauth_session.refresh_token,
|
||||
token_type=oauth_session.token_type,
|
||||
expires_at=oauth_session.expires_at,
|
||||
scope=oauth_session.scope,
|
||||
client_id=oauth_session.client_id,
|
||||
client_secret=oauth_session.client_secret,
|
||||
redirect_uri=oauth_session.redirect_uri,
|
||||
status=oauth_session.status,
|
||||
created_at=oauth_session.created_at,
|
||||
updated_at=oauth_session.updated_at,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete an OAuth session."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor)
|
||||
await oauth_session.hard_delete_async(db_session=session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"OAuth session with id {session_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
async def cleanup_expired_oauth_sessions(self, max_age_hours: int = 24) -> int:
|
||||
"""Clean up expired OAuth sessions and return the count of deleted sessions."""
|
||||
cutoff_time = datetime.now() - timedelta(hours=max_age_hours)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
from sqlalchemy import select
|
||||
|
||||
# Find expired sessions
|
||||
result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time))
|
||||
expired_sessions = result.scalars().all()
|
||||
|
||||
# Delete expired sessions using async ORM method
|
||||
for oauth_session in expired_sessions:
|
||||
await oauth_session.hard_delete_async(db_session=session, actor=None)
|
||||
|
||||
if expired_sessions:
|
||||
logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions")
|
||||
|
||||
return len(expired_sessions)
|
||||
|
||||
356
mcp_test.py
Normal file
356
mcp_test.py
Normal file
@@ -0,0 +1,356 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Simple MCP client example with OAuth authentication support.
|
||||
|
||||
This client connects to an MCP server using streamable HTTP transport with OAuth.
|
||||
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
import webbrowser
|
||||
from datetime import timedelta
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
from mcp.client.auth import OAuthClientProvider, TokenStorage
|
||||
from mcp.client.session import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
from mcp.shared.auth import OAuthClientInformationFull, OAuthClientMetadata, OAuthToken
|
||||
|
||||
|
||||
class InMemoryTokenStorage(TokenStorage):
|
||||
"""Simple in-memory token storage implementation."""
|
||||
|
||||
def __init__(self):
|
||||
self._tokens: OAuthToken | None = None
|
||||
self._client_info: OAuthClientInformationFull | None = None
|
||||
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
return self._tokens
|
||||
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
self._tokens = tokens
|
||||
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
return self._client_info
|
||||
|
||||
async def set_client_info(self, client_info: OAuthClientInformationFull) -> None:
|
||||
self._client_info = client_info
|
||||
|
||||
|
||||
class CallbackHandler(BaseHTTPRequestHandler):
|
||||
"""Simple HTTP handler to capture OAuth callback."""
|
||||
|
||||
def __init__(self, request, client_address, server, callback_data):
|
||||
"""Initialize with callback data storage."""
|
||||
self.callback_data = callback_data
|
||||
super().__init__(request, client_address, server)
|
||||
|
||||
def do_GET(self):
|
||||
"""Handle GET request from OAuth redirect."""
|
||||
parsed = urlparse(self.path)
|
||||
query_params = parse_qs(parsed.query)
|
||||
|
||||
if "code" in query_params:
|
||||
self.callback_data["authorization_code"] = query_params["code"][0]
|
||||
self.callback_data["state"] = query_params.get("state", [None])[0]
|
||||
self.send_response(200)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
b"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>Authorization Successful!</h1>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
<script>setTimeout(() => window.close(), 2000);</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
)
|
||||
elif "error" in query_params:
|
||||
self.callback_data["error"] = query_params["error"][0]
|
||||
self.send_response(400)
|
||||
self.send_header("Content-type", "text/html")
|
||||
self.end_headers()
|
||||
self.wfile.write(
|
||||
f"""
|
||||
<html>
|
||||
<body>
|
||||
<h1>Authorization Failed</h1>
|
||||
<p>Error: {query_params['error'][0]}</p>
|
||||
<p>You can close this window and return to the terminal.</p>
|
||||
</body>
|
||||
</html>
|
||||
""".encode()
|
||||
)
|
||||
else:
|
||||
self.send_response(404)
|
||||
self.end_headers()
|
||||
|
||||
def log_message(self, format, *args):
|
||||
"""Suppress default logging."""
|
||||
|
||||
|
||||
class CallbackServer:
|
||||
"""Simple server to handle OAuth callbacks."""
|
||||
|
||||
def __init__(self, port=3000):
|
||||
self.port = port
|
||||
self.server = None
|
||||
self.thread = None
|
||||
self.callback_data = {"authorization_code": None, "state": None, "error": None}
|
||||
|
||||
def _create_handler_with_data(self):
|
||||
"""Create a handler class with access to callback data."""
|
||||
callback_data = self.callback_data
|
||||
|
||||
class DataCallbackHandler(CallbackHandler):
|
||||
def __init__(self, request, client_address, server):
|
||||
super().__init__(request, client_address, server, callback_data)
|
||||
|
||||
return DataCallbackHandler
|
||||
|
||||
def start(self):
|
||||
"""Start the callback server in a background thread."""
|
||||
handler_class = self._create_handler_with_data()
|
||||
self.server = HTTPServer(("localhost", self.port), handler_class)
|
||||
self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
|
||||
self.thread.start()
|
||||
print(f"🖥️ Started callback server on http://localhost:{self.port}")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the callback server."""
|
||||
if self.server:
|
||||
self.server.shutdown()
|
||||
self.server.server_close()
|
||||
if self.thread:
|
||||
self.thread.join(timeout=1)
|
||||
|
||||
def wait_for_callback(self, timeout=300):
|
||||
"""Wait for OAuth callback with timeout."""
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < timeout:
|
||||
if self.callback_data["authorization_code"]:
|
||||
return self.callback_data["authorization_code"]
|
||||
elif self.callback_data["error"]:
|
||||
raise Exception(f"OAuth error: {self.callback_data['error']}")
|
||||
time.sleep(0.1)
|
||||
raise Exception("Timeout waiting for OAuth callback")
|
||||
|
||||
def get_state(self):
|
||||
"""Get the received state parameter."""
|
||||
return self.callback_data["state"]
|
||||
|
||||
|
||||
class SimpleAuthClient:
|
||||
"""Simple MCP client with auth support."""
|
||||
|
||||
def __init__(self, server_url: str, transport_type: str = "streamable_http"):
|
||||
self.server_url = server_url
|
||||
self.transport_type = transport_type
|
||||
self.session: ClientSession | None = None
|
||||
|
||||
async def connect(self):
|
||||
"""Connect to the MCP server."""
|
||||
print(f"🔗 Attempting to connect to {self.server_url}...")
|
||||
|
||||
try:
|
||||
callback_server = CallbackServer(port=3030)
|
||||
callback_server.start()
|
||||
|
||||
async def callback_handler() -> tuple[str, str | None]:
|
||||
"""Wait for OAuth callback and return auth code and state."""
|
||||
print("⏳ Waiting for authorization callback...")
|
||||
try:
|
||||
auth_code = callback_server.wait_for_callback(timeout=300)
|
||||
return auth_code, callback_server.get_state()
|
||||
finally:
|
||||
callback_server.stop()
|
||||
|
||||
client_metadata_dict = {
|
||||
"client_name": "Simple Auth Client",
|
||||
"redirect_uris": ["http://localhost:3030/callback"],
|
||||
"grant_types": ["authorization_code", "refresh_token"],
|
||||
"response_types": ["code"],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
}
|
||||
|
||||
async def _default_redirect_handler(authorization_url: str) -> None:
|
||||
"""Default redirect handler that opens the URL in a browser."""
|
||||
print(f"Opening browser for authorization: {authorization_url}")
|
||||
webbrowser.open(authorization_url)
|
||||
|
||||
# Create OAuth authentication handler using the new interface
|
||||
oauth_auth = OAuthClientProvider(
|
||||
server_url=self.server_url.replace("/mcp", ""),
|
||||
client_metadata=OAuthClientMetadata.model_validate(client_metadata_dict),
|
||||
storage=InMemoryTokenStorage(),
|
||||
redirect_handler=_default_redirect_handler,
|
||||
callback_handler=callback_handler,
|
||||
)
|
||||
|
||||
# Create transport with auth handler based on transport type
|
||||
if self.transport_type == "sse":
|
||||
print("📡 Opening SSE transport connection with auth...")
|
||||
async with sse_client(
|
||||
url=self.server_url,
|
||||
auth=oauth_auth,
|
||||
timeout=60,
|
||||
) as (read_stream, write_stream):
|
||||
await self._run_session(read_stream, write_stream, None)
|
||||
else:
|
||||
print("📡 Opening StreamableHTTP transport connection with auth...")
|
||||
async with streamablehttp_client(
|
||||
url=self.server_url,
|
||||
auth=oauth_auth,
|
||||
timeout=timedelta(seconds=60),
|
||||
) as (read_stream, write_stream, get_session_id):
|
||||
await self._run_session(read_stream, write_stream, get_session_id)
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to connect: {e}")
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
|
||||
async def _run_session(self, read_stream, write_stream, get_session_id):
|
||||
"""Run the MCP session with the given streams."""
|
||||
print("🤝 Initializing MCP session...")
|
||||
async with ClientSession(read_stream, write_stream) as session:
|
||||
self.session = session
|
||||
print("⚡ Starting session initialization...")
|
||||
await session.initialize()
|
||||
print("✨ Session initialization complete!")
|
||||
|
||||
print(f"\n✅ Connected to MCP server at {self.server_url}")
|
||||
if get_session_id:
|
||||
session_id = get_session_id()
|
||||
if session_id:
|
||||
print(f"Session ID: {session_id}")
|
||||
|
||||
# Run interactive loop
|
||||
await self.interactive_loop()
|
||||
|
||||
async def list_tools(self):
|
||||
"""List available tools from the server."""
|
||||
if not self.session:
|
||||
print("❌ Not connected to server")
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self.session.list_tools()
|
||||
if hasattr(result, "tools") and result.tools:
|
||||
print("\n📋 Available tools:")
|
||||
for i, tool in enumerate(result.tools, 1):
|
||||
print(f"{i}. {tool.name}")
|
||||
if tool.description:
|
||||
print(f" Description: {tool.description}")
|
||||
print()
|
||||
else:
|
||||
print("No tools available")
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to list tools: {e}")
|
||||
|
||||
async def call_tool(self, tool_name: str, arguments: dict[str, Any] | None = None):
|
||||
"""Call a specific tool."""
|
||||
if not self.session:
|
||||
print("❌ Not connected to server")
|
||||
return
|
||||
|
||||
try:
|
||||
result = await self.session.call_tool(tool_name, arguments or {})
|
||||
print(f"\n🔧 Tool '{tool_name}' result:")
|
||||
if hasattr(result, "content"):
|
||||
for content in result.content:
|
||||
if content.type == "text":
|
||||
print(content.text)
|
||||
else:
|
||||
print(content)
|
||||
else:
|
||||
print(result)
|
||||
except Exception as e:
|
||||
print(f"❌ Failed to call tool '{tool_name}': {e}")
|
||||
|
||||
async def interactive_loop(self):
|
||||
"""Run interactive command loop."""
|
||||
print("\n🎯 Interactive MCP Client")
|
||||
print("Commands:")
|
||||
print(" list - List available tools")
|
||||
print(" call <tool_name> [args] - Call a tool")
|
||||
print(" quit - Exit the client")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
command = input("mcp> ").strip()
|
||||
|
||||
if not command:
|
||||
continue
|
||||
|
||||
if command == "quit":
|
||||
break
|
||||
|
||||
elif command == "list":
|
||||
await self.list_tools()
|
||||
|
||||
elif command.startswith("call "):
|
||||
parts = command.split(maxsplit=2)
|
||||
tool_name = parts[1] if len(parts) > 1 else ""
|
||||
|
||||
if not tool_name:
|
||||
print("❌ Please specify a tool name")
|
||||
continue
|
||||
|
||||
# Parse arguments (simple JSON-like format)
|
||||
arguments = {}
|
||||
if len(parts) > 2:
|
||||
import json
|
||||
|
||||
try:
|
||||
arguments = json.loads(parts[2])
|
||||
except json.JSONDecodeError:
|
||||
print("❌ Invalid arguments format (expected JSON)")
|
||||
continue
|
||||
|
||||
await self.call_tool(tool_name, arguments)
|
||||
|
||||
else:
|
||||
print("❌ Unknown command. Try 'list', 'call <tool_name>', or 'quit'")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\n👋 Goodbye!")
|
||||
break
|
||||
except EOFError:
|
||||
break
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main entry point."""
|
||||
# Default server URL - can be overridden with environment variable
|
||||
# Most MCP streamable HTTP servers use /mcp as the endpoint
|
||||
server_url = os.getenv("MCP_SERVER_PORT", 8000)
|
||||
transport_type = os.getenv("MCP_TRANSPORT_TYPE", "streamable_http")
|
||||
server_url = f"http://localhost:{server_url}/mcp" if transport_type == "streamable_http" else f"http://localhost:{server_url}/sse"
|
||||
|
||||
print("🚀 Simple MCP Auth Client")
|
||||
print(f"Connecting to: {server_url}")
|
||||
print(f"Transport type: {transport_type}")
|
||||
|
||||
# Start connection flow - OAuth will be handled automatically
|
||||
client = SimpleAuthClient(server_url, transport_type)
|
||||
await client.connect()
|
||||
|
||||
|
||||
def cli():
|
||||
"""CLI entry point for uv script."""
|
||||
asyncio.run(main())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli()
|
||||
Reference in New Issue
Block a user