diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 4f47470e..d2f8de46 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -12,8 +12,7 @@ from composio.exceptions import ( EnumMetadataNotFound, EnumStringNotFound, ) -from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query -from fastapi.responses import HTMLResponse +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, Request from pydantic import BaseModel, Field from starlette.responses import StreamingResponse @@ -38,16 +37,10 @@ from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer -from letta.services.mcp.oauth_utils import ( - MCPOAuthSession, - create_oauth_provider, - drill_down_exception, - get_oauth_success_html, - oauth_stream_event, -) +from letta.services.mcp.oauth_utils import MCPOAuthSession, create_oauth_provider, drill_down_exception, oauth_stream_event from letta.services.mcp.stdio_client import AsyncStdioMCPClient from letta.services.mcp.types import OauthStreamEvent -from letta.settings import tool_settings +from letta.settings import settings, tool_settings router = APIRouter(prefix="/tools", tags=["tools"]) @@ -700,6 +693,7 @@ async def connect_mcp_server( request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), + http_request: Request = None, ) -> StreamingResponse: """ Connect to an MCP server with support for OAuth via SSE. @@ -708,6 +702,7 @@ async def connect_mcp_server( async def oauth_stream_generator( request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], + http_request: Request, ) -> AsyncGenerator[str, None]: client = None oauth_provider = None @@ -759,11 +754,32 @@ async def connect_mcp_server( oauth_session = await server.mcp_manager.create_oauth_session(session_create, actor) session_id = oauth_session.id + # TODO: @jnjpng make this check more robust + # Check if request is from web frontend to determine redirect URI + is_web_request = ( + http_request + and http_request.headers + and http_request.headers.get("user-agent", "") == "Next.js Middleware" + and http_request.headers.__contains__("x-organization-id") + ) + + logo_uri = None + NEXT_PUBLIC_CURRENT_HOST = settings.next_public_current_host + LETTA_AGENTS_ENDPOINT = settings.letta_agents_endpoint + + if is_web_request and NEXT_PUBLIC_CURRENT_HOST: + redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}" + logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg" + elif LETTA_AGENTS_ENDPOINT: + # API and SDK usage should call core server directly + redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}" + else: + raise HTTPException(status_code=400, detail="No redirect URI found") + # Create OAuth provider for the instance of the stream connection - # Note: Using the correct API path for the callback - # do not edit this this is the correct url - redirect_uri = f"http://localhost:8283/v1/tools/mcp/oauth/callback/{session_id}" - oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, server.mcp_manager, actor) + oauth_provider = await create_oauth_provider( + session_id, request.server_url, redirect_uri, server.mcp_manager, actor, logo_uri=logo_uri + ) # Get authorization URL by triggering OAuth flow temp_client = None @@ -835,7 +851,7 @@ async def connect_mcp_server( # detailed_error = drill_down_exception(cleanup_error) logger.warning(f"Aysnc cleanup confict during temp MCP client cleanup: {cleanup_error}") - return StreamingResponseWithStatusCode(oauth_stream_generator(request), media_type="text/event-stream") + return StreamingResponseWithStatusCode(oauth_stream_generator(request, http_request), media_type="text/event-stream") class CodeInput(BaseModel): @@ -860,7 +876,7 @@ async def generate_json_schema( # TODO: @jnjpng need to route this through cloud API for production -@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback", response_class=HTMLResponse) +@router.get("/mcp/oauth/callback/{session_id}", operation_id="mcp_oauth_callback") async def mcp_oauth_callback( session_id: str, code: Optional[str] = Query(None, description="OAuth authorization code"), @@ -873,7 +889,6 @@ async def mcp_oauth_callback( """ try: oauth_session = MCPOAuthSession(session_id) - if error: error_msg = f"OAuth error: {error}" if error_description: @@ -891,7 +906,7 @@ async def mcp_oauth_callback( await oauth_session.update_session_status(OAuthSessionStatus.ERROR) return {"status": "error", "message": "Invalid state parameter"} - return HTMLResponse(content=get_oauth_success_html(), status_code=200) + return {"status": "success", "message": "Authorization successful", "server_url": success.server_url} except Exception as e: logger.error(f"OAuth callback error: {e}") diff --git a/letta/services/mcp/oauth_utils.py b/letta/services/mcp/oauth_utils.py index 7391474c..cab6f833 100644 --- a/letta/services/mcp/oauth_utils.py +++ b/letta/services/mcp/oauth_utils.py @@ -132,23 +132,18 @@ class MCPOAuthSession: except Exception: pass - async def store_authorization_code(self, code: str, state: str) -> bool: + async def store_authorization_code(self, code: str, state: str) -> Optional[MCPOAuth]: """Store the authorization code from OAuth callback.""" async with db_registry.async_session() as session: try: oauth_record = await MCPOAuth.read_async(db_session=session, identifier=self.session_id, actor=None) - - # if oauth_record.state != state: - # return False - oauth_record.authorization_code = code oauth_record.state = state oauth_record.status = OAuthSessionStatus.AUTHORIZED oauth_record.updated_at = datetime.now() - await oauth_record.update_async(db_session=session, actor=None) - return True + return await oauth_record.update_async(db_session=session, actor=None) except Exception: - return False + return None async def get_authorization_url(self) -> Optional[str]: """Get the authorization URL for this session.""" @@ -177,16 +172,18 @@ async def create_oauth_provider( redirect_uri: str, mcp_manager: MCPManager, actor: PydanticUser, + logo_uri: Optional[str] = None, url_callback: Optional[Callable[[str], None]] = None, ) -> OAuthClientProvider: """Create an OAuth provider for MCP server authentication.""" client_metadata_dict = { - "client_name": "Letta MCP Client", + "client_name": "Letta", "redirect_uris": [redirect_uri], "grant_types": ["authorization_code", "refresh_token"], "response_types": ["code"], "token_endpoint_auth_method": "client_secret_post", + "logo_uri": logo_uri, } # Use manager-based storage @@ -290,144 +287,3 @@ def drill_down_exception(exception, depth=0, max_depth=5): error_info = "".join(error_details) return error_info - - -def get_oauth_success_html() -> str: - """Generate HTML for successful OAuth authorization.""" - return """ - - - - Authorization Successful - Letta - - - -
- -

Authorization Successful

-

You have successfully connected your MCP server.

-
- You can now close this window. -
-
- - -""" diff --git a/letta/settings.py b/letta/settings.py index 16110d04..7997c33c 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -267,6 +267,10 @@ class Settings(BaseSettings): # for OCR mistral_api_key: Optional[str] = None + # OAuth redirect URLs + letta_agents_endpoint: Optional[str] = os.getenv("LETTA_AGENTS_ENDPOINT") + next_public_current_host: Optional[str] = os.getenv("NEXT_PUBLIC_CURRENT_HOST") + # LLM request timeout settings (model + embedding model) llm_request_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM requests in seconds") llm_stream_timeout_seconds: float = Field(default=60.0, ge=10.0, le=1800.0, description="Timeout for LLM streaming requests in seconds")