feat: add new mcp_servers routes [LET-4321] (#5675)

---------

Co-authored-by: Ari Webb <ari@letta.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
Ari Webb
2025-10-23 17:19:08 -07:00
committed by Caren Thomas
parent 272f055b4a
commit c7c0d7507c
11 changed files with 3989 additions and 241 deletions

View File

@@ -0,0 +1,47 @@
"""Add mcp_tools table
Revision ID: c6c43222e2de
Revises: 6756d04c3ddb
Create Date: 2025-10-20 17:25:54.334037
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c6c43222e2de"
down_revision: Union[str, None] = "6756d04c3ddb"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"mcp_tools",
sa.Column("mcp_server_id", sa.String(), nullable=False),
sa.Column("tool_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("mcp_tools")
# ### end Alembic commands ###

File diff suppressed because it is too large Load Diff

View File

@@ -504,14 +504,43 @@ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormat
def serialize_mcp_stdio_config(config: Union[Optional[StdioServerConfig], Dict]) -> Optional[Dict]:
"""Convert an StdioServerConfig object into a JSON-serializable dictionary."""
"""Convert an StdioServerConfig object into a JSON-serializable dictionary.
Persist required fields for successful deserialization back into a
StdioServerConfig model (namely `server_name` and `type`). The
`to_dict()` helper intentionally omits these since they're not needed
by MCP transport, but our ORM deserializer reconstructs the pydantic
model and requires them.
"""
if config and isinstance(config, StdioServerConfig):
return config.to_dict()
data = config.to_dict()
# Preserve required fields for pydantic reconstruction
data["server_name"] = config.server_name
# Store enum as its value; pydantic will coerce on load
data["type"] = config.type.value if hasattr(config.type, "value") else str(config.type)
return data
return config
def deserialize_mcp_stdio_config(data: Optional[Dict]) -> Optional[StdioServerConfig]:
"""Convert a dictionary back into an StdioServerConfig object."""
"""Convert a dictionary back into an StdioServerConfig object.
Backwards-compatibility notes:
- Older rows may only include `transport`, `command`, `args`, `env`.
In that case, provide defaults for `server_name` and `type` to
satisfy the pydantic model requirements.
- If both `type` and `transport` are present, prefer `type`.
"""
if not data:
return None
return StdioServerConfig(**data)
payload = dict(data)
# Map legacy `transport` field to required `type` if missing
if "type" not in payload and "transport" in payload:
payload["type"] = payload["transport"]
# Ensure required field exists; use a sensible placeholder when unknown
if "server_name" not in payload:
payload["server_name"] = payload.get("name", "unknown")
return StdioServerConfig(**payload)

View File

@@ -56,3 +56,12 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
metadata_: Mapped[Optional[dict]] = mapped_column(
JSON, default=lambda: {}, doc="A dictionary of additional metadata for the MCP server."
)
class MCPTools(SqlalchemyBase, OrganizationMixin):
"""Represents a mapping of MCP server ID to tool ID"""
__tablename__ = "mcp_tools"
mcp_server_id: Mapped[str] = mapped_column(String, doc="The ID of the MCP server")
tool_id: Mapped[str] = mapped_column(String, doc="The ID of the tool")

View File

@@ -148,6 +148,7 @@ class MCPServer(BaseMCPServer):
class UpdateSSEMCPServer(LettaBase):
"""Update an SSE MCP server"""
server_name: Optional[str] = Field(None, description="The name of the MCP server")
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
@@ -156,6 +157,7 @@ class UpdateSSEMCPServer(LettaBase):
class UpdateStdioMCPServer(LettaBase):
"""Update a Stdio MCP server"""
server_name: Optional[str] = Field(None, description="The name of the MCP server")
stdio_config: Optional[StdioServerConfig] = Field(
None, description="The configuration for the server (MCP 'local' client will run this command)"
)
@@ -164,6 +166,7 @@ class UpdateStdioMCPServer(LettaBase):
class UpdateStreamableHTTPMCPServer(LettaBase):
"""Update a Streamable HTTP MCP server"""
server_name: Optional[str] = Field(None, description="The name of the MCP server")
server_url: Optional[str] = Field(None, description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")

View File

@@ -41,18 +41,21 @@ class StdioMCPServer(CreateStdioMCPServer):
"""A Stdio MCP server"""
id: str = BaseMCPServer.generate_id_field()
type: MCPServerType = MCPServerType.STDIO
class SSEMCPServer(CreateSSEMCPServer):
"""An SSE MCP server"""
id: str = BaseMCPServer.generate_id_field()
type: MCPServerType = MCPServerType.SSE
class StreamableHTTPMCPServer(CreateStreamableHTTPMCPServer):
"""A Streamable HTTP MCP server"""
id: str = BaseMCPServer.generate_id_field()
type: MCPServerType = MCPServerType.STREAMABLE_HTTP
MCPServerUnion = Union[StdioMCPServer, SSEMCPServer, StreamableHTTPMCPServer]
@@ -74,9 +77,10 @@ class UpdateSSEMCPServer(LettaBase):
server_name: Optional[str] = Field(None, description="The name of the MCP server")
server_url: Optional[str] = Field(None, description="The URL of the SSE MCP server")
# Note: auth_token is renamed to token to match the ORM field
token: Optional[str] = Field(None, description="The authentication token")
# auth_header is excluded as it's derived from the token
# Accept both `auth_token` (API surface) and `token` (internal ORM naming)
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
token: Optional[str] = Field(None, description="The authentication token (internal)")
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to send with requests")
@@ -85,9 +89,10 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
server_name: Optional[str] = Field(None, description="The name of the MCP server")
server_url: Optional[str] = Field(None, description="The URL of the Streamable HTTP MCP server")
# Note: auth_token is renamed to token to match the ORM field
token: Optional[str] = Field(None, description="The authentication token")
# auth_header is excluded as it's derived from the token
# Accept both `auth_token` (API surface) and `token` (internal ORM naming)
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
token: Optional[str] = Field(None, description="The authentication token (internal)")
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to send with requests")
@@ -296,3 +301,49 @@ def convert_generic_to_union(server) -> MCPServerUnion:
)
else:
raise ValueError(f"Unknown server type: {server.server_type}")
def convert_update_to_internal(request: Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer]):
"""Convert external API update models to internal UpdateMCPServer union used by the manager.
- Flattens stdio fields into StdioServerConfig inside UpdateStdioMCPServer
- Maps `auth_token` to `token` for HTTP-based transports
- Ignores `auth_header` at update time (header is derived from token)
"""
# Local import to avoid circulars
from letta.functions.mcp_client.types import MCPServerType as MCPType, StdioServerConfig as StdioCfg
from letta.schemas.mcp import (
UpdateSSEMCPServer as InternalUpdateSSE,
UpdateStdioMCPServer as InternalUpdateStdio,
UpdateStreamableHTTPMCPServer as InternalUpdateHTTP,
)
if isinstance(request, UpdateStdioMCPServer):
stdio_cfg = None
# Only build stdio_config if command and args are explicitly provided to avoid overwriting existing config
if request.command is not None and request.args is not None:
stdio_cfg = StdioCfg(
server_name=request.server_name or "",
type=MCPType.STDIO,
command=request.command,
args=request.args,
env=request.env,
)
kwargs: dict = {}
if request.server_name is not None:
kwargs["server_name"] = request.server_name
if stdio_cfg is not None:
kwargs["stdio_config"] = stdio_cfg
return InternalUpdateStdio(**kwargs)
elif isinstance(request, UpdateSSEMCPServer):
token_value = request.auth_token or request.token
return InternalUpdateSSE(
server_name=request.server_name, server_url=request.server_url, token=token_value, custom_headers=request.custom_headers
)
elif isinstance(request, UpdateStreamableHTTPMCPServer):
token_value = request.auth_token or request.token
return InternalUpdateHTTP(
server_name=request.server_name, server_url=request.server_url, auth_token=token_value, custom_headers=request.custom_headers
)
else:
raise TypeError(f"Unsupported update request type: {type(request)}")

View File

@@ -11,6 +11,7 @@ from letta.server.rest_api.routers.v1.internal_runs import router as internal_ru
from letta.server.rest_api.routers.v1.internal_templates import router as internal_templates_router
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
from letta.server.rest_api.routers.v1.llms import router as llm_router
from letta.server.rest_api.routers.v1.mcp_servers import router as mcp_servers_router
from letta.server.rest_api.routers.v1.messages import router as messages_router
from letta.server.rest_api.routers.v1.providers import router as providers_router
from letta.server.rest_api.routers.v1.runs import router as runs_router
@@ -34,6 +35,7 @@ ROUTERS = [
internal_runs_router,
internal_templates_router,
llm_router,
mcp_servers_router,
blocks_router,
jobs_router,
health_router,

View File

@@ -1,8 +1,10 @@
from typing import Any, Dict, List, Optional
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from fastapi import APIRouter, Body, Depends, HTTPException
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from httpx import HTTPStatusError
from starlette.responses import StreamingResponse
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
from letta.log import get_logger
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.mcp_server import (
@@ -11,14 +13,20 @@ from letta.schemas.mcp_server import (
MCPToolExecuteRequest,
UpdateMCPServerUnion,
convert_generic_to_union,
convert_update_to_internal,
)
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.server.rest_api.dependencies import (
HeaderParams,
get_headers,
get_letta_server,
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.server.server import SyncServer
from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp.types import OauthStreamEvent
from letta.settings import tool_settings
router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"])
@@ -39,6 +47,7 @@ async def create_mcp_server(
"""
Add a new MCP server to the Letta MCP server config
"""
# TODO: add the tools to the MCP server table we made.
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
new_server = await server.mcp_server_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
return convert_generic_to_union(new_server)
@@ -56,7 +65,6 @@ async def list_mcp_servers(
"""
Get a list of all configured MCP servers
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers]
@@ -112,8 +120,10 @@ async def update_mcp_server(
Update an existing MCP server configuration
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Convert external update payload to internal manager union
internal_update = convert_update_to_internal(request)
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
mcp_server_id=mcp_server_id, mcp_server_update=request, actor=actor
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
)
return convert_generic_to_union(updated_server)
@@ -127,24 +137,10 @@ async def list_mcp_tools_by_server(
"""
Get a list of all tools for a specific MCP server
"""
# TODO: implement this. We want to use the new tools table instead of going to the mcp server.
pass
# actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# mcp_tools = await server.mcp_server_manager.list_mcp_server_tools(mcp_server_id, actor=actor)
# # Convert MCPTool objects to Tool objects
# tools = []
# for mcp_tool in mcp_tools:
# from letta.schemas.tool import ToolCreate
# tool_create = ToolCreate.from_mcp(mcp_server_name="", mcp_tool=mcp_tool)
# tools.append(Tool(
# id=f"mcp-tool-{mcp_tool.name}", # Generate a temporary ID
# name=mcp_tool.name,
# description=tool_create.description,
# json_schema=tool_create.json_schema,
# source_code=tool_create.source_code,
# tags=tool_create.tags,
# ))
# return tools
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Use the new efficient method that queries from the database using MCPTools mapping
tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor)
return tools
@router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_get_mcp_tool")
@@ -158,13 +154,11 @@ async def get_mcp_tool(
Get a specific MCP tool by its ID
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Use the tool_manager's existing method to get the tool by ID
# Verify the tool belongs to the MCP server (optional check)
tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor)
tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor)
return tool
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolReturnMessage, operation_id="mcp_run_tool")
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool")
async def run_mcp_tool(
mcp_server_id: str,
tool_id: str,
@@ -188,9 +182,10 @@ async def run_mcp_tool(
actor=actor,
)
# Create a ToolReturnMessage
return ToolReturnMessage(
id=f"tool-return-{tool_id}", tool_call_id=f"call-{tool_id}", tool_return=result, status="success" if success else "error"
# Create a ToolExecutionResult
return ToolExecutionResult(
status="success" if success else "error",
func_return=result,
)
@@ -231,6 +226,7 @@ async def refresh_mcp_server_tools(
)
async def connect_mcp_server(
mcp_server_id: str,
request: Request,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
) -> StreamingResponse:
@@ -238,72 +234,76 @@ async def connect_mcp_server(
Connect to an MCP server with support for OAuth via SSE.
Returns a stream of events handling authorization state and exchange if OAuth is required.
"""
pass
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
# async def oauth_stream_generator(
# request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
# http_request: Request,
# ) -> AsyncGenerator[str, None]:
# client = None
# Convert the MCP server to the appropriate config type
config = mcp_server.to_config(resolve_variables=False)
# oauth_flow_attempted = False
# try:
# # Acknolwedge connection attempt
# yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
async def oauth_stream_generator(
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
http_request: Request,
) -> AsyncGenerator[str, None]:
client = None
# actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
oauth_flow_attempted = False
try:
# Acknowledge connection attempt
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name)
# # Create MCP client with respective transport type
# try:
# request.resolve_environment_variables()
# client = await server.mcp_server_manager.get_mcp_client(request, actor)
# except ValueError as e:
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
# return
# Create MCP client with respective transport type
try:
mcp_config.resolve_environment_variables()
client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor)
except ValueError as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
return
# # Try normal connection first for flows that don't require OAuth
# try:
# await client.connect_to_server()
# tools = await client.list_tools(serialize=True)
# yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
# return
# except ConnectionError:
# # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
# if isinstance(client, AsyncStdioMCPClient):
# logger.warning("OAuth not supported for stdio")
# yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
# return
# # Continue to OAuth flow
# logger.info(f"Attempting OAuth flow for {request}...")
# except Exception as e:
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
# return
# finally:
# if client:
# try:
# await client.cleanup()
# # This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
# # For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
# except* HTTPStatusError:
# oauth_flow_attempted = True
# async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request):
# yield event
# Try normal connection first for flows that don't require OAuth
try:
await client.connect_to_server()
tools = await client.list_tools(serialize=True)
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
return
except ConnectionError:
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
if isinstance(client, AsyncStdioMCPClient):
logger.warning("OAuth not supported for stdio")
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
return
# Continue to OAuth flow
logger.info(f"Attempting OAuth flow for {mcp_config}...")
except Exception as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
return
finally:
if client:
try:
await client.cleanup()
# This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
# For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
except HTTPStatusError:
oauth_flow_attempted = True
async for event in server.mcp_server_manager.handle_oauth_flow(
request=mcp_config, actor=actor, http_request=http_request
):
yield event
# # Failsafe to make sure we don't try to handle OAuth flow twice
# if not oauth_flow_attempted:
# async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request):
# yield event
# return
# except Exception as e:
# detailed_error = drill_down_exception(e)
# logger.error(f"Error in OAuth stream:\n{detailed_error}")
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
# Failsafe to make sure we don't try to handle OAuth flow twice
if not oauth_flow_attempted:
async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request):
yield event
return
except Exception as e:
detailed_error = drill_down_exception(e)
logger.error(f"Error in OAuth stream:\n{detailed_error}")
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
# finally:
# if client:
# try:
# await client.cleanup()
# except Exception as cleanup_error:
# logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
finally:
if client:
try:
await client.cleanup()
except Exception as cleanup_error:
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
# return StreamingResponseWithStatusCode(oauth_stream_generator(request, http_request), media_type="text/event-stream")
return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream")

View File

@@ -94,6 +94,7 @@ from letta.services.mcp.base_client import AsyncBaseMCPClient
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp_manager import MCPManager
from letta.services.mcp_server_manager import MCPServerManager
from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager
from letta.services.passage_manager import PassageManager
@@ -154,6 +155,7 @@ class SyncServer(object):
self.user_manager = UserManager()
self.tool_manager = ToolManager()
self.mcp_manager = MCPManager()
self.mcp_server_manager = MCPServerManager()
self.block_manager = BlockManager()
self.source_manager = SourceManager()
self.sandbox_config_manager = SandboxConfigManager()

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff