MCP server connection failures were raising Python's builtin ConnectionError, which bypassed the LettaMCPConnectionError FastAPI exception handler and hit Datadog as unhandled 500 errors. Now all MCP client classes convert ConnectionError to LettaMCPConnectionError at the source, which the existing exception handler returns as a user-friendly 502. Datadog: https://us5.datadoghq.com/error-tracking/issue/93db4a82-fe5a-11f0-85f0-da7ad0900000 🐛 Generated with [Letta Code](https://letta.com) Co-authored-by: Letta <noreply@letta.com>
313 lines
12 KiB
Python
313 lines
12 KiB
Python
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
|
|
|
from fastapi import APIRouter, Body, Depends, HTTPException, Request
|
|
from httpx import HTTPStatusError
|
|
from starlette.responses import StreamingResponse
|
|
|
|
from letta.errors import LettaMCPConnectionError
|
|
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
|
from letta.log import get_logger
|
|
from letta.schemas.letta_message import ToolReturnMessage
|
|
from letta.schemas.mcp_server import (
|
|
CreateMCPServerRequest,
|
|
MCPServerUnion,
|
|
ToolExecuteRequest,
|
|
UpdateMCPServerRequest,
|
|
convert_generic_to_union,
|
|
convert_update_to_internal,
|
|
)
|
|
from letta.schemas.tool import Tool
|
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
|
from letta.server.rest_api.dependencies import (
|
|
HeaderParams,
|
|
get_headers,
|
|
get_letta_server,
|
|
)
|
|
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
|
from letta.server.server import SyncServer
|
|
from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event
|
|
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
|
from letta.services.mcp.types import OauthStreamEvent
|
|
from letta.settings import tool_settings
|
|
|
|
router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"])
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@router.post(
|
|
"/",
|
|
response_model=MCPServerUnion,
|
|
operation_id="mcp_create_mcp_server",
|
|
)
|
|
async def create_mcp_server(
|
|
request: CreateMCPServerRequest = Body(...),
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Add a new MCP server to the Letta MCP server config
|
|
"""
|
|
# TODO: add the tools to the MCP server table we made.
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
new_server = await server.mcp_server_manager.create_mcp_server_from_request(request, actor=actor)
|
|
return await convert_generic_to_union(new_server)
|
|
|
|
|
|
@router.get(
|
|
"/",
|
|
response_model=List[MCPServerUnion],
|
|
operation_id="mcp_list_mcp_servers",
|
|
)
|
|
async def list_mcp_servers(
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a list of all configured MCP servers
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
|
|
result = []
|
|
for mcp_server in mcp_servers:
|
|
result.append(await convert_generic_to_union(mcp_server))
|
|
return result
|
|
|
|
|
|
@router.get(
|
|
"/{mcp_server_id}",
|
|
response_model=MCPServerUnion,
|
|
operation_id="mcp_retrieve_mcp_server",
|
|
)
|
|
async def retrieve_mcp_server(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a specific MCP server
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
current_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
|
return await convert_generic_to_union(current_server)
|
|
|
|
|
|
@router.delete(
|
|
"/{mcp_server_id}",
|
|
status_code=204,
|
|
operation_id="mcp_delete_mcp_server",
|
|
)
|
|
async def delete_mcp_server(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Delete an MCP server by its ID
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
await server.mcp_server_manager.delete_mcp_server_by_id(mcp_server_id, actor=actor)
|
|
|
|
|
|
@router.patch(
|
|
"/{mcp_server_id}",
|
|
response_model=MCPServerUnion,
|
|
operation_id="mcp_update_mcp_server",
|
|
)
|
|
async def update_mcp_server(
|
|
mcp_server_id: str,
|
|
request: UpdateMCPServerRequest = Body(...),
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Update an existing MCP server configuration
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
# Convert external update payload to internal manager union
|
|
internal_update = convert_update_to_internal(request)
|
|
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
|
|
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
|
|
)
|
|
return await convert_generic_to_union(updated_server)
|
|
|
|
|
|
@router.get("/{mcp_server_id}/tools", response_model=List[Tool], operation_id="mcp_list_tools_for_mcp_server")
|
|
async def list_tools_for_mcp_server(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a list of all tools for a specific MCP server
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
# Use the new efficient method that queries from the database using MCPTools mapping
|
|
tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor)
|
|
return tools
|
|
|
|
|
|
@router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_retrieve_mcp_tool")
|
|
async def retrieve_mcp_tool(
|
|
mcp_server_id: str,
|
|
tool_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a specific MCP tool by its ID
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor)
|
|
return tool
|
|
|
|
|
|
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool")
|
|
async def run_mcp_tool(
|
|
mcp_server_id: str,
|
|
tool_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
request: ToolExecuteRequest = Body(default=ToolExecuteRequest()),
|
|
):
|
|
"""
|
|
Execute a specific MCP tool
|
|
|
|
The request body should contain the tool arguments in the ToolExecuteRequest format.
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
|
|
# Execute the tool
|
|
result, success = await server.mcp_server_manager.execute_mcp_server_tool(
|
|
mcp_server_id=mcp_server_id,
|
|
tool_id=tool_id,
|
|
tool_args=request.args,
|
|
environment_variables={}, # TODO: Get environment variables from somewhere if needed
|
|
actor=actor,
|
|
)
|
|
|
|
# Create a ToolExecutionResult
|
|
return ToolExecutionResult(
|
|
status="success" if success else "error",
|
|
func_return=result,
|
|
)
|
|
|
|
|
|
@router.patch("/{mcp_server_id}/refresh", operation_id="mcp_refresh_mcp_server_tools")
|
|
async def refresh_mcp_server_tools(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
agent_id: Optional[str] = None,
|
|
):
|
|
"""
|
|
Refresh tools for an MCP server by:
|
|
1. Fetching current tools from the MCP server
|
|
2. Deleting tools that no longer exist on the server
|
|
3. Updating schemas for existing tools
|
|
4. Adding new tools from the server
|
|
|
|
Returns a summary of changes made.
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
result = await server.mcp_server_manager.resync_mcp_server_tools(mcp_server_id, actor=actor, agent_id=agent_id)
|
|
return result
|
|
|
|
|
|
@router.get(
|
|
"/connect/{mcp_server_id}",
|
|
response_model=None,
|
|
# TODO: make this into a model?
|
|
responses={
|
|
200: {
|
|
"description": "Successful response",
|
|
"content": {
|
|
"text/event-stream": {"description": "Server-Sent Events stream"},
|
|
},
|
|
}
|
|
},
|
|
operation_id="mcp_connect_mcp_server",
|
|
)
|
|
async def connect_mcp_server(
|
|
mcp_server_id: str,
|
|
request: Request,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
) -> StreamingResponse:
|
|
"""
|
|
Connect to an MCP server with support for OAuth via SSE.
|
|
Returns a stream of events handling authorization state and exchange if OAuth is required.
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
|
|
|
# Convert the MCP server to the appropriate config type
|
|
config = await mcp_server.to_config_async(resolve_variables=False)
|
|
|
|
async def oauth_stream_generator(
|
|
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
|
|
http_request: Request,
|
|
) -> AsyncGenerator[str, None]:
|
|
client = None
|
|
|
|
oauth_flow_attempted = False
|
|
try:
|
|
# Acknowledge connection attempt
|
|
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name)
|
|
|
|
# Create MCP client with respective transport type
|
|
try:
|
|
mcp_config.resolve_environment_variables()
|
|
client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor)
|
|
except ValueError as e:
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
|
|
return
|
|
|
|
# Try normal connection first for flows that don't require OAuth
|
|
try:
|
|
await client.connect_to_server()
|
|
tools = await client.list_tools(serialize=True)
|
|
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
|
return
|
|
except (ConnectionError, LettaMCPConnectionError):
|
|
if isinstance(client, AsyncStdioMCPClient):
|
|
logger.warning("OAuth not supported for stdio")
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
|
|
return
|
|
# Continue to OAuth flow
|
|
logger.info(f"Attempting OAuth flow for {mcp_config}...")
|
|
except Exception as e:
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
|
|
return
|
|
finally:
|
|
if client:
|
|
try:
|
|
await client.cleanup()
|
|
# This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
|
|
# For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
|
|
except HTTPStatusError:
|
|
oauth_flow_attempted = True
|
|
async for event in server.mcp_server_manager.handle_oauth_flow(
|
|
request=mcp_config, actor=actor, http_request=http_request
|
|
):
|
|
yield event
|
|
|
|
# Failsafe to make sure we don't try to handle OAuth flow twice
|
|
if not oauth_flow_attempted:
|
|
async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request):
|
|
yield event
|
|
return
|
|
except Exception as e:
|
|
detailed_error = drill_down_exception(e)
|
|
logger.error(f"Error in OAuth stream:\n{detailed_error}")
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
|
|
|
|
finally:
|
|
if client:
|
|
try:
|
|
await client.cleanup()
|
|
except Exception as cleanup_error:
|
|
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|
|
|
|
return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream")
|