Files
letta-server/letta/server/rest_api/routers/v1/mcp_servers.py
jnjpng 350f3a751c fix: update more plaintext non async callsites (#7223)
* bae

* update

* fix

* clean up

* last
2025-12-17 17:31:02 -08:00

313 lines
12 KiB
Python

from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from httpx import HTTPStatusError
from starlette.responses import StreamingResponse
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
from letta.log import get_logger
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.mcp_server import (
CreateMCPServerRequest,
MCPServerUnion,
ToolExecuteRequest,
UpdateMCPServerRequest,
convert_generic_to_union,
convert_update_to_internal,
)
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.server.rest_api.dependencies import (
HeaderParams,
get_headers,
get_letta_server,
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.server.server import SyncServer
from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp.types import OauthStreamEvent
from letta.settings import tool_settings
router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"])
logger = get_logger(__name__)
@router.post(
"/",
response_model=MCPServerUnion,
operation_id="mcp_create_mcp_server",
)
async def create_mcp_server(
request: CreateMCPServerRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Add a new MCP server to the Letta MCP server config
"""
# TODO: add the tools to the MCP server table we made.
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
new_server = await server.mcp_server_manager.create_mcp_server_from_request(request, actor=actor)
return await convert_generic_to_union(new_server)
@router.get(
"/",
response_model=List[MCPServerUnion],
operation_id="mcp_list_mcp_servers",
)
async def list_mcp_servers(
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Get a list of all configured MCP servers
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
result = []
for mcp_server in mcp_servers:
result.append(await convert_generic_to_union(mcp_server))
return result
@router.get(
"/{mcp_server_id}",
response_model=MCPServerUnion,
operation_id="mcp_retrieve_mcp_server",
)
async def retrieve_mcp_server(
mcp_server_id: str,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Get a specific MCP server
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
current_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
return await convert_generic_to_union(current_server)
@router.delete(
"/{mcp_server_id}",
status_code=204,
operation_id="mcp_delete_mcp_server",
)
async def delete_mcp_server(
mcp_server_id: str,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Delete an MCP server by its ID
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
await server.mcp_server_manager.delete_mcp_server_by_id(mcp_server_id, actor=actor)
@router.patch(
"/{mcp_server_id}",
response_model=MCPServerUnion,
operation_id="mcp_update_mcp_server",
)
async def update_mcp_server(
mcp_server_id: str,
request: UpdateMCPServerRequest = Body(...),
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Update an existing MCP server configuration
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Convert external update payload to internal manager union
internal_update = convert_update_to_internal(request)
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
)
return await convert_generic_to_union(updated_server)
@router.get("/{mcp_server_id}/tools", response_model=List[Tool], operation_id="mcp_list_tools_for_mcp_server")
async def list_tools_for_mcp_server(
mcp_server_id: str,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Get a list of all tools for a specific MCP server
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Use the new efficient method that queries from the database using MCPTools mapping
tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor)
return tools
@router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_retrieve_mcp_tool")
async def retrieve_mcp_tool(
mcp_server_id: str,
tool_id: str,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
):
"""
Get a specific MCP tool by its ID
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor)
return tool
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool")
async def run_mcp_tool(
mcp_server_id: str,
tool_id: str,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
request: ToolExecuteRequest = Body(default=ToolExecuteRequest()),
):
"""
Execute a specific MCP tool
The request body should contain the tool arguments in the ToolExecuteRequest format.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Execute the tool
result, success = await server.mcp_server_manager.execute_mcp_server_tool(
mcp_server_id=mcp_server_id,
tool_id=tool_id,
tool_args=request.args,
environment_variables={}, # TODO: Get environment variables from somewhere if needed
actor=actor,
)
# Create a ToolExecutionResult
return ToolExecutionResult(
status="success" if success else "error",
func_return=result,
)
@router.patch("/{mcp_server_id}/refresh", operation_id="mcp_refresh_mcp_server_tools")
async def refresh_mcp_server_tools(
mcp_server_id: str,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
agent_id: Optional[str] = None,
):
"""
Refresh tools for an MCP server by:
1. Fetching current tools from the MCP server
2. Deleting tools that no longer exist on the server
3. Updating schemas for existing tools
4. Adding new tools from the server
Returns a summary of changes made.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
result = await server.mcp_server_manager.resync_mcp_server_tools(mcp_server_id, actor=actor, agent_id=agent_id)
return result
@router.get(
"/connect/{mcp_server_id}",
response_model=None,
# TODO: make this into a model?
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
operation_id="mcp_connect_mcp_server",
)
async def connect_mcp_server(
mcp_server_id: str,
request: Request,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
) -> StreamingResponse:
"""
Connect to an MCP server with support for OAuth via SSE.
Returns a stream of events handling authorization state and exchange if OAuth is required.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
# Convert the MCP server to the appropriate config type
config = await mcp_server.to_config_async(resolve_variables=False)
async def oauth_stream_generator(
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
http_request: Request,
) -> AsyncGenerator[str, None]:
client = None
oauth_flow_attempted = False
try:
# Acknowledge connection attempt
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name)
# Create MCP client with respective transport type
try:
mcp_config.resolve_environment_variables()
client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor)
except ValueError as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
return
# Try normal connection first for flows that don't require OAuth
try:
await client.connect_to_server()
tools = await client.list_tools(serialize=True)
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
return
except ConnectionError:
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
if isinstance(client, AsyncStdioMCPClient):
logger.warning("OAuth not supported for stdio")
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
return
# Continue to OAuth flow
logger.info(f"Attempting OAuth flow for {mcp_config}...")
except Exception as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
return
finally:
if client:
try:
await client.cleanup()
# This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
# For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
except HTTPStatusError:
oauth_flow_attempted = True
async for event in server.mcp_server_manager.handle_oauth_flow(
request=mcp_config, actor=actor, http_request=http_request
):
yield event
# Failsafe to make sure we don't try to handle OAuth flow twice
if not oauth_flow_attempted:
async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request):
yield event
return
except Exception as e:
detailed_error = drill_down_exception(e)
logger.error(f"Error in OAuth stream:\n{detailed_error}")
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
finally:
if client:
try:
await client.cleanup()
except Exception as cleanup_error:
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream")