Revert "feat: revise mcp tool routes [LET-4321]" (#5652)

Revert "feat: revise mcp tool routes [LET-4321] (#5631)"

This reverts commit e15f120078652b2160d64a1e300317b95eccb163.
This commit is contained in:
Ari Webb
2025-10-22 11:38:58 -07:00
committed by Caren Thomas
parent 6757c7e201
commit abbd1b5595
8 changed files with 230 additions and 3596 deletions

View File

@@ -1,47 +0,0 @@
"""Add mcp_tools table
Revision ID: c6c43222e2de
Revises: 6756d04c3ddb
Create Date: 2025-10-20 17:25:54.334037
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c6c43222e2de"
down_revision: Union[str, None] = "6756d04c3ddb"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"mcp_tools",
sa.Column("mcp_server_id", sa.String(), nullable=False),
sa.Column("tool_id", sa.String(), nullable=False),
sa.Column("id", sa.String(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
)
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("mcp_tools")
# ### end Alembic commands ###

File diff suppressed because it is too large Load Diff

View File

@@ -56,12 +56,3 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
metadata_: Mapped[Optional[dict]] = mapped_column(
JSON, default=lambda: {}, doc="A dictionary of additional metadata for the MCP server."
)
class MCPTools(SqlalchemyBase, OrganizationMixin):
"""Represents a mapping of MCP server ID to tool ID"""
__tablename__ = "mcp_tools"
mcp_server_id: Mapped[str] = mapped_column(String, doc="The ID of the MCP server")
tool_id: Mapped[str] = mapped_column(String, doc="The ID of the tool")

View File

@@ -11,7 +11,6 @@ from letta.server.rest_api.routers.v1.internal_runs import router as internal_ru
from letta.server.rest_api.routers.v1.internal_templates import router as internal_templates_router
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
from letta.server.rest_api.routers.v1.llms import router as llm_router
from letta.server.rest_api.routers.v1.mcp_servers import router as mcp_servers_router
from letta.server.rest_api.routers.v1.messages import router as messages_router
from letta.server.rest_api.routers.v1.providers import router as providers_router
from letta.server.rest_api.routers.v1.runs import router as runs_router
@@ -35,7 +34,6 @@ ROUTERS = [
internal_runs_router,
internal_templates_router,
llm_router,
mcp_servers_router,
blocks_router,
jobs_router,
health_router,

View File

@@ -1,10 +1,8 @@
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Request
from httpx import HTTPStatusError
from fastapi import APIRouter, Body, Depends, HTTPException
from starlette.responses import StreamingResponse
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
from letta.log import get_logger
from letta.schemas.letta_message import ToolReturnMessage
from letta.schemas.mcp_server import (
@@ -15,17 +13,12 @@ from letta.schemas.mcp_server import (
convert_generic_to_union,
)
from letta.schemas.tool import Tool
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.server.rest_api.dependencies import (
HeaderParams,
get_headers,
get_letta_server,
)
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
from letta.server.server import SyncServer
from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp.types import OauthStreamEvent
from letta.settings import tool_settings
router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"])
@@ -46,7 +39,6 @@ async def create_mcp_server(
"""
Add a new MCP server to the Letta MCP server config
"""
# TODO: add the tools to the MCP server table we made.
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
new_server = await server.mcp_server_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
return convert_generic_to_union(new_server)
@@ -64,6 +56,7 @@ async def list_mcp_servers(
"""
Get a list of all configured MCP servers
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers]
@@ -134,10 +127,24 @@ async def list_mcp_tools_by_server(
"""
Get a list of all tools for a specific MCP server
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Use the new efficient method that queries from the database using MCPTools mapping
tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor)
return tools
# TODO: implement this. We want to use the new tools table instead of going to the mcp server.
pass
# actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# mcp_tools = await server.mcp_server_manager.list_mcp_server_tools(mcp_server_id, actor=actor)
# # Convert MCPTool objects to Tool objects
# tools = []
# for mcp_tool in mcp_tools:
# from letta.schemas.tool import ToolCreate
# tool_create = ToolCreate.from_mcp(mcp_server_name="", mcp_tool=mcp_tool)
# tools.append(Tool(
# id=f"mcp-tool-{mcp_tool.name}", # Generate a temporary ID
# name=mcp_tool.name,
# description=tool_create.description,
# json_schema=tool_create.json_schema,
# source_code=tool_create.source_code,
# tags=tool_create.tags,
# ))
# return tools
@router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_get_mcp_tool")
@@ -151,11 +158,13 @@ async def get_mcp_tool(
Get a specific MCP tool by its ID
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor)
# Use the tool_manager's existing method to get the tool by ID
# Verify the tool belongs to the MCP server (optional check)
tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor)
return tool
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool")
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolReturnMessage, operation_id="mcp_run_tool")
async def run_mcp_tool(
mcp_server_id: str,
tool_id: str,
@@ -179,10 +188,9 @@ async def run_mcp_tool(
actor=actor,
)
# Create a ToolExecutionResult
return ToolExecutionResult(
status="success" if success else "error",
func_return=result,
# Create a ToolReturnMessage
return ToolReturnMessage(
id=f"tool-return-{tool_id}", tool_call_id=f"call-{tool_id}", tool_return=result, status="success" if success else "error"
)
@@ -223,7 +231,6 @@ async def refresh_mcp_server_tools(
)
async def connect_mcp_server(
mcp_server_id: str,
request: Request,
server: SyncServer = Depends(get_letta_server),
headers: HeaderParams = Depends(get_headers),
) -> StreamingResponse:
@@ -231,76 +238,72 @@ async def connect_mcp_server(
Connect to an MCP server with support for OAuth via SSE.
Returns a stream of events handling authorization state and exchange if OAuth is required.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
pass
# Convert the MCP server to the appropriate config type
config = mcp_server.to_config(resolve_variables=False)
# async def oauth_stream_generator(
# request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
# http_request: Request,
# ) -> AsyncGenerator[str, None]:
# client = None
async def oauth_stream_generator(
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
http_request: Request,
) -> AsyncGenerator[str, None]:
client = None
# oauth_flow_attempted = False
# try:
# # Acknolwedge connection attempt
# yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
oauth_flow_attempted = False
try:
# Acknowledge connection attempt
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name)
# actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
# Create MCP client with respective transport type
try:
mcp_config.resolve_environment_variables()
client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor)
except ValueError as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
return
# # Create MCP client with respective transport type
# try:
# request.resolve_environment_variables()
# client = await server.mcp_server_manager.get_mcp_client(request, actor)
# except ValueError as e:
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
# return
# Try normal connection first for flows that don't require OAuth
try:
await client.connect_to_server()
tools = await client.list_tools(serialize=True)
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
return
except ConnectionError:
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
if isinstance(client, AsyncStdioMCPClient):
logger.warning("OAuth not supported for stdio")
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
return
# Continue to OAuth flow
logger.info(f"Attempting OAuth flow for {mcp_config}...")
except Exception as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
return
finally:
if client:
try:
await client.cleanup()
# This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
# For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
except HTTPStatusError:
oauth_flow_attempted = True
async for event in server.mcp_server_manager.handle_oauth_flow(
request=mcp_config, actor=actor, http_request=http_request
):
yield event
# # Try normal connection first for flows that don't require OAuth
# try:
# await client.connect_to_server()
# tools = await client.list_tools(serialize=True)
# yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
# return
# except ConnectionError:
# # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
# if isinstance(client, AsyncStdioMCPClient):
# logger.warning("OAuth not supported for stdio")
# yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
# return
# # Continue to OAuth flow
# logger.info(f"Attempting OAuth flow for {request}...")
# except Exception as e:
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
# return
# finally:
# if client:
# try:
# await client.cleanup()
# # This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
# # For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
# except* HTTPStatusError:
# oauth_flow_attempted = True
# async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request):
# yield event
# Failsafe to make sure we don't try to handle OAuth flow twice
if not oauth_flow_attempted:
async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request):
yield event
return
except Exception as e:
detailed_error = drill_down_exception(e)
logger.error(f"Error in OAuth stream:\n{detailed_error}")
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
# # Failsafe to make sure we don't try to handle OAuth flow twice
# if not oauth_flow_attempted:
# async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request):
# yield event
# return
# except Exception as e:
# detailed_error = drill_down_exception(e)
# logger.error(f"Error in OAuth stream:\n{detailed_error}")
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
finally:
if client:
try:
await client.cleanup()
except Exception as cleanup_error:
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
# finally:
# if client:
# try:
# await client.cleanup()
# except Exception as cleanup_error:
# logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream")
# return StreamingResponseWithStatusCode(oauth_stream_generator(request, http_request), media_type="text/event-stream")

View File

@@ -94,7 +94,6 @@ from letta.services.mcp.base_client import AsyncBaseMCPClient
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
from letta.services.mcp_manager import MCPManager
from letta.services.mcp_server_manager import MCPServerManager
from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager
from letta.services.passage_manager import PassageManager
@@ -155,7 +154,6 @@ class SyncServer(object):
self.user_manager = UserManager()
self.tool_manager = ToolManager()
self.mcp_manager = MCPManager()
self.mcp_server_manager = MCPServerManager()
self.block_manager = BlockManager()
self.source_manager = SourceManager()
self.sandbox_config_manager = SandboxConfigManager()

File diff suppressed because it is too large Load Diff

View File

@@ -1,858 +0,0 @@
"""
Integration tests for the new MCP server endpoints (/v1/mcp-servers/).
Tests all CRUD operations, tool management, and OAuth connection flows.
Uses plain dictionaries since SDK types are not yet generated.
"""
import os
import sys
import threading
import time
import uuid
from pathlib import Path
from typing import Any, Dict, List, Optional
import pytest
import requests
from dotenv import load_dotenv
# ------------------------------
# Fixtures
# ------------------------------
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
yield url
@pytest.fixture(scope="module")
def auth_headers() -> Dict[str, str]:
"""
Provides authentication headers for API requests.
"""
# Get auth token from environment or use default
token = os.getenv("LETTA_API_TOKEN", "")
if token:
return {"Authorization": f"Bearer {token}"}
return {}
@pytest.fixture(scope="function")
def unique_server_id() -> str:
"""Generate a unique MCP server ID for each test."""
# MCP server IDs follow the format: mcp_server-<uuid>
return f"mcp_server-{uuid.uuid4()}"
@pytest.fixture(scope="function")
def mock_mcp_server_path() -> Path:
"""Get path to mock MCP server for testing."""
script_dir = Path(__file__).parent
mcp_server_path = script_dir / "mock_mcp_server.py"
if not mcp_server_path.exists():
# Create a minimal mock server for testing if it doesn't exist
pytest.skip(f"Mock MCP server not found at {mcp_server_path}")
return mcp_server_path
# ------------------------------
# Helper Functions
# ------------------------------
def create_stdio_server_dict(server_name: str, command: str = "npx", args: List[str] = None) -> Dict[str, Any]:
"""Create a dictionary representing a stdio MCP server configuration."""
return {
"type": "stdio",
"server_name": server_name,
"command": command,
"args": args or ["-y", "@modelcontextprotocol/server-everything"],
"env": {"NODE_ENV": "test", "DEBUG": "true"},
}
def create_sse_server_dict(server_name: str, server_url: str = None) -> Dict[str, Any]:
"""Create a dictionary representing an SSE MCP server configuration."""
return {
"type": "sse",
"server_name": server_name,
"server_url": server_url or "https://api.example.com/sse",
"auth_header": "Authorization",
"auth_token": "Bearer test_token_123",
"custom_headers": {"X-Custom-Header": "custom_value", "X-API-Version": "1.0"},
}
def create_streamable_http_server_dict(server_name: str, server_url: str = None) -> Dict[str, Any]:
"""Create a dictionary representing a streamable HTTP MCP server configuration."""
return {
"type": "streamable_http",
"server_name": server_name,
"server_url": server_url or "https://api.example.com/streamable",
"auth_header": "X-API-Key",
"auth_token": "api_key_456",
"custom_headers": {"Accept": "application/json", "X-Version": "2.0"},
}
# ------------------------------
# Test Cases for CRUD Operations
# ------------------------------
def test_create_stdio_mcp_server(server_url: str, auth_headers: Dict[str, str]):
"""Test creating a stdio MCP server."""
server_name = f"test-stdio-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
# Create the server
response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert response.status_code == 200, f"Failed to create server: {response.text}"
server_data = response.json()
assert server_data["server_name"] == server_name
assert server_data["command"] == server_config["command"]
assert server_data["args"] == server_config["args"]
assert "id" in server_data # Should have an ID assigned
server_id = server_data["id"]
# Cleanup - delete the server
delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert delete_response.status_code == 204, f"Failed to delete server: {delete_response.text}"
def test_create_sse_mcp_server(server_url: str, auth_headers: Dict[str, str]):
"""Test creating an SSE MCP server."""
server_name = f"test-sse-{uuid.uuid4().hex[:8]}"
server_config = create_sse_server_dict(server_name)
# Create the server
response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert response.status_code == 200, f"Failed to create server: {response.text}"
server_data = response.json()
assert server_data["server_name"] == server_name
assert server_data["server_url"] == server_config["server_url"]
assert server_data["auth_header"] == server_config["auth_header"]
assert "id" in server_data
server_id = server_data["id"]
# Cleanup
delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert delete_response.status_code == 204
def test_create_streamable_http_mcp_server(server_url: str, auth_headers: Dict[str, str]):
"""Test creating a streamable HTTP MCP server."""
server_name = f"test-http-{uuid.uuid4().hex[:8]}"
server_config = create_streamable_http_server_dict(server_name)
# Create the server
response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert response.status_code == 200, f"Failed to create server: {response.text}"
server_data = response.json()
assert server_data["server_name"] == server_name
assert server_data["server_url"] == server_config["server_url"]
assert "id" in server_data
server_id = server_data["id"]
# Cleanup
delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert delete_response.status_code == 204
def test_list_mcp_servers(server_url: str, auth_headers: Dict[str, str]):
"""Test listing all MCP servers."""
# Create multiple servers
servers_created = []
# Create stdio server
stdio_name = f"list-test-stdio-{uuid.uuid4().hex[:8]}"
stdio_config = create_stdio_server_dict(stdio_name)
stdio_response = requests.post(f"{server_url}/v1/mcp-servers/", json=stdio_config, headers=auth_headers)
assert stdio_response.status_code == 200
stdio_server = stdio_response.json()
servers_created.append(stdio_server["id"])
# Create SSE server
sse_name = f"list-test-sse-{uuid.uuid4().hex[:8]}"
sse_config = create_sse_server_dict(sse_name)
sse_response = requests.post(f"{server_url}/v1/mcp-servers/", json=sse_config, headers=auth_headers)
assert sse_response.status_code == 200
sse_server = sse_response.json()
servers_created.append(sse_server["id"])
try:
# List all servers
list_response = requests.get(f"{server_url}/v1/mcp-servers/", headers=auth_headers)
assert list_response.status_code == 200
servers_list = list_response.json()
assert isinstance(servers_list, list)
assert len(servers_list) >= 2 # At least our two servers
# Check our servers are in the list
server_ids = [s["id"] for s in servers_list]
assert stdio_server["id"] in server_ids
assert sse_server["id"] in server_ids
# Check server names
server_names = [s["server_name"] for s in servers_list]
assert stdio_name in server_names
assert sse_name in server_names
finally:
# Cleanup
for server_id in servers_created:
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_get_specific_mcp_server(server_url: str, auth_headers: Dict[str, str]):
"""Test getting a specific MCP server by ID."""
# Create a server
server_name = f"get-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name, command="python", args=["-m", "mcp_server"])
server_config["env"]["PYTHONPATH"] = "/usr/local/lib"
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
created_server = create_response.json()
server_id = created_server["id"]
try:
# Get the server by ID
get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert get_response.status_code == 200
retrieved_server = get_response.json()
assert retrieved_server["id"] == server_id
assert retrieved_server["server_name"] == server_name
assert retrieved_server["command"] == "python"
assert retrieved_server["args"] == ["-m", "mcp_server"]
assert retrieved_server.get("env", {}).get("PYTHONPATH") == "/usr/local/lib"
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_update_stdio_mcp_server(server_url: str, auth_headers: Dict[str, str]):
"""Test updating a stdio MCP server."""
# Create a server
server_name = f"update-test-stdio-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name, command="node", args=["old_server.js"])
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Update the server
update_data = {
"server_name": "updated-stdio-server",
"command": "node",
"args": ["new_server.js", "--port", "3000"],
"env": {"NEW_ENV": "new_value", "PORT": "3000"},
}
update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers)
assert update_response.status_code == 200
updated_server = update_response.json()
assert updated_server["server_name"] == "updated-stdio-server"
assert updated_server["args"] == ["new_server.js", "--port", "3000"]
assert updated_server.get("env", {}).get("NEW_ENV") == "new_value"
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_update_sse_mcp_server(server_url: str, auth_headers: Dict[str, str]):
"""Test updating an SSE MCP server."""
# Create an SSE server
server_name = f"update-test-sse-{uuid.uuid4().hex[:8]}"
server_config = create_sse_server_dict(server_name, server_url="https://old.example.com/sse")
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Update the server
update_data = {
"server_name": "updated-sse-server",
"server_url": "https://new.example.com/sse/v2",
"token": "new_token_789",
"custom_headers": {"X-Updated": "true", "X-Version": "2.0"},
}
update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers)
assert update_response.status_code == 200
updated_server = update_response.json()
assert updated_server["server_name"] == "updated-sse-server"
assert updated_server["server_url"] == "https://new.example.com/sse/v2"
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_delete_mcp_server(server_url: str, auth_headers: Dict[str, str]):
"""Test deleting an MCP server."""
# Create a server to delete
server_name = f"delete-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
# Delete the server
delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert delete_response.status_code == 204
# Verify it's deleted (should get 404)
get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert get_response.status_code == 404
# ------------------------------
# Test Cases for Tool Operations
# ------------------------------
def test_list_mcp_tools_by_server(server_url: str, auth_headers: Dict[str, str]):
"""Test listing tools for a specific MCP server."""
# Create a server
server_name = f"tools-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# List tools for this server
tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert tools_response.status_code == 200
tools = tools_response.json()
assert isinstance(tools, list)
# Tools might be empty initially if server hasn't connected
# But response structure should be valid
if len(tools) > 0:
# Verify tool structure
tool = tools[0]
assert "id" in tool
assert "name" in tool
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_get_specific_mcp_tool(server_url: str, auth_headers: Dict[str, str]):
"""Test getting a specific tool from an MCP server."""
# Create a server
server_name = f"tool-get-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# First get list of tools
tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert tools_response.status_code == 200
tools = tools_response.json()
if len(tools) > 0:
# Get a specific tool
tool_id = tools[0]["id"]
tool_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}", headers=auth_headers)
assert tool_response.status_code == 200
specific_tool = tool_response.json()
assert specific_tool["id"] == tool_id
assert "name" in specific_tool
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_run_mcp_tool(server_url: str, auth_headers: Dict[str, str]):
"""Test executing an MCP tool."""
# Create a server
server_name = f"tool-run-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Get available tools
tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert tools_response.status_code == 200
tools = tools_response.json()
if len(tools) > 0:
# Run the first available tool
tool_id = tools[0]["id"]
# Run with arguments
run_request = {"args": {"test_param": "test_value", "count": 5}}
run_response = requests.post(
f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}/run", json=run_request, headers=auth_headers
)
assert run_response.status_code == 200
result = run_response.json()
assert "status" in result
assert result["status"] in ["success", "error"]
assert "func_return" in result
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_run_mcp_tool_without_args(server_url: str, auth_headers: Dict[str, str]):
"""Test executing an MCP tool without arguments."""
# Create a server
server_name = f"tool-noargs-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Get available tools
tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert tools_response.status_code == 200
tools = tools_response.json()
if len(tools) > 0:
tool_id = tools[0]["id"]
# Run without arguments (empty dict)
run_request = {"args": {}}
run_response = requests.post(
f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}/run", json=run_request, headers=auth_headers
)
assert run_response.status_code == 200
result = run_response.json()
assert "status" in result
assert "func_return" in result
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_refresh_mcp_server_tools(server_url: str, auth_headers: Dict[str, str]):
"""Test refreshing tools for an MCP server."""
# Create a server
server_name = f"refresh-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Get initial tools
initial_tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert initial_tools_response.status_code == 200
# Refresh tools
refresh_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}/refresh", headers=auth_headers)
assert refresh_response.status_code == 200
refresh_result = refresh_response.json()
# Result should contain summary of changes
assert refresh_result is not None
# Get tools after refresh
refreshed_tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert refreshed_tools_response.status_code == 200
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_refresh_mcp_server_tools_with_agent(server_url: str, auth_headers: Dict[str, str]):
"""Test refreshing tools with agent context."""
# Create a server
server_name = f"refresh-agent-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name)
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Refresh tools with agent ID
mock_agent_id = f"agent-{uuid.uuid4()}"
refresh_response = requests.patch(
f"{server_url}/v1/mcp-servers/{server_id}/refresh", params={"agent_id": mock_agent_id}, headers=auth_headers
)
assert refresh_response.status_code == 200
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
# ------------------------------
# Test Cases for OAuth/Connection
# ------------------------------
def test_connect_mcp_server_oauth(server_url: str, auth_headers: Dict[str, str]):
"""Test connecting to an MCP server (OAuth flow)."""
# Create an SSE server that might require OAuth
server_name = f"oauth-test-{uuid.uuid4().hex[:8]}"
server_config = create_sse_server_dict(server_name, server_url="https://oauth.example.com/sse")
# Remove token to simulate OAuth requirement
server_config["auth_token"] = None
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Attempt to connect (returns SSE stream)
# We can't fully test SSE in a simple integration test, but verify endpoint works
connect_response = requests.get(
f"{server_url}/v1/mcp-servers/connect/{server_id}",
headers={**auth_headers, "Accept": "text/event-stream"},
stream=True,
timeout=2,
)
# Should get a streaming response or error, not 404
assert connect_response.status_code in [200, 400, 500], f"Unexpected status: {connect_response.status_code}"
# Close the stream
connect_response.close()
except requests.exceptions.Timeout:
# Timeout is acceptable for SSE endpoints in tests
pass
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
# ------------------------------
# Test Cases for Error Handling
# ------------------------------
def test_error_handling_invalid_server_id(server_url: str, auth_headers: Dict[str, str]):
"""Test error handling with invalid server IDs."""
invalid_id = "invalid-server-id-12345"
# Try to get non-existent server
get_response = requests.get(f"{server_url}/v1/mcp-servers/{invalid_id}", headers=auth_headers)
assert get_response.status_code == 404
# Try to update non-existent server
update_data = {"server_name": "updated"}
update_response = requests.patch(f"{server_url}/v1/mcp-servers/{invalid_id}", json=update_data, headers=auth_headers)
assert update_response.status_code == 404 # Non-existent server returns 404
# Try to delete non-existent server
delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{invalid_id}", headers=auth_headers)
assert delete_response.status_code == 404
# Try to list tools for non-existent server
tools_response = requests.get(f"{server_url}/v1/mcp-servers/{invalid_id}/tools", headers=auth_headers)
assert tools_response.status_code == 404
def test_invalid_server_type(server_url: str, auth_headers: Dict[str, str]):
"""Test creating server with invalid type."""
invalid_config = {"type": "invalid_type", "server_name": "invalid-server", "some_field": "value"}
response = requests.post(f"{server_url}/v1/mcp-servers/", json=invalid_config, headers=auth_headers)
assert response.status_code == 422 # Validation error
# ------------------------------
# Test Cases for Complex Scenarios
# ------------------------------
def test_multiple_server_types_coexist(server_url: str, auth_headers: Dict[str, str]):
"""Test that multiple server types can coexist."""
servers_created = []
try:
# Create one of each type
stdio_config = create_stdio_server_dict(f"multi-stdio-{uuid.uuid4().hex[:8]}")
stdio_response = requests.post(f"{server_url}/v1/mcp-servers/", json=stdio_config, headers=auth_headers)
assert stdio_response.status_code == 200
stdio_server = stdio_response.json()
servers_created.append(stdio_server["id"])
sse_config = create_sse_server_dict(f"multi-sse-{uuid.uuid4().hex[:8]}")
sse_response = requests.post(f"{server_url}/v1/mcp-servers/", json=sse_config, headers=auth_headers)
assert sse_response.status_code == 200
sse_server = sse_response.json()
servers_created.append(sse_server["id"])
http_config = create_streamable_http_server_dict(f"multi-http-{uuid.uuid4().hex[:8]}")
http_response = requests.post(f"{server_url}/v1/mcp-servers/", json=http_config, headers=auth_headers)
assert http_response.status_code == 200
http_server = http_response.json()
servers_created.append(http_server["id"])
# List all servers
list_response = requests.get(f"{server_url}/v1/mcp-servers/", headers=auth_headers)
assert list_response.status_code == 200
servers_list = list_response.json()
server_ids = [s["id"] for s in servers_list]
# Verify all three are present
assert stdio_server["id"] in server_ids
assert sse_server["id"] in server_ids
assert http_server["id"] in server_ids
# Get each server and verify type-specific fields
stdio_get = requests.get(f"{server_url}/v1/mcp-servers/{stdio_server['id']}", headers=auth_headers)
assert stdio_get.status_code == 200
assert stdio_get.json()["command"] == stdio_config["command"]
sse_get = requests.get(f"{server_url}/v1/mcp-servers/{sse_server['id']}", headers=auth_headers)
assert sse_get.status_code == 200
assert sse_get.json()["server_url"] == sse_config["server_url"]
http_get = requests.get(f"{server_url}/v1/mcp-servers/{http_server['id']}", headers=auth_headers)
assert http_get.status_code == 200
assert http_get.json()["server_url"] == http_config["server_url"]
finally:
# Cleanup all servers
for server_id in servers_created:
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_partial_update_preserves_fields(server_url: str, auth_headers: Dict[str, str]):
"""Test that partial updates preserve non-updated fields."""
# Create a server with all fields
server_name = f"partial-update-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name, command="node", args=["server.js", "--port", "3000"])
server_config["env"] = {"NODE_ENV": "production", "PORT": "3000", "DEBUG": "false"}
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# Update only the server name
update_data = {"server_name": "renamed-server"}
update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers)
assert update_response.status_code == 200
updated_server = update_response.json()
assert updated_server["server_name"] == "renamed-server"
# Other fields should be preserved
assert updated_server["command"] == "node"
assert updated_server["args"] == ["server.js", "--port", "3000"]
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_concurrent_server_operations(server_url: str, auth_headers: Dict[str, str]):
"""Test multiple servers can be operated on concurrently."""
servers_created = []
try:
# Create multiple servers quickly
for i in range(3):
server_config = create_stdio_server_dict(f"concurrent-{i}-{uuid.uuid4().hex[:8]}", command="python", args=[f"server_{i}.py"])
response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert response.status_code == 200
servers_created.append(response.json()["id"])
# Update all servers
for i, server_id in enumerate(servers_created):
update_data = {"server_name": f"updated-concurrent-{i}"}
update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers)
assert update_response.status_code == 200
assert update_response.json()["server_name"] == f"updated-concurrent-{i}"
# Get all servers
for i, server_id in enumerate(servers_created):
get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert get_response.status_code == 200
assert get_response.json()["server_name"] == f"updated-concurrent-{i}"
finally:
# Cleanup all servers
for server_id in servers_created:
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
def test_full_server_lifecycle(server_url: str, auth_headers: Dict[str, str]):
"""Test complete lifecycle: create, list, get, update, tools, delete."""
# 1. Create server
server_name = f"lifecycle-test-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name, command="npx", args=["-y", "@modelcontextprotocol/server-everything"])
server_config["env"]["TEST"] = "true"
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# 2. List servers and verify it's there
list_response = requests.get(f"{server_url}/v1/mcp-servers/", headers=auth_headers)
assert list_response.status_code == 200
assert any(s["id"] == server_id for s in list_response.json())
# 3. Get specific server
get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert get_response.status_code == 200
assert get_response.json()["server_name"] == server_name
# 4. Update server
update_data = {"server_name": "lifecycle-updated", "env": {"TEST": "false", "NEW_VAR": "value"}}
update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers)
assert update_response.status_code == 200
assert update_response.json()["server_name"] == "lifecycle-updated"
# 5. List tools
tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert tools_response.status_code == 200
tools = tools_response.json()
assert isinstance(tools, list)
# 6. If tools exist, try to get and run one
if len(tools) > 0:
tool_id = tools[0]["id"]
# Get specific tool
tool_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}", headers=auth_headers)
assert tool_response.status_code == 200
assert tool_response.json()["id"] == tool_id
# Run tool
run_response = requests.post(
f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}/run", json={"args": {}}, headers=auth_headers
)
assert run_response.status_code == 200
# 7. Refresh tools
refresh_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}/refresh", headers=auth_headers)
assert refresh_response.status_code == 200
# 8. Try to connect (OAuth flow)
try:
connect_response = requests.get(
f"{server_url}/v1/mcp-servers/connect/{server_id}",
headers={**auth_headers, "Accept": "text/event-stream"},
stream=True,
timeout=1,
)
# Just verify it doesn't 404
assert connect_response.status_code in [200, 400, 500]
connect_response.close()
except requests.exceptions.Timeout:
pass # SSE timeout is acceptable
finally:
# 9. Delete server
delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert delete_response.status_code == 204
# 10. Verify it's deleted
get_deleted_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)
assert get_deleted_response.status_code == 404
# ------------------------------
# Test Cases for Empty Responses
# ------------------------------
def test_empty_tools_list(server_url: str, auth_headers: Dict[str, str]):
"""Test handling of servers with no tools."""
# Create a minimal server that likely has no tools
server_name = f"no-tools-{uuid.uuid4().hex[:8]}"
server_config = create_stdio_server_dict(server_name, command="echo", args=["hello"])
create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers)
assert create_response.status_code == 200
server_id = create_response.json()["id"]
try:
# List tools (should be empty)
tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers)
assert tools_response.status_code == 200
tools = tools_response.json()
assert tools is not None
assert isinstance(tools, list)
# Tools will be empty for a simple echo command
finally:
# Cleanup
requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)