From 167f30972d4cad10564bf925c7c5145c2d6ff4c5 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Fri, 16 Jan 2026 17:59:39 -0800 Subject: [PATCH] feat: add mcp oauth support via /mcp connect (#570) --- src/cli/App.tsx | 27 +- src/cli/commands/mcp.ts | 7 +- src/cli/commands/registry.ts | 2 +- src/cli/components/McpConnectFlow.tsx | 599 ++++++++++++++++++++++++++ src/cli/components/McpSelector.tsx | 35 +- src/cli/helpers/mcpOauth.ts | 151 +++++++ 6 files changed, 801 insertions(+), 20 deletions(-) create mode 100644 src/cli/components/McpConnectFlow.tsx create mode 100644 src/cli/helpers/mcpOauth.ts diff --git a/src/cli/App.tsx b/src/cli/App.tsx index 3441b53..54076c7 100644 --- a/src/cli/App.tsx +++ b/src/cli/App.tsx @@ -90,6 +90,7 @@ import { ErrorMessage } from "./components/ErrorMessageRich"; import { FeedbackDialog } from "./components/FeedbackDialog"; import { HelpDialog } from "./components/HelpDialog"; import { Input } from "./components/InputRich"; +import { McpConnectFlow } from "./components/McpConnectFlow"; import { McpSelector } from "./components/McpSelector"; import { MemoryTabViewer } from "./components/MemoryTabViewer"; import { MessageSearch } from "./components/MessageSearch"; @@ -819,6 +820,7 @@ export default function App({ | "pin" | "new" | "mcp" + | "mcp-connect" | "help" | null; const [activeOverlay, setActiveOverlay] = useState(null); @@ -3563,6 +3565,12 @@ export default function App({ return { submitted: true }; } + // /mcp connect - interactive TUI for connecting with OAuth + if (firstWord === "connect") { + setActiveOverlay("mcp-connect"); + return { submitted: true }; + } + // Unknown subcommand handleMcpUsage(mcpCtx, msg); return { submitted: true }; @@ -7739,15 +7747,28 @@ Plan file path: ${planFilePath}`; { - // Close overlay and prompt user to use /mcp add command + // Switch to the MCP connect flow + setActiveOverlay("mcp-connect"); + }} + onCancel={closeOverlay} + /> + )} + + {/* MCP Connect Flow - interactive TUI for OAuth connection */} + {activeOverlay === "mcp-connect" && ( + { closeOverlay(); const cmdId = uid("cmd"); buffersRef.current.byId.set(cmdId, { kind: "command", id: cmdId, - input: "/mcp", + input: "/mcp connect", output: - "Use /mcp add --transport [...] to add a new server", + `Successfully created MCP server "${serverName}"\n` + + `ID: ${serverId}\n` + + `Discovered ${toolCount} tool${toolCount === 1 ? "" : "s"}\n` + + "Open /mcp to attach or detach tools for this server.", phase: "finished", success: true, }); diff --git a/src/cli/commands/mcp.ts b/src/cli/commands/mcp.ts index 88f7f4e..11bd8cc 100644 --- a/src/cli/commands/mcp.ts +++ b/src/cli/commands/mcp.ts @@ -355,7 +355,12 @@ export function handleMcpUsage(ctx: McpCommandContext, msg: string): void { ctx.buffersRef, ctx.refreshDerived, msg, - 'Usage: /mcp [add ...]\n /mcp - list MCP servers\n /mcp add --transport [...] - add a new server\n\nExamples:\n /mcp add --transport http notion https://mcp.notion.com/mcp\n /mcp add --transport http api https://api.example.com --header "Authorization: Bearer token"', + "Usage: /mcp [subcommand ...]\n" + + " /mcp - Open MCP server manager\n" + + " /mcp add ... - Add a new server (without OAuth)\n" + + " /mcp connect - Interactive wizard with OAuth support\n\n" + + "Examples:\n" + + " /mcp add --transport http notion https://mcp.notion.com/mcp", false, ); } diff --git a/src/cli/commands/registry.ts b/src/cli/commands/registry.ts index 7dd05af..ac373c4 100644 --- a/src/cli/commands/registry.ts +++ b/src/cli/commands/registry.ts @@ -178,7 +178,7 @@ export const commands: Record = { }, }, "/mcp": { - desc: "Manage MCP servers", + desc: "Manage MCP servers (add, connect with OAuth)", order: 32, handler: () => { // Handled specially in App.tsx to show MCP server selector diff --git a/src/cli/components/McpConnectFlow.tsx b/src/cli/components/McpConnectFlow.tsx new file mode 100644 index 0000000..02983a6 --- /dev/null +++ b/src/cli/components/McpConnectFlow.tsx @@ -0,0 +1,599 @@ +/** + * Interactive TUI for connecting to MCP servers with OAuth support. + * Flow: Select transport → Enter URL → Connect (OAuth if needed) → Enter name → Create + */ + +import { Box, Text, useInput } from "ink"; +import { memo, useCallback, useState } from "react"; + +import { getClient } from "../../agent/client"; +import { + connectMcpServer, + type McpConnectConfig, + type McpTool, + OauthStreamEvent, +} from "../helpers/mcpOauth"; +import { useTerminalWidth } from "../hooks/useTerminalWidth"; +import { colors } from "./colors"; +import { PasteAwareTextInput } from "./PasteAwareTextInput"; + +const SOLID_LINE = "─"; + +// Validate URL (outside component to avoid useCallback dependency) +function validateUrl(url: string): string | null { + if (!url.trim()) { + return "URL is required"; + } + try { + const parsed = new URL(url); + if (!["http:", "https:"].includes(parsed.protocol)) { + return "URL must use http or https protocol"; + } + } catch { + return "Invalid URL format"; + } + return null; +} + +// Validate server name (outside component to avoid useCallback dependency) +function validateName(name: string): string | null { + if (!name.trim()) { + return "Server name is required"; + } + if (!/^[a-zA-Z0-9_-]+$/.test(name.trim())) { + return "Name can only contain letters, numbers, hyphens, and underscores"; + } + if (name.trim().length > 64) { + return "Name must be 64 characters or less"; + } + return null; +} + +interface McpConnectFlowProps { + onComplete: (serverName: string, serverId: string, toolCount: number) => void; + onCancel: () => void; +} + +type Step = + | "select-transport" + | "enter-url" + | "connecting" + | "enter-name" + | "creating"; + +type Transport = "http" | "sse"; + +const TRANSPORTS: { value: Transport; label: string; description: string }[] = [ + { + value: "http", + label: "Streamable HTTP", + description: "Modern HTTP-based transport (recommended)", + }, + { + value: "sse", + label: "Server-Sent Events", + description: "SSE-based transport for legacy servers", + }, +]; + +export const McpConnectFlow = memo(function McpConnectFlow({ + onComplete, + onCancel, +}: McpConnectFlowProps) { + const terminalWidth = useTerminalWidth(); + const solidLine = SOLID_LINE.repeat(Math.max(terminalWidth, 10)); + + // Step state + const [step, setStep] = useState("select-transport"); + + // Transport selection + const [transportIndex, setTransportIndex] = useState(0); + const [selectedTransport, setSelectedTransport] = useState( + null, + ); + + // URL input + const [urlInput, setUrlInput] = useState(""); + const [urlError, setUrlError] = useState(""); + + // Connection state + const [connectionStatus, setConnectionStatus] = useState(""); + const [authUrl, setAuthUrl] = useState(null); + const [discoveredTools, setDiscoveredTools] = useState([]); + const [connectionError, setConnectionError] = useState(null); + + // Name input + const [nameInput, setNameInput] = useState(""); + const [nameError, setNameError] = useState(""); + + // Creating state + const [creatingStatus, setCreatingStatus] = useState(""); + + // Handle transport selection + useInput( + (input, key) => { + if (key.ctrl && input === "c") { + onCancel(); + return; + } + + if (key.escape) { + onCancel(); + return; + } + + if (step === "select-transport") { + if (key.upArrow) { + setTransportIndex((prev) => Math.max(0, prev - 1)); + } else if (key.downArrow) { + setTransportIndex((prev) => + Math.min(TRANSPORTS.length - 1, prev + 1), + ); + } else if (key.return) { + const selected = TRANSPORTS[transportIndex]; + if (selected) { + setSelectedTransport(selected.value); + setStep("enter-url"); + } + } + } + }, + { isActive: step === "select-transport" }, + ); + + // Handle URL input escape + useInput( + (input, key) => { + if (key.ctrl && input === "c") { + onCancel(); + return; + } + + if (key.escape) { + // Go back to transport selection + setStep("select-transport"); + setUrlInput(""); + setUrlError(""); + } + }, + { isActive: step === "enter-url" }, + ); + + // Handle connection step escape + useInput( + (input, key) => { + if (key.ctrl && input === "c") { + onCancel(); + return; + } + + if (key.escape && connectionError) { + // Go back to URL input on error + setStep("enter-url"); + setConnectionError(null); + setConnectionStatus(""); + setAuthUrl(null); + } + }, + { isActive: step === "connecting" }, + ); + + // Handle name input escape + useInput( + (input, key) => { + if (key.ctrl && input === "c") { + onCancel(); + return; + } + + if (key.escape) { + // Go back to URL input + setStep("enter-url"); + setNameInput(""); + setNameError(""); + } + }, + { isActive: step === "enter-name" }, + ); + + // Handle URL submission + const handleUrlSubmit = useCallback( + async (text: string) => { + const trimmed = text.trim(); + const error = validateUrl(trimmed); + if (error) { + setUrlError(error); + return; + } + + setUrlError(""); + setStep("connecting"); + setConnectionStatus("Connecting..."); + setConnectionError(null); + setAuthUrl(null); + + const config: McpConnectConfig = { + server_name: "temp-connection-test", + type: selectedTransport === "http" ? "streamable_http" : "sse", + server_url: trimmed, + }; + + try { + const tools = await connectMcpServer(config, { + onEvent: (event) => { + switch (event.event) { + case OauthStreamEvent.CONNECTION_ATTEMPT: + setConnectionStatus("Connecting to server..."); + break; + case OauthStreamEvent.OAUTH_REQUIRED: + setConnectionStatus("OAuth authentication required..."); + break; + case OauthStreamEvent.AUTHORIZATION_URL: + if (event.url) { + const authorizationUrl = event.url; + setAuthUrl(authorizationUrl); + setConnectionStatus("Opening browser for authorization..."); + // Open browser + import("open") + .then(({ default: open }) => open(authorizationUrl)) + .catch(() => {}); + } + break; + case OauthStreamEvent.WAITING_FOR_AUTH: + setConnectionStatus("Waiting for authorization in browser..."); + break; + } + }, + }); + + // Success! + setDiscoveredTools(tools); + setConnectionStatus(""); + + // Generate default name from URL + try { + const parsed = new URL(trimmed); + const defaultName = + parsed.hostname.replace(/^(www|mcp|api)\./, "").split(".")[0] || + "mcp-server"; + setNameInput(defaultName); + } catch { + setNameInput("mcp-server"); + } + + setStep("enter-name"); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + setConnectionError(message); + setConnectionStatus(""); + } + }, + [selectedTransport], + ); + + // Handle name submission and create server + const handleNameSubmit = useCallback( + async (text: string) => { + const trimmed = text.trim(); + const error = validateName(trimmed); + if (error) { + setNameError(error); + return; + } + + setNameError(""); + setStep("creating"); + setCreatingStatus("Creating MCP server..."); + + try { + const client = await getClient(); + + const serverConfig = + selectedTransport === "http" + ? { + mcp_server_type: "streamable_http" as const, + server_url: urlInput.trim(), + } + : { + mcp_server_type: "sse" as const, + server_url: urlInput.trim(), + }; + + const server = await client.mcpServers.create({ + server_name: trimmed, + config: serverConfig, + }); + + onComplete(trimmed, server.id || "", discoveredTools.length); + } catch (err) { + const message = err instanceof Error ? err.message : String(err); + setNameError(`Failed to create server: ${message}`); + setStep("enter-name"); + setCreatingStatus(""); + } + }, + [selectedTransport, urlInput, discoveredTools.length, onComplete], + ); + + // Render transport selection step + if (step === "select-transport") { + return ( + + {"> /mcp connect"} + {solidLine} + + + + + Connect to MCP Server + + + + + + Select transport type: + + + + + + {TRANSPORTS.map((transport, index) => { + const isSelected = index === transportIndex; + return ( + + + + {isSelected ? "> " : " "} + {transport.label} + + + + {transport.description} + + + ); + })} + + + + + + ↑↓ navigate · Enter select · Esc cancel + + + ); + } + + // Render URL input step + if (step === "enter-url") { + const transportLabel = + TRANSPORTS.find((t) => t.value === selectedTransport)?.label || ""; + + return ( + + {"> /mcp connect"} + {solidLine} + + + + + Connect to MCP Server + + + + + + + Transport: {transportLabel} + + + + + + + Enter the server URL: + + + + {">"} + + { + setUrlInput(val); + setUrlError(""); + }} + onSubmit={handleUrlSubmit} + placeholder="https://mcp.example.com/mcp" + /> + + + {urlError && ( + + {urlError} + + )} + + + + + Enter submit · Esc back + + + ); + } + + // Render connecting step + if (step === "connecting") { + const transportLabel = + TRANSPORTS.find((t) => t.value === selectedTransport)?.label || ""; + + return ( + + {"> /mcp connect"} + {solidLine} + + + + + Connect to MCP Server + + + + + + + Transport: {transportLabel} + + + URL: {urlInput} + + + + + + {connectionStatus && ( + + {connectionStatus} + + )} + + {authUrl && ( + + Authorization URL: + {authUrl} + + )} + + {connectionError && ( + + Connection failed: + {connectionError} + + Esc to go back and try again + + + )} + + ); + } + + // Render name input step + if (step === "enter-name") { + const transportLabel = + TRANSPORTS.find((t) => t.value === selectedTransport)?.label || ""; + + return ( + + {"> /mcp connect"} + {solidLine} + + + + + Connect to MCP Server + + + + + + + Transport: {transportLabel} + + + URL: {urlInput} + + + + + + + + ✓ Connection successful! Discovered {discoveredTools.length} tool + {discoveredTools.length === 1 ? "" : "s"} + + + + {discoveredTools.length > 0 && ( + + {discoveredTools.slice(0, 5).map((tool) => ( + + • {tool.name} + + ))} + {discoveredTools.length > 5 && ( + ... and {discoveredTools.length - 5} more + )} + + )} + + + + + Enter a name for this server: + + + + {">"} + + { + setNameInput(val); + setNameError(""); + }} + onSubmit={handleNameSubmit} + placeholder="my-mcp-server" + /> + + + {nameError && ( + + {nameError} + + )} + + + + + Enter create · Esc back + + + ); + } + + // Render creating step + if (step === "creating") { + return ( + + {"> /mcp connect"} + {solidLine} + + + + + Connect to MCP Server + + + + + + {creatingStatus} + + + ); + } + + return null; +}); + +McpConnectFlow.displayName = "McpConnectFlow"; diff --git a/src/cli/components/McpSelector.tsx b/src/cli/components/McpSelector.tsx index 6643e03..b315832 100644 --- a/src/cli/components/McpSelector.tsx +++ b/src/cli/components/McpSelector.tsx @@ -109,6 +109,16 @@ export const McpSelector = memo(function McpSelector({ } }, []); + const fetchAttachedToolIds = useCallback( + async (client: Awaited>) => { + const agent = await client.agents.retrieve(agentId, { + include: ["agent.tools"], + }); + return new Set(agent.tools?.map((t) => t.id) || []); + }, + [agentId], + ); + // Load tools for a specific server const loadTools = useCallback( async (server: McpServer) => { @@ -138,8 +148,7 @@ export const McpSelector = memo(function McpSelector({ setTools(toolsList); // Fetch agent's current tools to check which are attached - const agent = await client.agents.retrieve(agentId); - const agentToolIds = new Set(agent.tools?.map((t) => t.id) || []); + const agentToolIds = await fetchAttachedToolIds(client); setAttachedToolIds(agentToolIds); setToolsPage(0); @@ -153,7 +162,7 @@ export const McpSelector = memo(function McpSelector({ setToolsLoading(false); } }, - [agentId], + [fetchAttachedToolIds], ); // Refresh tools from MCP server @@ -174,8 +183,7 @@ export const McpSelector = memo(function McpSelector({ setTools(toolsList); // Refresh agent's current tools - const agent = await client.agents.retrieve(agentId); - const agentToolIds = new Set(agent.tools?.map((t) => t.id) || []); + const agentToolIds = await fetchAttachedToolIds(client); setAttachedToolIds(agentToolIds); setToolsPage(0); @@ -194,7 +202,7 @@ export const McpSelector = memo(function McpSelector({ } finally { setToolsLoading(false); } - }, [agentId, viewingServer]); + }, [agentId, fetchAttachedToolIds, viewingServer]); // Toggle tool attachment const toggleTool = useCallback( @@ -213,8 +221,7 @@ export const McpSelector = memo(function McpSelector({ } // Fetch agent's current tools to get accurate total count - const agent = await client.agents.retrieve(agentId); - const agentToolIds = new Set(agent.tools?.map((t) => t.id) || []); + const agentToolIds = await fetchAttachedToolIds(client); setAttachedToolIds(agentToolIds); } catch (err) { setToolsError( @@ -226,7 +233,7 @@ export const McpSelector = memo(function McpSelector({ setIsTogglingTool(false); } }, - [agentId, attachedToolIds], + [agentId, attachedToolIds, fetchAttachedToolIds], ); // Attach all tools @@ -244,8 +251,7 @@ export const McpSelector = memo(function McpSelector({ ); // Fetch agent's current tools to get accurate total count - const agent = await client.agents.retrieve(agentId); - const agentToolIds = new Set(agent.tools?.map((t) => t.id) || []); + const agentToolIds = await fetchAttachedToolIds(client); setAttachedToolIds(agentToolIds); } catch (err) { setToolsError( @@ -254,7 +260,7 @@ export const McpSelector = memo(function McpSelector({ } finally { setIsTogglingTool(false); } - }, [agentId, tools, attachedToolIds]); + }, [agentId, tools, attachedToolIds, fetchAttachedToolIds]); // Detach all tools const detachAllTools = useCallback(async () => { @@ -271,8 +277,7 @@ export const McpSelector = memo(function McpSelector({ ); // Fetch agent's current tools to get accurate total count - const agent = await client.agents.retrieve(agentId); - const agentToolIds = new Set(agent.tools?.map((t) => t.id) || []); + const agentToolIds = await fetchAttachedToolIds(client); setAttachedToolIds(agentToolIds); } catch (err) { setToolsError( @@ -281,7 +286,7 @@ export const McpSelector = memo(function McpSelector({ } finally { setIsTogglingTool(false); } - }, [agentId, tools, attachedToolIds]); + }, [agentId, tools, attachedToolIds, fetchAttachedToolIds]); useEffect(() => { loadServers(); diff --git a/src/cli/helpers/mcpOauth.ts b/src/cli/helpers/mcpOauth.ts new file mode 100644 index 0000000..2a43e29 --- /dev/null +++ b/src/cli/helpers/mcpOauth.ts @@ -0,0 +1,151 @@ +/** + * MCP OAuth SSE client for connecting to MCP servers that require OAuth authentication. + * Uses the /v1/tools/mcp/servers/connect SSE streaming endpoint. + */ + +import { getServerUrl } from "../../agent/client"; +import { settingsManager } from "../../settings-manager"; + +// Match backend's OauthStreamEvent enum +export enum OauthStreamEvent { + CONNECTION_ATTEMPT = "connection_attempt", + SUCCESS = "success", + ERROR = "error", + OAUTH_REQUIRED = "oauth_required", + AUTHORIZATION_URL = "authorization_url", + WAITING_FOR_AUTH = "waiting_for_auth", +} + +export interface McpOauthEvent { + event: OauthStreamEvent; + url?: string; // For AUTHORIZATION_URL + session_id?: string; // For tracking + tools?: McpTool[]; // For SUCCESS + message?: string; // For ERROR/info + server_name?: string; // Server name +} + +export interface McpTool { + name: string; + description?: string; + inputSchema?: Record; +} + +export interface McpConnectConfig { + server_name: string; + type: "sse" | "streamable_http"; + server_url: string; + auth_header?: string; + auth_token?: string; + custom_headers?: Record; +} + +export interface McpConnectOptions { + onEvent?: (event: McpOauthEvent) => void; + abortSignal?: AbortSignal; +} + +/** + * Connect to an MCP server with OAuth support via SSE streaming. + * Returns the list of available tools on success. + * + * The flow: + * 1. Opens SSE stream to /v1/tools/mcp/servers/connect + * 2. Receives CONNECTION_ATTEMPT event + * 3. If OAuth is required: + * - Receives OAUTH_REQUIRED event + * - Receives AUTHORIZATION_URL event with OAuth URL + * - Receives WAITING_FOR_AUTH event + * - Caller should open browser with the URL + * - After user authorizes, receives SUCCESS event + * 4. Returns tools array on SUCCESS, throws on ERROR + */ +export async function connectMcpServer( + config: McpConnectConfig, + options: McpConnectOptions = {}, +): Promise { + const { onEvent, abortSignal } = options; + + const settings = await settingsManager.getSettingsWithSecureTokens(); + const baseUrl = getServerUrl(); + const apiKey = process.env.LETTA_API_KEY || settings.env?.LETTA_API_KEY; + + if (!apiKey) { + throw new Error("Missing LETTA_API_KEY"); + } + + const response = await fetch(`${baseUrl}/v1/tools/mcp/servers/connect`, { + method: "POST", + headers: { + "Content-Type": "application/json", + Accept: "text/event-stream", + Authorization: `Bearer ${apiKey}`, + "X-Letta-Source": "letta-code", + }, + body: JSON.stringify(config), + signal: abortSignal, + }); + + if (!response.ok) { + const errorText = await response.text(); + throw new Error(`Connection failed (${response.status}): ${errorText}`); + } + + const reader = response.body?.getReader(); + if (!reader) { + throw new Error("Failed to get response stream reader"); + } + + const decoder = new TextDecoder(); + let buffer = ""; + + try { + while (true) { + const { done, value } = await reader.read(); + if (done) { + throw new Error("Stream ended unexpectedly without success or error"); + } + + buffer += decoder.decode(value, { stream: true }); + const lines = buffer.split("\n"); + buffer = lines.pop() || ""; + + for (const line of lines) { + if (!line.trim() || line.trim() === "[DONE]") continue; + + let data = line; + if (line.startsWith("data: ")) { + data = line.slice(6); + } + + if (data.trim() === "[DONE]") continue; + + try { + const event = JSON.parse(data) as McpOauthEvent; + onEvent?.(event); + + switch (event.event) { + case OauthStreamEvent.SUCCESS: + return event.tools || []; + + case OauthStreamEvent.ERROR: + throw new Error(event.message || "Connection failed"); + + case OauthStreamEvent.AUTHORIZATION_URL: + // Event handler should open browser + // Continue processing stream for WAITING_FOR_AUTH and SUCCESS + break; + + // Other events are informational (CONNECTION_ATTEMPT, OAUTH_REQUIRED, WAITING_FOR_AUTH) + } + } catch (parseError) { + // Skip unparseable lines (might be partial SSE data) + if (parseError instanceof SyntaxError) continue; + throw parseError; + } + } + } + } finally { + reader.releaseLock(); + } +}