From abbd1b5595a833d7ff073e6319753f1c2428f28b Mon Sep 17 00:00:00 2001 From: Ari Webb Date: Wed, 22 Oct 2025 11:38:58 -0700 Subject: [PATCH] Revert "feat: revise mcp tool routes [LET-4321]" (#5652) Revert "feat: revise mcp tool routes [LET-4321] (#5631)" This reverts commit e15f120078652b2160d64a1e300317b95eccb163. --- .../c6c43222e2de_add_mcp_tools_table.py | 47 - fern/openapi.json | 1422 ++--------------- letta/orm/mcp_server.py | 9 - letta/server/rest_api/routers/v1/__init__.py | 2 - .../server/rest_api/routers/v1/mcp_servers.py | 175 +- letta/server/server.py | 2 - letta/services/mcp_server_manager.py | 1311 --------------- tests/integration_test_mcp_servers.py | 858 ---------- 8 files changed, 230 insertions(+), 3596 deletions(-) delete mode 100644 alembic/versions/c6c43222e2de_add_mcp_tools_table.py delete mode 100644 letta/services/mcp_server_manager.py delete mode 100644 tests/integration_test_mcp_servers.py diff --git a/alembic/versions/c6c43222e2de_add_mcp_tools_table.py b/alembic/versions/c6c43222e2de_add_mcp_tools_table.py deleted file mode 100644 index 280ef3d6..00000000 --- a/alembic/versions/c6c43222e2de_add_mcp_tools_table.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Add mcp_tools table - -Revision ID: c6c43222e2de -Revises: 6756d04c3ddb -Create Date: 2025-10-20 17:25:54.334037 - -""" - -from typing import Sequence, Union - -import sqlalchemy as sa - -from alembic import op - -# revision identifiers, used by Alembic. -revision: str = "c6c43222e2de" -down_revision: Union[str, None] = "6756d04c3ddb" -branch_labels: Union[str, Sequence[str], None] = None -depends_on: Union[str, Sequence[str], None] = None - - -def upgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.create_table( - "mcp_tools", - sa.Column("mcp_server_id", sa.String(), nullable=False), - sa.Column("tool_id", sa.String(), nullable=False), - sa.Column("id", sa.String(), nullable=False), - sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), - sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), - sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), - sa.Column("_created_by_id", sa.String(), nullable=True), - sa.Column("_last_updated_by_id", sa.String(), nullable=True), - sa.Column("organization_id", sa.String(), nullable=False), - sa.ForeignKeyConstraint( - ["organization_id"], - ["organizations.id"], - ), - sa.PrimaryKeyConstraint("id"), - ) - # ### end Alembic commands ### - - -def downgrade() -> None: - # ### commands auto generated by Alembic - please adjust! ### - op.drop_table("mcp_tools") - # ### end Alembic commands ### diff --git a/fern/openapi.json b/fern/openapi.json index 5640b3e0..e3874fa5 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -1283,13 +1283,13 @@ "schema": { "anyOf": [ { - "$ref": "#/components/schemas/letta__schemas__mcp__UpdateStdioMCPServer" + "$ref": "#/components/schemas/UpdateStdioMCPServer" }, { - "$ref": "#/components/schemas/letta__schemas__mcp__UpdateSSEMCPServer" + "$ref": "#/components/schemas/UpdateSSEMCPServer" }, { - "$ref": "#/components/schemas/letta__schemas__mcp__UpdateStreamableHTTPMCPServer" + "$ref": "#/components/schemas/UpdateStreamableHTTPMCPServer" } ], "title": "Request" @@ -1519,7 +1519,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/letta__server__rest_api__routers__v1__tools__MCPToolExecuteRequest" + "$ref": "#/components/schemas/MCPToolExecuteRequest" } } } @@ -9528,521 +9528,6 @@ } } }, - "/v1/mcp-servers/": { - "post": { - "tags": ["mcp-servers"], - "summary": "Create Mcp Server", - "description": "Add a new MCP server to the Letta MCP server config", - "operationId": "mcp_create_mcp_server", - "parameters": [], - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "anyOf": [ - { - "$ref": "#/components/schemas/CreateStdioMCPServer" - }, - { - "$ref": "#/components/schemas/CreateSSEMCPServer" - }, - { - "$ref": "#/components/schemas/CreateStreamableHTTPMCPServer" - } - ], - "title": "Request" - } - } - } - }, - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "anyOf": [ - { - "$ref": "#/components/schemas/StdioMCPServer" - }, - { - "$ref": "#/components/schemas/SSEMCPServer" - }, - { - "$ref": "#/components/schemas/StreamableHTTPMCPServer" - } - ], - "title": "Response Mcp Create Mcp Server" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - }, - "get": { - "tags": ["mcp-servers"], - "summary": "List Mcp Servers", - "description": "Get a list of all configured MCP servers", - "operationId": "mcp_list_mcp_servers", - "parameters": [], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "anyOf": [ - { - "$ref": "#/components/schemas/StdioMCPServer" - }, - { - "$ref": "#/components/schemas/SSEMCPServer" - }, - { - "$ref": "#/components/schemas/StreamableHTTPMCPServer" - } - ] - }, - "title": "Response Mcp List Mcp Servers" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, - "/v1/mcp-servers/{mcp_server_id}": { - "get": { - "tags": ["mcp-servers"], - "summary": "Get Mcp Server", - "description": "Get a specific MCP server", - "operationId": "mcp_get_mcp_server", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - } - ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "anyOf": [ - { - "$ref": "#/components/schemas/StdioMCPServer" - }, - { - "$ref": "#/components/schemas/SSEMCPServer" - }, - { - "$ref": "#/components/schemas/StreamableHTTPMCPServer" - } - ], - "title": "Response Mcp Get Mcp Server" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - }, - "delete": { - "tags": ["mcp-servers"], - "summary": "Delete Mcp Server", - "description": "Delete an MCP server by its ID", - "operationId": "mcp_delete_mcp_server", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - } - ], - "responses": { - "204": { - "description": "Successful Response" - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - }, - "patch": { - "tags": ["mcp-servers"], - "summary": "Update Mcp Server", - "description": "Update an existing MCP server configuration", - "operationId": "mcp_update_mcp_server", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - } - ], - "requestBody": { - "required": true, - "content": { - "application/json": { - "schema": { - "anyOf": [ - { - "$ref": "#/components/schemas/letta__schemas__mcp_server__UpdateStdioMCPServer" - }, - { - "$ref": "#/components/schemas/letta__schemas__mcp_server__UpdateSSEMCPServer" - }, - { - "$ref": "#/components/schemas/letta__schemas__mcp_server__UpdateStreamableHTTPMCPServer" - } - ], - "title": "Request" - } - } - } - }, - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "anyOf": [ - { - "$ref": "#/components/schemas/StdioMCPServer" - }, - { - "$ref": "#/components/schemas/SSEMCPServer" - }, - { - "$ref": "#/components/schemas/StreamableHTTPMCPServer" - } - ], - "title": "Response Mcp Update Mcp Server" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, - "/v1/mcp-servers/{mcp_server_id}/tools": { - "get": { - "tags": ["mcp-servers"], - "summary": "List Mcp Tools By Server", - "description": "Get a list of all tools for a specific MCP server", - "operationId": "mcp_list_mcp_tools_by_server", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - } - ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "type": "array", - "items": { - "$ref": "#/components/schemas/Tool" - }, - "title": "Response Mcp List Mcp Tools By Server" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, - "/v1/mcp-servers/{mcp_server_id}/tools/{tool_id}": { - "get": { - "tags": ["mcp-servers"], - "summary": "Get Mcp Tool", - "description": "Get a specific MCP tool by its ID", - "operationId": "mcp_get_mcp_tool", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - }, - { - "name": "tool_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Tool Id" - } - } - ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/Tool" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, - "/v1/mcp-servers/{mcp_server_id}/tools/{tool_id}/run": { - "post": { - "tags": ["mcp-servers"], - "summary": "Run Mcp Tool", - "description": "Execute a specific MCP tool\n\nThe request body should contain the tool arguments in the MCPToolExecuteRequest format.", - "operationId": "mcp_run_tool", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - }, - { - "name": "tool_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Tool Id" - } - } - ], - "requestBody": { - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/letta__schemas__mcp_server__MCPToolExecuteRequest", - "default": { - "args": {} - } - } - } - } - }, - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/ToolExecutionResult" - } - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, - "/v1/mcp-servers/{mcp_server_id}/refresh": { - "patch": { - "tags": ["mcp-servers"], - "summary": "Refresh Mcp Server Tools", - "description": "Refresh tools for an MCP server by:\n1. Fetching current tools from the MCP server\n2. Deleting tools that no longer exist on the server\n3. Updating schemas for existing tools\n4. Adding new tools from the server\n\nReturns a summary of changes made.", - "operationId": "mcp_refresh_mcp_server_tools", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - }, - { - "name": "agent_id", - "in": "query", - "required": false, - "schema": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Agent Id" - } - } - ], - "responses": { - "200": { - "description": "Successful Response", - "content": { - "application/json": { - "schema": {} - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, - "/v1/mcp-servers/connect/{mcp_server_id}": { - "get": { - "tags": ["mcp-servers"], - "summary": "Connect Mcp Server", - "description": "Connect to an MCP server with support for OAuth via SSE.\nReturns a stream of events handling authorization state and exchange if OAuth is required.", - "operationId": "mcp_connect_mcp_server", - "parameters": [ - { - "name": "mcp_server_id", - "in": "path", - "required": true, - "schema": { - "type": "string", - "title": "Mcp Server Id" - } - } - ], - "responses": { - "200": { - "description": "Successful response", - "content": { - "application/json": { - "schema": {} - }, - "text/event-stream": { - "description": "Server-Sent Events stream" - } - } - }, - "422": { - "description": "Validation Error", - "content": { - "application/json": { - "schema": { - "$ref": "#/components/schemas/HTTPValidationError" - } - } - } - } - } - } - }, "/v1/blocks/": { "get": { "tags": ["blocks"], @@ -28133,173 +27618,6 @@ "title": "CreateBlock", "description": "Create a block" }, - "CreateSSEMCPServer": { - "properties": { - "server_name": { - "type": "string", - "title": "Server Name", - "description": "The name of the server" - }, - "type": { - "$ref": "#/components/schemas/MCPServerType", - "default": "sse" - }, - "server_url": { - "type": "string", - "title": "Server Url", - "description": "The URL of the server" - }, - "auth_header": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Header", - "description": "The name of the authentication header (e.g., 'Authorization')" - }, - "auth_token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Token", - "description": "The authentication token or API key value" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom HTTP headers to include with requests" - } - }, - "type": "object", - "required": ["server_name", "server_url"], - "title": "CreateSSEMCPServer", - "description": "Create a new SSE MCP server" - }, - "CreateStdioMCPServer": { - "properties": { - "server_name": { - "type": "string", - "title": "Server Name", - "description": "The name of the server" - }, - "type": { - "$ref": "#/components/schemas/MCPServerType", - "default": "stdio" - }, - "command": { - "type": "string", - "title": "Command", - "description": "The command to run (MCP 'local' client will run this command)" - }, - "args": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Args", - "description": "The arguments to pass to the command" - }, - "env": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Env", - "description": "Environment variables to set" - } - }, - "type": "object", - "required": ["server_name", "command", "args"], - "title": "CreateStdioMCPServer", - "description": "Create a new Stdio MCP server" - }, - "CreateStreamableHTTPMCPServer": { - "properties": { - "server_name": { - "type": "string", - "title": "Server Name", - "description": "The name of the server" - }, - "type": { - "$ref": "#/components/schemas/MCPServerType", - "default": "streamable_http" - }, - "server_url": { - "type": "string", - "title": "Server Url", - "description": "The URL of the server" - }, - "auth_header": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Header", - "description": "The name of the authentication header (e.g., 'Authorization')" - }, - "auth_token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Token", - "description": "The authentication token or API key value" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom HTTP headers to include with requests" - } - }, - "type": "object", - "required": ["server_name", "server_url"], - "title": "CreateStreamableHTTPMCPServer", - "description": "Create a new Streamable HTTP MCP server" - }, "Custom-Input": { "properties": { "input": { @@ -32890,6 +32208,18 @@ "title": "MCPTool", "description": "A simple wrapper around MCP's tool definition (to avoid conflict with our own)" }, + "MCPToolExecuteRequest": { + "properties": { + "args": { + "additionalProperties": true, + "type": "object", + "title": "Args", + "description": "Arguments to pass to the MCP tool" + } + }, + "type": "object", + "title": "MCPToolExecuteRequest" + }, "MCPToolHealth": { "properties": { "status": { @@ -35331,74 +34661,6 @@ "title": "RunStatus", "description": "Status of the run." }, - "SSEMCPServer": { - "properties": { - "server_name": { - "type": "string", - "title": "Server Name", - "description": "The name of the server" - }, - "type": { - "$ref": "#/components/schemas/MCPServerType", - "default": "sse" - }, - "server_url": { - "type": "string", - "title": "Server Url", - "description": "The URL of the server" - }, - "auth_header": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Header", - "description": "The name of the authentication header (e.g., 'Authorization')" - }, - "auth_token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Token", - "description": "The authentication token or API key value" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom HTTP headers to include with requests" - }, - "id": { - "type": "string", - "pattern": "^mcp_server-[a-fA-F0-9]{8}", - "title": "Id", - "description": "The human-friendly ID of the Mcp_server", - "examples": ["mcp_server-123e4567-e89b-12d3-a456-426614174000"] - } - }, - "type": "object", - "required": ["server_name", "server_url"], - "title": "SSEMCPServer", - "description": "An SSE MCP server" - }, "SSEServerConfig": { "properties": { "server_name": { @@ -36233,58 +35495,6 @@ "title": "SourceUpdate", "description": "Schema for updating an existing Source." }, - "StdioMCPServer": { - "properties": { - "server_name": { - "type": "string", - "title": "Server Name", - "description": "The name of the server" - }, - "type": { - "$ref": "#/components/schemas/MCPServerType", - "default": "stdio" - }, - "command": { - "type": "string", - "title": "Command", - "description": "The command to run (MCP 'local' client will run this command)" - }, - "args": { - "items": { - "type": "string" - }, - "type": "array", - "title": "Args", - "description": "The arguments to pass to the command" - }, - "env": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Env", - "description": "Environment variables to set" - }, - "id": { - "type": "string", - "pattern": "^mcp_server-[a-fA-F0-9]{8}", - "title": "Id", - "description": "The human-friendly ID of the Mcp_server", - "examples": ["mcp_server-123e4567-e89b-12d3-a456-426614174000"] - } - }, - "type": "object", - "required": ["server_name", "command", "args"], - "title": "StdioMCPServer", - "description": "A Stdio MCP server" - }, "StdioServerConfig": { "properties": { "server_name": { @@ -36782,74 +35992,6 @@ ], "title": "StopReasonType" }, - "StreamableHTTPMCPServer": { - "properties": { - "server_name": { - "type": "string", - "title": "Server Name", - "description": "The name of the server" - }, - "type": { - "$ref": "#/components/schemas/MCPServerType", - "default": "streamable_http" - }, - "server_url": { - "type": "string", - "title": "Server Url", - "description": "The URL of the server" - }, - "auth_header": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Header", - "description": "The name of the authentication header (e.g., 'Authorization')" - }, - "auth_token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Token", - "description": "The authentication token or API key value" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom HTTP headers to include with requests" - }, - "id": { - "type": "string", - "pattern": "^mcp_server-[a-fA-F0-9]{8}", - "title": "Id", - "description": "The human-friendly ID of the Mcp_server", - "examples": ["mcp_server-123e4567-e89b-12d3-a456-426614174000"] - } - }, - "type": "object", - "required": ["server_name", "server_url"], - "title": "StreamableHTTPMCPServer", - "description": "A Streamable HTTP MCP server" - }, "StreamableHTTPServerConfig": { "properties": { "server_name": { @@ -37887,82 +37029,6 @@ "required": ["created_at", "description", "key", "updated_at", "value"], "title": "ToolEnvVarSchema" }, - "ToolExecutionResult": { - "properties": { - "status": { - "type": "string", - "enum": ["success", "error"], - "title": "Status", - "description": "The status of the tool execution and return object" - }, - "func_return": { - "anyOf": [ - {}, - { - "type": "null" - } - ], - "title": "Func Return", - "description": "The function return object" - }, - "agent_state": { - "anyOf": [ - { - "$ref": "#/components/schemas/AgentState" - }, - { - "type": "null" - } - ], - "description": "The agent state" - }, - "stdout": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Stdout", - "description": "Captured stdout (prints, logs) from function invocation" - }, - "stderr": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Stderr", - "description": "Captured stderr from the function invocation" - }, - "sandbox_config_fingerprint": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Sandbox Config Fingerprint", - "description": "The fingerprint of the config for the sandbox" - } - }, - "type": "object", - "required": ["status"], - "title": "ToolExecutionResult" - }, "ToolJSONSchema": { "properties": { "name": { @@ -39024,6 +38090,131 @@ "required": ["reasoning"], "title": "UpdateReasoningMessage" }, + "UpdateSSEMCPServer": { + "properties": { + "server_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Url", + "description": "The URL of the server (MCP SSE client will connect to this URL)" + }, + "token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Token", + "description": "The access token or API key for the MCP server (used for SSE authentication)" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom authentication headers as key-value pairs" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateSSEMCPServer", + "description": "Update an SSE MCP server" + }, + "UpdateStdioMCPServer": { + "properties": { + "stdio_config": { + "anyOf": [ + { + "$ref": "#/components/schemas/StdioServerConfig" + }, + { + "type": "null" + } + ], + "description": "The configuration for the server (MCP 'local' client will run this command)" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateStdioMCPServer", + "description": "Update a Stdio MCP server" + }, + "UpdateStreamableHTTPMCPServer": { + "properties": { + "server_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Url", + "description": "The URL path for the streamable HTTP server (e.g., 'example/mcp')" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom authentication headers as key-value pairs" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateStreamableHTTPMCPServer", + "description": "Update a Streamable HTTP MCP server" + }, "UpdateSystemMessage": { "properties": { "message_type": { @@ -40616,325 +39807,6 @@ "required": ["tool_return", "status", "tool_call_id"], "title": "ToolReturn" }, - "letta__schemas__mcp__UpdateSSEMCPServer": { - "properties": { - "server_url": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Url", - "description": "The URL of the server (MCP SSE client will connect to this URL)" - }, - "token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Token", - "description": "The access token or API key for the MCP server (used for SSE authentication)" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom authentication headers as key-value pairs" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateSSEMCPServer", - "description": "Update an SSE MCP server" - }, - "letta__schemas__mcp__UpdateStdioMCPServer": { - "properties": { - "stdio_config": { - "anyOf": [ - { - "$ref": "#/components/schemas/StdioServerConfig" - }, - { - "type": "null" - } - ], - "description": "The configuration for the server (MCP 'local' client will run this command)" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateStdioMCPServer", - "description": "Update a Stdio MCP server" - }, - "letta__schemas__mcp__UpdateStreamableHTTPMCPServer": { - "properties": { - "server_url": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Url", - "description": "The URL path for the streamable HTTP server (e.g., 'example/mcp')" - }, - "auth_header": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Header", - "description": "The name of the authentication header (e.g., 'Authorization')" - }, - "auth_token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Token", - "description": "The authentication token or API key value" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom authentication headers as key-value pairs" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateStreamableHTTPMCPServer", - "description": "Update a Streamable HTTP MCP server" - }, - "letta__schemas__mcp_server__MCPToolExecuteRequest": { - "properties": { - "args": { - "additionalProperties": true, - "type": "object", - "title": "Args", - "description": "Arguments to pass to the MCP tool" - } - }, - "additionalProperties": false, - "type": "object", - "title": "MCPToolExecuteRequest", - "description": "Request to execute an MCP tool by IDs." - }, - "letta__schemas__mcp_server__UpdateSSEMCPServer": { - "properties": { - "server_name": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Name", - "description": "The name of the MCP server" - }, - "server_url": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Url", - "description": "The URL of the SSE MCP server" - }, - "token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Token", - "description": "The authentication token" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom headers to send with requests" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateSSEMCPServer", - "description": "Update schema for SSE MCP server - all fields optional" - }, - "letta__schemas__mcp_server__UpdateStdioMCPServer": { - "properties": { - "server_name": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Name", - "description": "The name of the MCP server" - }, - "command": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Command", - "description": "The command to run the MCP server" - }, - "args": { - "anyOf": [ - { - "items": { - "type": "string" - }, - "type": "array" - }, - { - "type": "null" - } - ], - "title": "Args", - "description": "The arguments to pass to the command" - }, - "env": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Env", - "description": "Environment variables to set" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateStdioMCPServer", - "description": "Update schema for Stdio MCP server - all fields optional" - }, - "letta__schemas__mcp_server__UpdateStreamableHTTPMCPServer": { - "properties": { - "server_name": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Name", - "description": "The name of the MCP server" - }, - "server_url": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Url", - "description": "The URL of the Streamable HTTP MCP server" - }, - "token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Token", - "description": "The authentication token" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom headers to send with requests" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateStreamableHTTPMCPServer", - "description": "Update schema for Streamable HTTP MCP server - all fields optional" - }, "letta__schemas__message__ToolReturn": { "properties": { "tool_call_id": { @@ -41334,18 +40206,6 @@ ], "title": "ToolSchema" }, - "letta__server__rest_api__routers__v1__tools__MCPToolExecuteRequest": { - "properties": { - "args": { - "additionalProperties": true, - "type": "object", - "title": "Args", - "description": "Arguments to pass to the MCP tool" - } - }, - "type": "object", - "title": "MCPToolExecuteRequest" - }, "openai__types__chat__chat_completion_message_function_tool_call__Function": { "properties": { "arguments": { diff --git a/letta/orm/mcp_server.py b/letta/orm/mcp_server.py index 14888baf..49cffb84 100644 --- a/letta/orm/mcp_server.py +++ b/letta/orm/mcp_server.py @@ -56,12 +56,3 @@ class MCPServer(SqlalchemyBase, OrganizationMixin): metadata_: Mapped[Optional[dict]] = mapped_column( JSON, default=lambda: {}, doc="A dictionary of additional metadata for the MCP server." ) - - -class MCPTools(SqlalchemyBase, OrganizationMixin): - """Represents a mapping of MCP server ID to tool ID""" - - __tablename__ = "mcp_tools" - - mcp_server_id: Mapped[str] = mapped_column(String, doc="The ID of the MCP server") - tool_id: Mapped[str] = mapped_column(String, doc="The ID of the tool") diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 520a77b4..8568485c 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -11,7 +11,6 @@ from letta.server.rest_api.routers.v1.internal_runs import router as internal_ru from letta.server.rest_api.routers.v1.internal_templates import router as internal_templates_router from letta.server.rest_api.routers.v1.jobs import router as jobs_router from letta.server.rest_api.routers.v1.llms import router as llm_router -from letta.server.rest_api.routers.v1.mcp_servers import router as mcp_servers_router from letta.server.rest_api.routers.v1.messages import router as messages_router from letta.server.rest_api.routers.v1.providers import router as providers_router from letta.server.rest_api.routers.v1.runs import router as runs_router @@ -35,7 +34,6 @@ ROUTERS = [ internal_runs_router, internal_templates_router, llm_router, - mcp_servers_router, blocks_router, jobs_router, health_router, diff --git a/letta/server/rest_api/routers/v1/mcp_servers.py b/letta/server/rest_api/routers/v1/mcp_servers.py index 618fb331..e29c5a1d 100644 --- a/letta/server/rest_api/routers/v1/mcp_servers.py +++ b/letta/server/rest_api/routers/v1/mcp_servers.py @@ -1,10 +1,8 @@ -from typing import Any, AsyncGenerator, Dict, List, Optional, Union +from typing import Any, Dict, List, Optional -from fastapi import APIRouter, Body, Depends, HTTPException, Request -from httpx import HTTPStatusError +from fastapi import APIRouter, Body, Depends, HTTPException from starlette.responses import StreamingResponse -from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig from letta.log import get_logger from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.mcp_server import ( @@ -15,17 +13,12 @@ from letta.schemas.mcp_server import ( convert_generic_to_union, ) from letta.schemas.tool import Tool -from letta.schemas.tool_execution_result import ToolExecutionResult from letta.server.rest_api.dependencies import ( HeaderParams, get_headers, get_letta_server, ) -from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode from letta.server.server import SyncServer -from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event -from letta.services.mcp.stdio_client import AsyncStdioMCPClient -from letta.services.mcp.types import OauthStreamEvent from letta.settings import tool_settings router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"]) @@ -46,7 +39,6 @@ async def create_mcp_server( """ Add a new MCP server to the Letta MCP server config """ - # TODO: add the tools to the MCP server table we made. actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) new_server = await server.mcp_server_manager.create_mcp_server_from_config_with_tools(request, actor=actor) return convert_generic_to_union(new_server) @@ -64,6 +56,7 @@ async def list_mcp_servers( """ Get a list of all configured MCP servers """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor) return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers] @@ -134,10 +127,24 @@ async def list_mcp_tools_by_server( """ Get a list of all tools for a specific MCP server """ - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # Use the new efficient method that queries from the database using MCPTools mapping - tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor) - return tools + # TODO: implement this. We want to use the new tools table instead of going to the mcp server. + pass + # actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + # mcp_tools = await server.mcp_server_manager.list_mcp_server_tools(mcp_server_id, actor=actor) + # # Convert MCPTool objects to Tool objects + # tools = [] + # for mcp_tool in mcp_tools: + # from letta.schemas.tool import ToolCreate + # tool_create = ToolCreate.from_mcp(mcp_server_name="", mcp_tool=mcp_tool) + # tools.append(Tool( + # id=f"mcp-tool-{mcp_tool.name}", # Generate a temporary ID + # name=mcp_tool.name, + # description=tool_create.description, + # json_schema=tool_create.json_schema, + # source_code=tool_create.source_code, + # tags=tool_create.tags, + # )) + # return tools @router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_get_mcp_tool") @@ -151,11 +158,13 @@ async def get_mcp_tool( Get a specific MCP tool by its ID """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor) + # Use the tool_manager's existing method to get the tool by ID + # Verify the tool belongs to the MCP server (optional check) + tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor) return tool -@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool") +@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolReturnMessage, operation_id="mcp_run_tool") async def run_mcp_tool( mcp_server_id: str, tool_id: str, @@ -179,10 +188,9 @@ async def run_mcp_tool( actor=actor, ) - # Create a ToolExecutionResult - return ToolExecutionResult( - status="success" if success else "error", - func_return=result, + # Create a ToolReturnMessage + return ToolReturnMessage( + id=f"tool-return-{tool_id}", tool_call_id=f"call-{tool_id}", tool_return=result, status="success" if success else "error" ) @@ -223,7 +231,6 @@ async def refresh_mcp_server_tools( ) async def connect_mcp_server( mcp_server_id: str, - request: Request, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), ) -> StreamingResponse: @@ -231,76 +238,72 @@ async def connect_mcp_server( Connect to an MCP server with support for OAuth via SSE. Returns a stream of events handling authorization state and exchange if OAuth is required. """ - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor) + pass - # Convert the MCP server to the appropriate config type - config = mcp_server.to_config(resolve_variables=False) + # async def oauth_stream_generator( + # request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], + # http_request: Request, + # ) -> AsyncGenerator[str, None]: + # client = None - async def oauth_stream_generator( - mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], - http_request: Request, - ) -> AsyncGenerator[str, None]: - client = None + # oauth_flow_attempted = False + # try: + # # Acknolwedge connection attempt + # yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name) - oauth_flow_attempted = False - try: - # Acknowledge connection attempt - yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name) + # actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # Create MCP client with respective transport type - try: - mcp_config.resolve_environment_variables() - client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor) - except ValueError as e: - yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e)) - return + # # Create MCP client with respective transport type + # try: + # request.resolve_environment_variables() + # client = await server.mcp_server_manager.get_mcp_client(request, actor) + # except ValueError as e: + # yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e)) + # return - # Try normal connection first for flows that don't require OAuth - try: - await client.connect_to_server() - tools = await client.list_tools(serialize=True) - yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) - return - except ConnectionError: - # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error - if isinstance(client, AsyncStdioMCPClient): - logger.warning("OAuth not supported for stdio") - yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio") - return - # Continue to OAuth flow - logger.info(f"Attempting OAuth flow for {mcp_config}...") - except Exception as e: - yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}") - return - finally: - if client: - try: - await client.cleanup() - # This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py - # For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception - except HTTPStatusError: - oauth_flow_attempted = True - async for event in server.mcp_server_manager.handle_oauth_flow( - request=mcp_config, actor=actor, http_request=http_request - ): - yield event + # # Try normal connection first for flows that don't require OAuth + # try: + # await client.connect_to_server() + # tools = await client.list_tools(serialize=True) + # yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) + # return + # except ConnectionError: + # # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error + # if isinstance(client, AsyncStdioMCPClient): + # logger.warning("OAuth not supported for stdio") + # yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio") + # return + # # Continue to OAuth flow + # logger.info(f"Attempting OAuth flow for {request}...") + # except Exception as e: + # yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}") + # return + # finally: + # if client: + # try: + # await client.cleanup() + # # This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py + # # For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception + # except* HTTPStatusError: + # oauth_flow_attempted = True + # async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request): + # yield event - # Failsafe to make sure we don't try to handle OAuth flow twice - if not oauth_flow_attempted: - async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request): - yield event - return - except Exception as e: - detailed_error = drill_down_exception(e) - logger.error(f"Error in OAuth stream:\n{detailed_error}") - yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}") + # # Failsafe to make sure we don't try to handle OAuth flow twice + # if not oauth_flow_attempted: + # async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request): + # yield event + # return + # except Exception as e: + # detailed_error = drill_down_exception(e) + # logger.error(f"Error in OAuth stream:\n{detailed_error}") + # yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}") - finally: - if client: - try: - await client.cleanup() - except Exception as cleanup_error: - logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") + # finally: + # if client: + # try: + # await client.cleanup() + # except Exception as cleanup_error: + # logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") - return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream") + # return StreamingResponseWithStatusCode(oauth_stream_generator(request, http_request), media_type="text/event-stream") diff --git a/letta/server/server.py b/letta/server/server.py index 4cac26b4..b6a010d0 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -94,7 +94,6 @@ from letta.services.mcp.base_client import AsyncBaseMCPClient from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient from letta.services.mcp.stdio_client import AsyncStdioMCPClient from letta.services.mcp_manager import MCPManager -from letta.services.mcp_server_manager import MCPServerManager from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager from letta.services.passage_manager import PassageManager @@ -155,7 +154,6 @@ class SyncServer(object): self.user_manager = UserManager() self.tool_manager = ToolManager() self.mcp_manager = MCPManager() - self.mcp_server_manager = MCPServerManager() self.block_manager = BlockManager() self.source_manager = SourceManager() self.sandbox_config_manager = SandboxConfigManager() diff --git a/letta/services/mcp_server_manager.py b/letta/services/mcp_server_manager.py deleted file mode 100644 index 9296164f..00000000 --- a/letta/services/mcp_server_manager.py +++ /dev/null @@ -1,1311 +0,0 @@ -import json -import os -import secrets -import uuid -from datetime import datetime, timedelta -from typing import Any, Dict, List, Optional, Tuple, Union - -from fastapi import HTTPException -from sqlalchemy import delete, desc, null, select -from starlette.requests import Request - -import letta.constants as constants -from letta.functions.mcp_client.types import ( - MCPServerType, - MCPTool, - MCPToolHealth, - SSEServerConfig, - StdioServerConfig, - StreamableHTTPServerConfig, -) -from letta.functions.schema_generator import normalize_mcp_schema -from letta.functions.schema_validator import validate_complete_json_schema -from letta.log import get_logger -from letta.orm.errors import NoResultFound -from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus -from letta.orm.mcp_server import MCPServer as MCPServerModel, MCPTools as MCPToolsModel -from letta.orm.tool import Tool as ToolModel -from letta.schemas.mcp import ( - MCPOAuthSession, - MCPOAuthSessionCreate, - MCPOAuthSessionUpdate, - MCPServer, - MCPServerResyncResult, - UpdateMCPServer, - UpdateSSEMCPServer, - UpdateStdioMCPServer, - UpdateStreamableHTTPMCPServer, -) -from letta.schemas.secret import Secret -from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate -from letta.schemas.user import User as PydanticUser -from letta.server.db import db_registry -from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient -from letta.services.mcp.stdio_client import AsyncStdioMCPClient -from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient -from letta.services.tool_manager import ToolManager -from letta.settings import settings, tool_settings -from letta.utils import enforce_types, printd, safe_create_task - -logger = get_logger(__name__) - - -class MCPServerManager: - """Manager class to handle business logic related to MCP.""" - - def __init__(self): - # TODO: timeouts? - self.tool_manager = ToolManager() - self.cached_mcp_servers = {} # maps id -> async connection - - # MCPTools mapping table management methods - @enforce_types - async def create_mcp_tool_mapping(self, mcp_server_id: str, tool_id: str, actor: PydanticUser) -> None: - """Create a mapping between an MCP server and a tool.""" - async with db_registry.async_session() as session: - mapping = MCPToolsModel( - mcp_server_id=mcp_server_id, - tool_id=tool_id, - organization_id=actor.organization_id, - ) - await mapping.create_async(session, actor=actor) - - @enforce_types - async def delete_mcp_tool_mappings_by_server(self, mcp_server_id: str, actor: PydanticUser) -> None: - """Delete all tool mappings for a specific MCP server.""" - async with db_registry.async_session() as session: - await session.execute( - delete(MCPToolsModel).where( - MCPToolsModel.mcp_server_id == mcp_server_id, - MCPToolsModel.organization_id == actor.organization_id, - ) - ) - await session.commit() - - @enforce_types - async def get_tool_ids_by_mcp_server(self, mcp_server_id: str, actor: PydanticUser) -> List[str]: - """Get all tool IDs associated with an MCP server.""" - async with db_registry.async_session() as session: - result = await session.execute( - select(MCPToolsModel.tool_id).where( - MCPToolsModel.mcp_server_id == mcp_server_id, - MCPToolsModel.organization_id == actor.organization_id, - ) - ) - return [row[0] for row in result.fetchall()] - - @enforce_types - async def get_mcp_server_id_by_tool(self, tool_id: str, actor: PydanticUser) -> Optional[str]: - """Get the MCP server ID associated with a tool.""" - async with db_registry.async_session() as session: - result = await session.execute( - select(MCPToolsModel.mcp_server_id).where( - MCPToolsModel.tool_id == tool_id, - MCPToolsModel.organization_id == actor.organization_id, - ) - ) - row = result.fetchone() - return row[0] if row else None - - @enforce_types - async def list_tools_by_mcp_server_from_db(self, mcp_server_id: str, actor: PydanticUser) -> List[PydanticTool]: - """ - Get tools associated with an MCP server from the database using the MCPTools mapping. - This is more efficient than fetching from the MCP server directly. - """ - # First get all tool IDs associated with this MCP server - tool_ids = await self.get_tool_ids_by_mcp_server(mcp_server_id, actor) - - if not tool_ids: - return [] - - # Fetch all tools in a single query - async with db_registry.async_session() as session: - result = await session.execute( - select(ToolModel).where( - ToolModel.id.in_(tool_ids), - ToolModel.organization_id == actor.organization_id, - ) - ) - tools = result.scalars().all() - return [tool.to_pydantic() for tool in tools] - - @enforce_types - async def get_tool_by_mcp_server(self, mcp_server_id: str, tool_id: str, actor: PydanticUser) -> Optional[PydanticTool]: - """ - Get a specific tool that belongs to an MCP server. - Verifies the tool is associated with the MCP server via the mapping table. - """ - async with db_registry.async_session() as session: - # Check if the tool is associated with this MCP server - result = await session.execute( - select(MCPToolsModel).where( - MCPToolsModel.mcp_server_id == mcp_server_id, - MCPToolsModel.tool_id == tool_id, - MCPToolsModel.organization_id == actor.organization_id, - ) - ) - mapping = result.scalar_one_or_none() - - if not mapping: - return None - - # Fetch the tool - tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) - return tool.to_pydantic() - - @enforce_types - async def list_mcp_server_tools(self, mcp_server_id: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]: - """Get a list of all tools for a specific MCP server by server ID.""" - mcp_client = None - try: - mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) - server_config = mcp_config.to_config() - mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id) - await mcp_client.connect_to_server() - - # list tools - tools = await mcp_client.list_tools() - # Add health information to each tool - for tool in tools: - # Try to normalize the schema and re-validate - if tool.inputSchema: - tool.inputSchema = normalize_mcp_schema(tool.inputSchema) - health_status, reasons = validate_complete_json_schema(tool.inputSchema) - tool.health = MCPToolHealth(status=health_status.value, reasons=reasons) - - return tools - except Exception as e: - # MCP tool listing errors are often due to connection/configuration issues, not system errors - # Log at info level to avoid triggering Sentry alerts for expected failures - logger.warning(f"Error listing tools for MCP server {mcp_server_id}: {e}") - raise e - finally: - if mcp_client: - try: - await mcp_client.cleanup() - except Exception as e: - logger.warning(f"Error listing tools for MCP server {mcp_server_id}: {e}") - raise e - - @enforce_types - async def execute_mcp_server_tool( - self, - mcp_server_id: str, - tool_id: str, - tool_args: Optional[Dict[str, Any]], - environment_variables: Dict[str, str], - actor: PydanticUser, - agent_id: Optional[str] = None, - ) -> Tuple[str, bool]: - """Call a specific tool from a specific MCP server by IDs.""" - mcp_client = None - try: - # Get the tool to find its actual name - async with db_registry.async_session() as session: - tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) - tool_name = tool.name - - # Get the MCP server config - mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) - server_config = mcp_config.to_config(environment_variables) - - mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id) - await mcp_client.connect_to_server() - - # call tool - result, success = await mcp_client.execute_tool(tool_name, tool_args) - logger.info(f"MCP Result: {result}, Success: {success}") - return result, success - finally: - if mcp_client: - await mcp_client.cleanup() - - @enforce_types - async def add_tool_from_mcp_server(self, mcp_server_id: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool: - """Add a tool from an MCP server to the Letta tool registry.""" - # Get the MCP server to get its name - mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) - mcp_server_name = mcp_server.server_name - - mcp_tools = await self.list_mcp_server_tools(mcp_server_id, actor=actor) - for mcp_tool in mcp_tools: - # TODO: @jnjpng move health check to tool class - if mcp_tool.name == mcp_tool_name: - # Check tool health - but try normalization first for INVALID schemas - if mcp_tool.health and mcp_tool.health.status == "INVALID": - logger.info(f"Attempting to normalize INVALID schema for tool {mcp_tool_name}") - logger.info(f"Original health reasons: {mcp_tool.health.reasons}") - - # Try to normalize the schema and re-validate - try: - # Normalize the schema to fix common issues - logger.debug(f"Normalizing schema for {mcp_tool_name}") - normalized_schema = normalize_mcp_schema(mcp_tool.inputSchema) - - # Re-validate after normalization - logger.debug(f"Re-validating schema for {mcp_tool_name}") - health_status, health_reasons = validate_complete_json_schema(normalized_schema) - logger.info(f"After normalization: status={health_status.value}, reasons={health_reasons}") - - # Update the tool's schema and health (use inputSchema, not input_schema) - mcp_tool.inputSchema = normalized_schema - mcp_tool.health.status = health_status.value - mcp_tool.health.reasons = health_reasons - - # Log the normalization result - if health_status.value != "INVALID": - logger.info(f"✓ MCP tool {mcp_tool_name} schema normalized successfully: {health_status.value}") - else: - logger.warning(f"MCP tool {mcp_tool_name} still INVALID after normalization. Reasons: {health_reasons}") - except Exception as e: - logger.error(f"Failed to normalize schema for tool {mcp_tool_name}: {e}", exc_info=True) - - # After normalization attempt, check if still INVALID - if mcp_tool.health and mcp_tool.health.status == "INVALID": - logger.warning(f"Tool {mcp_tool_name} has potentially invalid schema. Reasons: {', '.join(mcp_tool.health.reasons)}") - - tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) - created_tool = await self.tool_manager.create_mcp_tool_async( - tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor - ) - - # Create mapping in MCPTools table - if created_tool: - await self.create_mcp_tool_mapping(mcp_server_id, created_tool.id, actor) - - return created_tool - - # failed to add - handle error? - return None - - @enforce_types - async def resync_mcp_server_tools( - self, mcp_server_id: str, actor: PydanticUser, agent_id: Optional[str] = None - ) -> MCPServerResyncResult: - """ - Resync tools for an MCP server by: - 1. Fetching current tools from the MCP server - 2. Deleting tools that no longer exist on the server - 3. Updating schemas for existing tools - 4. Adding new tools from the server - - Returns a result with: - - deleted: List of deleted tool names - - updated: List of updated tool names - - added: List of added tool names - """ - # Get the MCP server to get its name - mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) - mcp_server_name = mcp_server.server_name - - # Fetch current tools from MCP server - try: - current_mcp_tools = await self.list_mcp_server_tools(mcp_server_id, actor=actor, agent_id=agent_id) - except Exception as e: - logger.error(f"Failed to fetch tools from MCP server {mcp_server_name}: {e}") - raise HTTPException( - status_code=404, - detail={ - "code": "MCPServerUnavailable", - "message": f"Could not connect to MCP server {mcp_server_name} to resync tools", - "error": str(e), - }, - ) - - # Get all persisted tools for this MCP server - async with db_registry.async_session() as session: - # Query for tools with MCP metadata matching this server - # Using JSON path query to filter by metadata - persisted_tools = await ToolModel.list_async( - db_session=session, - organization_id=actor.organization_id, - ) - - # Filter tools that belong to this MCP server - mcp_tools = [] - for tool in persisted_tools: - if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_: - if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id: - mcp_tools.append(tool) - - # Create maps for easier comparison - current_tool_map = {tool.name: tool for tool in current_mcp_tools} - persisted_tool_map = {tool.name: tool for tool in mcp_tools} - - deleted_tools = [] - updated_tools = [] - added_tools = [] - - # 1. Delete tools that no longer exist on the server - for tool_name, persisted_tool in persisted_tool_map.items(): - if tool_name not in current_tool_map: - # Delete the tool (cascade will handle agent detachment) - await persisted_tool.hard_delete_async(db_session=session, actor=actor) - deleted_tools.append(tool_name) - logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}") - - # Commit deletions - await session.commit() - - # 2. Update existing tools and add new tools - for tool_name, current_tool in current_tool_map.items(): - if tool_name in persisted_tool_map: - # Update existing tool - persisted_tool = persisted_tool_map[tool_name] - tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool) - - # Check if schema has changed - if persisted_tool.json_schema != tool_create.json_schema: - # Update the tool - update_data = ToolUpdate( - description=tool_create.description, - json_schema=tool_create.json_schema, - source_code=tool_create.source_code, - ) - - await self.tool_manager.update_tool_by_id_async(tool_id=persisted_tool.id, tool_update=update_data, actor=actor) - updated_tools.append(tool_name) - logger.info(f"Updated MCP tool {tool_name} with new schema from server {mcp_server_name}") - else: - # Add new tool - # Skip INVALID tools - if current_tool.health and current_tool.health.status == "INVALID": - logger.warning( - f"Skipping invalid tool {tool_name} from MCP server {mcp_server_name}: {', '.join(current_tool.health.reasons)}" - ) - continue - - tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool) - created_tool = await self.tool_manager.create_mcp_tool_async( - tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor - ) - - # Create mapping in MCPTools table - if created_tool: - await self.create_mcp_tool_mapping(mcp_server_id, created_tool.id, actor) - added_tools.append(tool_name) - logger.info(f"Added new MCP tool {tool_name} from server {mcp_server_name} with mapping") - - return MCPServerResyncResult( - deleted=deleted_tools, - updated=updated_tools, - added=added_tools, - ) - - @enforce_types - async def list_mcp_servers(self, actor: PydanticUser) -> List[MCPServer]: - """List all MCP servers available""" - async with db_registry.async_session() as session: - mcp_servers = await MCPServerModel.list_async( - db_session=session, - organization_id=actor.organization_id, - ) - - return [mcp_server.to_pydantic() for mcp_server in mcp_servers] - - @enforce_types - async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: - """Create a new tool based on the ToolCreate schema.""" - mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name=pydantic_mcp_server.server_name, actor=actor) - if mcp_server_id: - # Put to dict and remove fields that should not be reset - update_data = pydantic_mcp_server.model_dump(exclude_unset=True, exclude_none=True) - - # If there's anything to update (can only update the configs, not the name) - # TODO: pass in custom headers for update as well? - if update_data: - if pydantic_mcp_server.server_type == MCPServerType.SSE: - update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token) - elif pydantic_mcp_server.server_type == MCPServerType.STDIO: - update_request = UpdateStdioMCPServer(stdio_config=pydantic_mcp_server.stdio_config) - elif pydantic_mcp_server.server_type == MCPServerType.STREAMABLE_HTTP: - update_request = UpdateStreamableHTTPMCPServer( - server_url=pydantic_mcp_server.server_url, auth_token=pydantic_mcp_server.token - ) - else: - raise ValueError(f"Unsupported server type: {pydantic_mcp_server.server_type}") - mcp_server = await self.update_mcp_server_by_id(mcp_server_id, update_request, actor) - else: - printd( - f"`create_or_update_mcp_server` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_mcp_server.server_name}, but found existing mcp server with nothing to update." - ) - mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) - else: - mcp_server = await self.create_mcp_server(pydantic_mcp_server, actor=actor) - - return mcp_server - - @enforce_types - async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: - """Create a new MCP server.""" - async with db_registry.async_session() as session: - try: - # Set the organization id at the ORM layer - pydantic_mcp_server.organization_id = actor.organization_id - - # Explicitly populate encrypted fields - if pydantic_mcp_server.token is not None: - pydantic_mcp_server.token_enc = Secret.from_plaintext(pydantic_mcp_server.token) - if pydantic_mcp_server.custom_headers is not None: - # custom_headers is a Dict[str, str], serialize to JSON then encrypt - import json - - json_str = json.dumps(pydantic_mcp_server.custom_headers) - pydantic_mcp_server.custom_headers_enc = Secret.from_plaintext(json_str) - - mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) - - # Ensure custom_headers None is stored as SQL NULL, not JSON null - if mcp_server_data.get("custom_headers") is None: - mcp_server_data.pop("custom_headers", None) - - mcp_server = MCPServerModel(**mcp_server_data) - mcp_server = await mcp_server.create_async(session, actor=actor, no_commit=True) - - # Link existing OAuth sessions for the same user and server URL - # This ensures OAuth sessions created during testing get linked to the server - server_url = getattr(mcp_server, "server_url", None) - if server_url: - result = await session.execute( - select(MCPOAuth).where( - MCPOAuth.server_url == server_url, - MCPOAuth.organization_id == actor.organization_id, - MCPOAuth.user_id == actor.id, # Only link sessions for the same user - MCPOAuth.server_id.is_(None), # Only update sessions not already linked - ) - ) - oauth_sessions = result.scalars().all() - - # TODO: @jnjpng we should upate sessions in bulk - for oauth_session in oauth_sessions: - oauth_session.server_id = mcp_server.id - await oauth_session.update_async(db_session=session, actor=actor, no_commit=True) - - if oauth_sessions: - logger.info( - f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}" - ) - - await session.commit() - return mcp_server.to_pydantic() - except Exception as e: - await session.rollback() - raise - - @enforce_types - async def create_mcp_server_from_config( - self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser - ) -> MCPServer: - """ - Create an MCP server from a config object, handling encryption of sensitive fields. - - This method converts the server config to an MCPServer model and encrypts - sensitive fields like tokens and custom headers. - """ - # Create base MCPServer object - if isinstance(server_config, StdioServerConfig): - mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config) - elif isinstance(server_config, SSEServerConfig): - mcp_server = MCPServer( - server_name=server_config.server_name, - server_type=server_config.type, - server_url=server_config.server_url, - ) - # Encrypt sensitive fields - token = server_config.resolve_token() - if token: - token_secret = Secret.from_plaintext(token) - mcp_server.set_token_secret(token_secret) - if server_config.custom_headers: - # Convert dict to JSON string, then encrypt as Secret - headers_json = json.dumps(server_config.custom_headers) - headers_secret = Secret.from_plaintext(headers_json) - mcp_server.set_custom_headers_secret(headers_secret) - - elif isinstance(server_config, StreamableHTTPServerConfig): - mcp_server = MCPServer( - server_name=server_config.server_name, - server_type=server_config.type, - server_url=server_config.server_url, - ) - # Encrypt sensitive fields - token = server_config.resolve_token() - if token: - token_secret = Secret.from_plaintext(token) - mcp_server.set_token_secret(token_secret) - if server_config.custom_headers: - # Convert dict to JSON string, then encrypt as Secret - headers_json = json.dumps(server_config.custom_headers) - headers_secret = Secret.from_plaintext(headers_json) - mcp_server.set_custom_headers_secret(headers_secret) - else: - raise ValueError(f"Unsupported server config type: {type(server_config)}") - - return mcp_server - - @enforce_types - async def create_mcp_server_from_config_with_tools( - self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser - ) -> MCPServer: - """ - Create an MCP server from a config object and optimistically sync its tools. - - This method handles encryption of sensitive fields and then creates the server - with automatic tool synchronization. - """ - # Convert config to MCPServer with encryption - mcp_server = await self.create_mcp_server_from_config(server_config, actor) - - # Create the server with tools - return await self.create_mcp_server_with_tools(mcp_server, actor) - - @enforce_types - async def create_mcp_server_with_tools(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: - """ - Create a new MCP server and optimistically sync its tools. - - This method: - 1. Creates the MCP server record - 2. Attempts to connect and fetch tools - 3. Persists valid tools in parallel (best-effort) - """ - import asyncio - - # First, create the MCP server - created_server = await self.create_mcp_server(pydantic_mcp_server, actor) - - # Optimistically try to sync tools - try: - logger.info(f"Attempting to auto-sync tools from MCP server: {created_server.server_name}") - - # List all tools from the MCP server - mcp_tools = await self.list_mcp_server_tools(created_server.id, actor=actor) - - # Filter out invalid tools - valid_tools = [tool for tool in mcp_tools if not (tool.health and tool.health.status == "INVALID")] - - # Register in parallel - if valid_tools: - tool_tasks = [] - for mcp_tool in valid_tools: - tool_create = ToolCreate.from_mcp(mcp_server_name=created_server.server_name, mcp_tool=mcp_tool) - task = self.tool_manager.create_mcp_tool_async( - tool_create=tool_create, mcp_server_name=created_server.server_name, mcp_server_id=created_server.id, actor=actor - ) - tool_tasks.append(task) - - results = await asyncio.gather(*tool_tasks, return_exceptions=True) - - # Create mappings in MCPTools table for successful tools - mapping_tasks = [] - successful_count = 0 - for result in results: - if not isinstance(result, Exception) and result: - # result should be a PydanticTool - mapping_task = self.create_mcp_tool_mapping(created_server.id, result.id, actor) - mapping_tasks.append(mapping_task) - successful_count += 1 - - # Execute mapping creation in parallel - if mapping_tasks: - await asyncio.gather(*mapping_tasks, return_exceptions=True) - - failed = len(results) - successful_count - logger.info( - f"Auto-sync completed for MCP server {created_server.server_name}: " - f"{successful_count} tools persisted with mappings, {failed} failed, " - f"{len(mcp_tools) - len(valid_tools)} invalid tools skipped" - ) - else: - logger.info(f"No valid tools found to sync from MCP server {created_server.server_name}") - - except Exception as e: - # Log the error but don't fail the server creation - logger.warning( - f"Failed to auto-sync tools from MCP server {created_server.server_name}: {e}. " - f"Server was created successfully but tools were not persisted." - ) - - return created_server - - @enforce_types - async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer: - """Update a tool by its ID with the given ToolUpdate object.""" - async with db_registry.async_session() as session: - # Fetch the tool by ID - mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) - - # Update tool attributes with only the fields that were explicitly set - update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True) - - # Handle encryption for token if provided - # Only re-encrypt if the value has actually changed - if "token" in update_data and update_data["token"] is not None: - # Check if value changed - existing_token = None - if mcp_server.token_enc: - existing_secret = Secret.from_encrypted(mcp_server.token_enc) - existing_token = existing_secret.get_plaintext() - elif mcp_server.token: - existing_token = mcp_server.token - - # Only re-encrypt if different - if existing_token != update_data["token"]: - mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted() - # Keep plaintext for dual-write during migration - mcp_server.token = update_data["token"] - - # Remove from update_data since we set directly on mcp_server - update_data.pop("token", None) - update_data.pop("token_enc", None) - - # Handle encryption for custom_headers if provided - # Only re-encrypt if the value has actually changed - if "custom_headers" in update_data: - if update_data["custom_headers"] is not None: - # custom_headers is a Dict[str, str], serialize to JSON then encrypt - import json - - json_str = json.dumps(update_data["custom_headers"]) - - # Check if value changed - existing_headers_json = None - if mcp_server.custom_headers_enc: - existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc) - existing_headers_json = existing_secret.get_plaintext() - elif mcp_server.custom_headers: - existing_headers_json = json.dumps(mcp_server.custom_headers) - - # Only re-encrypt if different - if existing_headers_json != json_str: - mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted() - # Keep plaintext for dual-write during migration - mcp_server.custom_headers = update_data["custom_headers"] - - # Remove from update_data since we set directly on mcp_server - update_data.pop("custom_headers", None) - update_data.pop("custom_headers_enc", None) - else: - # Ensure custom_headers None is stored as SQL NULL, not JSON null - update_data.pop("custom_headers", None) - setattr(mcp_server, "custom_headers", null()) - setattr(mcp_server, "custom_headers_enc", None) - - for key, value in update_data.items(): - setattr(mcp_server, key, value) - - mcp_server = await mcp_server.update_async(db_session=session, actor=actor) - - # Save the updated tool to the database mcp_server = await mcp_server.update_async(db_session=session, actor=actor) - return mcp_server.to_pydantic() - - @enforce_types - async def update_mcp_server_by_name(self, mcp_server_name: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer: - """Update an MCP server by its name.""" - mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) - if not mcp_server_id: - raise HTTPException( - status_code=404, - detail={ - "code": "MCPServerNotFoundError", - "message": f"MCP server {mcp_server_name} not found", - "mcp_server_name": mcp_server_name, - }, - ) - return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor) - - @enforce_types - async def get_mcp_server_id_by_name(self, mcp_server_name: str, actor: PydanticUser) -> Optional[str]: - """Retrieve a MCP server by its name and a user""" - try: - async with db_registry.async_session() as session: - mcp_server = await MCPServerModel.read_async(db_session=session, server_name=mcp_server_name, actor=actor) - return mcp_server.id - except NoResultFound: - return None - - @enforce_types - async def get_mcp_server_by_id_async(self, mcp_server_id: str, actor: PydanticUser) -> MCPServer: - """Fetch a tool by its ID.""" - async with db_registry.async_session() as session: - # Retrieve tool by id using the Tool model's read method - mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) - # Convert the SQLAlchemy Tool object to PydanticTool - return mcp_server.to_pydantic() - - @enforce_types - async def get_mcp_servers_by_ids(self, mcp_server_ids: List[str], actor: PydanticUser) -> List[MCPServer]: - """Fetch multiple MCP servers by their IDs in a single query.""" - if not mcp_server_ids: - return [] - - async with db_registry.async_session() as session: - mcp_servers = await MCPServerModel.list_async( - db_session=session, - organization_id=actor.organization_id, - id=mcp_server_ids, # This will use the IN operator - ) - return [mcp_server.to_pydantic() for mcp_server in mcp_servers] - - @enforce_types - async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: - """Get a MCP server by name.""" - async with db_registry.async_session() as session: - mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) - mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) - if not mcp_server: - raise HTTPException( - status_code=404, # Not Found - detail={ - "code": "MCPServerNotFoundError", - "message": f"MCP server {mcp_server_name} not found", - "mcp_server_name": mcp_server_name, - }, - ) - return mcp_server.to_pydantic() - - @enforce_types - async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None: - """Delete a MCP server by its ID and associated tools and OAuth sessions.""" - async with db_registry.async_session() as session: - try: - mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) - if not mcp_server: - raise NoResultFound(f"MCP server with id {mcp_server_id} not found.") - - server_url = getattr(mcp_server, "server_url", None) - # Get all tools with matching metadata - stmt = select(ToolModel).where(ToolModel.organization_id == actor.organization_id) - result = await session.execute(stmt) - all_tools = result.scalars().all() - - # Filter and delete tools that belong to this MCP server - tools_deleted = 0 - for tool in all_tools: - if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_: - if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id: - await tool.hard_delete_async(db_session=session, actor=actor) - tools_deleted = 1 - logger.info(f"Deleted MCP tool {tool.name} associated with MCP server {mcp_server_id}") - - if tools_deleted > 0: - logger.info(f"Deleted {tools_deleted} MCP tools associated with MCP server {mcp_server_id}") - - # Delete all MCPTools mappings for this server - await session.execute( - delete(MCPToolsModel).where( - MCPToolsModel.mcp_server_id == mcp_server_id, - MCPToolsModel.organization_id == actor.organization_id, - ) - ) - logger.info(f"Deleted MCPTools mappings for MCP server {mcp_server_id}") - - # Delete OAuth sessions for the same user and server URL in the same transaction - # This handles orphaned sessions that were created during testing/connection - oauth_count = 0 - if server_url: - result = await session.execute( - delete(MCPOAuth).where( - MCPOAuth.server_url == server_url, - MCPOAuth.organization_id == actor.organization_id, - MCPOAuth.user_id == actor.id, # Only delete sessions for the same user - ) - ) - oauth_count = result.rowcount - if oauth_count > 0: - logger.info( - f"Deleting {oauth_count} OAuth sessions for MCP server {mcp_server_id} (URL: {server_url}) for user {actor.id}" - ) - - # Delete the MCP server, will cascade delete to linked OAuth sessions - await session.execute( - delete(MCPServerModel).where( - MCPServerModel.id == mcp_server_id, - MCPServerModel.organization_id == actor.organization_id, - ) - ) - - await session.commit() - except NoResultFound: - await session.rollback() - raise ValueError(f"MCP server with id {mcp_server_id} not found.") - except Exception as e: - await session.rollback() - logger.error(f"Failed to delete MCP server {mcp_server_id}: {e}") - raise - - def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]]: - mcp_server_list = {} - - # Attempt to read from ~/.letta/mcp_config.json - mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) - if os.path.exists(mcp_config_path): - with open(mcp_config_path, "r") as f: - try: - mcp_config = json.load(f) - except Exception as e: - # Config parsing errors are user configuration issues, not system errors - logger.warning(f"Failed to parse MCP config file ({mcp_config_path}) as json: {e}") - return mcp_server_list - - # Proper formatting is "mcpServers" key at the top level, - # then a dict with the MCP server name as the key, - # with the value being the schema from StdioServerParameters - if MCP_CONFIG_TOPLEVEL_KEY in mcp_config: - for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items(): - # No support for duplicate server names - if server_name in mcp_server_list: - # Duplicate server names are configuration issues, not system errors - logger.warning(f"Duplicate MCP server name found (skipping): {server_name}") - continue - - if "url" in server_params_raw: - # Attempt to parse the server params as an SSE server - try: - server_params = SSEServerConfig( - server_name=server_name, - server_url=server_params_raw["url"], - auth_header=server_params_raw.get("auth_header", None), - auth_token=server_params_raw.get("auth_token", None), - headers=server_params_raw.get("headers", None), - ) - mcp_server_list[server_name] = server_params - except Exception as e: - # Config parsing errors are user configuration issues, not system errors - logger.warning(f"Failed to parse server params for MCP server {server_name} (skipping): {e}") - continue - else: - # Attempt to parse the server params as a StdioServerParameters - try: - server_params = StdioServerConfig( - server_name=server_name, - command=server_params_raw["command"], - args=server_params_raw.get("args", []), - env=server_params_raw.get("env", {}), - ) - mcp_server_list[server_name] = server_params - except Exception as e: - # Config parsing errors are user configuration issues, not system errors - logger.warning(f"Failed to parse server params for MCP server {server_name} (skipping): {e}") - continue - return mcp_server_list - - async def get_mcp_client( - self, - server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig], - actor: PydanticUser, - oauth_provider: Optional[Any] = None, - agent_id: Optional[str] = None, - ) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]: - """ - Helper function to create the appropriate MCP client based on server configuration. - - Args: - server_config: The server configuration object - actor: The user making the request - oauth_provider: Optional OAuth provider for authentication - - Returns: - The appropriate MCP client instance - - Raises: - ValueError: If server config type is not supported - """ - # If no OAuth provider is provided, check if we have stored OAuth credentials - if oauth_provider is None and hasattr(server_config, "server_url"): - oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor) - # Check if access token exists by attempting to decrypt it - if oauth_session and oauth_session.get_access_token_secret().get_plaintext(): - # Create OAuth provider from stored credentials - from letta.services.mcp.oauth_utils import create_oauth_provider - - oauth_provider = await create_oauth_provider( - session_id=oauth_session.id, - server_url=oauth_session.server_url, - redirect_uri=oauth_session.redirect_uri, - mcp_manager=self, - actor=actor, - ) - - if server_config.type == MCPServerType.SSE: - server_config = SSEServerConfig(**server_config.model_dump()) - return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) - elif server_config.type == MCPServerType.STDIO: - server_config = StdioServerConfig(**server_config.model_dump()) - return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) - elif server_config.type == MCPServerType.STREAMABLE_HTTP: - server_config = StreamableHTTPServerConfig(**server_config.model_dump()) - return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) - else: - raise ValueError(f"Unsupported server config type: {type(server_config)}") - - # OAuth-related methods - def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession: - """ - Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields. - """ - # Get decrypted values using the dual-read approach - # Secret.from_db() will automatically use settings.encryption_key if available - access_token = None - if oauth_session.access_token_enc or oauth_session.access_token: - if settings.encryption_key: - secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token) - access_token = secret.get_plaintext() - else: - # No encryption key, use plaintext if available - access_token = oauth_session.access_token - - refresh_token = None - if oauth_session.refresh_token_enc or oauth_session.refresh_token: - if settings.encryption_key: - secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token) - refresh_token = secret.get_plaintext() - else: - # No encryption key, use plaintext if available - refresh_token = oauth_session.refresh_token - - client_secret = None - if oauth_session.client_secret_enc or oauth_session.client_secret: - if settings.encryption_key: - secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret) - client_secret = secret.get_plaintext() - else: - # No encryption key, use plaintext if available - client_secret = oauth_session.client_secret - - authorization_code = None - if oauth_session.authorization_code_enc or oauth_session.authorization_code: - if settings.encryption_key: - secret = Secret.from_db(oauth_session.authorization_code_enc, oauth_session.authorization_code) - authorization_code = secret.get_plaintext() - else: - # No encryption key, use plaintext if available - authorization_code = oauth_session.authorization_code - - # Create the Pydantic object with encrypted fields as Secret objects - pydantic_session = MCPOAuthSession( - id=oauth_session.id, - state=oauth_session.state, - server_id=oauth_session.server_id, - server_url=oauth_session.server_url, - server_name=oauth_session.server_name, - user_id=oauth_session.user_id, - organization_id=oauth_session.organization_id, - authorization_url=oauth_session.authorization_url, - authorization_code=authorization_code, - access_token=access_token, - refresh_token=refresh_token, - token_type=oauth_session.token_type, - expires_at=oauth_session.expires_at, - scope=oauth_session.scope, - client_id=oauth_session.client_id, - client_secret=client_secret, - redirect_uri=oauth_session.redirect_uri, - status=oauth_session.status, - created_at=oauth_session.created_at, - updated_at=oauth_session.updated_at, - # Encrypted fields as Secret objects (converted from encrypted strings in DB) - authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc) - if oauth_session.authorization_code_enc - else None, - access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None, - refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None, - client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None, - ) - return pydantic_session - - @enforce_types - async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession: - """Create a new OAuth session for MCP server authentication.""" - async with db_registry.async_session() as session: - # Create the OAuth session with a unique state - oauth_session = MCPOAuth( - id="mcp-oauth-" + str(uuid.uuid4())[:8], - state=secrets.token_urlsafe(32), - server_url=session_create.server_url, - server_name=session_create.server_name, - user_id=session_create.user_id, - organization_id=session_create.organization_id, - status=OAuthSessionStatus.PENDING, - created_at=datetime.now(), - updated_at=datetime.now(), - ) - oauth_session = await oauth_session.create_async(session, actor=actor) - - # Convert to Pydantic model - note: new sessions won't have tokens yet - return self._oauth_orm_to_pydantic(oauth_session) - - @enforce_types - async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: - """Get an OAuth session by its ID.""" - async with db_registry.async_session() as session: - try: - oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) - return self._oauth_orm_to_pydantic(oauth_session) - except NoResultFound: - return None - - @enforce_types - async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: - """Get the latest OAuth session by server URL, organization, and user.""" - async with db_registry.async_session() as session: - # Query for OAuth session matching organization, user, server URL, and status - # Order by updated_at desc to get the most recent record - result = await session.execute( - select(MCPOAuth) - .where( - MCPOAuth.organization_id == actor.organization_id, - MCPOAuth.user_id == actor.id, - MCPOAuth.server_url == server_url, - MCPOAuth.status == OAuthSessionStatus.AUTHORIZED, - ) - .order_by(desc(MCPOAuth.updated_at)) - .limit(1) - ) - oauth_session = result.scalar_one_or_none() - - if not oauth_session: - return None - - return self._oauth_orm_to_pydantic(oauth_session) - - @enforce_types - async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession: - """Update an existing OAuth session.""" - async with db_registry.async_session() as session: - oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) - - # Update fields that are provided - if session_update.authorization_url is not None: - oauth_session.authorization_url = session_update.authorization_url - - # Handle encryption for authorization_code - # Only re-encrypt if the value has actually changed - if session_update.authorization_code is not None: - # Check if value changed - existing_code = None - if oauth_session.authorization_code_enc: - existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc) - existing_code = existing_secret.get_plaintext() - elif oauth_session.authorization_code: - existing_code = oauth_session.authorization_code - - # Only re-encrypt if different - if existing_code != session_update.authorization_code: - oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.authorization_code = session_update.authorization_code - - # Handle encryption for access_token - # Only re-encrypt if the value has actually changed - if session_update.access_token is not None: - # Check if value changed - existing_token = None - if oauth_session.access_token_enc: - existing_secret = Secret.from_encrypted(oauth_session.access_token_enc) - existing_token = existing_secret.get_plaintext() - elif oauth_session.access_token: - existing_token = oauth_session.access_token - - # Only re-encrypt if different - if existing_token != session_update.access_token: - oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.access_token = session_update.access_token - - # Handle encryption for refresh_token - # Only re-encrypt if the value has actually changed - if session_update.refresh_token is not None: - # Check if value changed - existing_refresh = None - if oauth_session.refresh_token_enc: - existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc) - existing_refresh = existing_secret.get_plaintext() - elif oauth_session.refresh_token: - existing_refresh = oauth_session.refresh_token - - # Only re-encrypt if different - if existing_refresh != session_update.refresh_token: - oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.refresh_token = session_update.refresh_token - - if session_update.token_type is not None: - oauth_session.token_type = session_update.token_type - if session_update.expires_at is not None: - oauth_session.expires_at = session_update.expires_at - if session_update.scope is not None: - oauth_session.scope = session_update.scope - if session_update.client_id is not None: - oauth_session.client_id = session_update.client_id - - # Handle encryption for client_secret - # Only re-encrypt if the value has actually changed - if session_update.client_secret is not None: - # Check if value changed - existing_secret_val = None - if oauth_session.client_secret_enc: - existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc) - existing_secret_val = existing_secret.get_plaintext() - elif oauth_session.client_secret: - existing_secret_val = oauth_session.client_secret - - # Only re-encrypt if different - if existing_secret_val != session_update.client_secret: - oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted() - # Keep plaintext for dual-write during migration - oauth_session.client_secret = session_update.client_secret - - if session_update.redirect_uri is not None: - oauth_session.redirect_uri = session_update.redirect_uri - if session_update.status is not None: - oauth_session.status = session_update.status - - # Always update the updated_at timestamp - oauth_session.updated_at = datetime.now() - - oauth_session = await oauth_session.update_async(db_session=session, actor=actor) - - return self._oauth_orm_to_pydantic(oauth_session) - - @enforce_types - async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None: - """Delete an OAuth session.""" - async with db_registry.async_session() as session: - try: - oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) - await oauth_session.hard_delete_async(db_session=session, actor=actor) - except NoResultFound: - raise ValueError(f"OAuth session with id {session_id} not found.") - - @enforce_types - async def cleanup_expired_oauth_sessions(self, max_age_hours: int = 24) -> int: - """Clean up expired OAuth sessions and return the count of deleted sessions.""" - cutoff_time = datetime.now() - timedelta(hours=max_age_hours) - - async with db_registry.async_session() as session: - # Find expired sessions - result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time)) - expired_sessions = result.scalars().all() - - # Delete expired sessions using async ORM method - for oauth_session in expired_sessions: - await oauth_session.hard_delete_async(db_session=session, actor=None) - - if expired_sessions: - logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions") - - return len(expired_sessions) - - @enforce_types - async def handle_oauth_flow( - self, - request: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig], - actor: PydanticUser, - http_request: Optional[Request] = None, - ): - """ - Handle OAuth flow for MCP server connection and yield SSE events. - - Args: - request: The server configuration - actor: The user making the request - http_request: The HTTP request object - - Yields: - SSE events during OAuth flow - - Returns: - Tuple of (temp_client, connect_task) after yielding events - """ - import asyncio - - from letta.services.mcp.oauth_utils import create_oauth_provider, oauth_stream_event - from letta.services.mcp.types import OauthStreamEvent - - # OAuth required, yield state to client to prepare to handle authorization URL - yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required") - - # Create OAuth session to persist the state of the OAuth flow - session_create = MCPOAuthSessionCreate( - server_url=request.server_url, - server_name=request.server_name, - user_id=actor.id, - organization_id=actor.organization_id, - ) - oauth_session = await self.create_oauth_session(session_create, actor) - session_id = oauth_session.id - - # TODO: @jnjpng make this check more robust and remove direct os.getenv - # Check if request is from web frontend to determine redirect URI - is_web_request = ( - http_request - and http_request.headers - and http_request.headers.get("user-agent", "") == "Next.js Middleware" - and http_request.headers.__contains__("x-organization-id") - ) - - logo_uri = None - NEXT_PUBLIC_CURRENT_HOST = os.getenv("NEXT_PUBLIC_CURRENT_HOST") - LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT") - - if is_web_request and NEXT_PUBLIC_CURRENT_HOST: - redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}" - logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg" - elif LETTA_AGENTS_ENDPOINT: - # API and SDK usage should call core server directly - redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}" - else: - logger.error( - f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}" - ) - raise HTTPException(status_code=400, detail="No redirect URI found") - - # Create OAuth provider for the instance of the stream connection - oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, self, actor, logo_uri=logo_uri) - - # Get authorization URL by triggering OAuth flow - temp_client = None - connect_task = None - try: - temp_client = await self.get_mcp_client(request, actor, oauth_provider) - - # Run connect_to_server in background to avoid blocking - # This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database - connect_task = safe_create_task(temp_client.connect_to_server(), label="mcp_oauth_connect") - - # Give the OAuth flow time to trigger and save the URL - await asyncio.sleep(1.0) - - # Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL - auth_session = await self.get_oauth_session_by_id(session_id, actor) - if auth_session and auth_session.authorization_url: - yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id) - - # Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit - yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...") - - # Callback handler will poll for authorization code and state and update the OAuth session - await connect_task - - tools = await temp_client.list_tools(serialize=True) - yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) - - except Exception as e: - logger.error(f"Error triggering OAuth flow: {e}") - yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}") - raise e - finally: - # Clean up resources - if connect_task and not connect_task.done(): - connect_task.cancel() - try: - await connect_task - except asyncio.CancelledError: - pass - if temp_client: - try: - await temp_client.cleanup() - except Exception as cleanup_error: - logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") diff --git a/tests/integration_test_mcp_servers.py b/tests/integration_test_mcp_servers.py deleted file mode 100644 index 1110e126..00000000 --- a/tests/integration_test_mcp_servers.py +++ /dev/null @@ -1,858 +0,0 @@ -""" -Integration tests for the new MCP server endpoints (/v1/mcp-servers/). -Tests all CRUD operations, tool management, and OAuth connection flows. -Uses plain dictionaries since SDK types are not yet generated. -""" - -import os -import sys -import threading -import time -import uuid -from pathlib import Path -from typing import Any, Dict, List, Optional - -import pytest -import requests -from dotenv import load_dotenv - -# ------------------------------ -# Fixtures -# ------------------------------ - - -@pytest.fixture(scope="module") -def server_url() -> str: - """ - Provides the URL for the Letta server. - If LETTA_SERVER_URL is not set, starts the server in a background thread - and polls until it's accepting connections. - """ - - def _run_server() -> None: - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - - # Poll until the server is up (or timeout) - timeout_seconds = 30 - deadline = time.time() + timeout_seconds - while time.time() < deadline: - try: - resp = requests.get(url + "/v1/health") - if resp.status_code < 500: - break - except requests.exceptions.RequestException: - pass - time.sleep(0.1) - else: - raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") - - yield url - - -@pytest.fixture(scope="module") -def auth_headers() -> Dict[str, str]: - """ - Provides authentication headers for API requests. - """ - # Get auth token from environment or use default - token = os.getenv("LETTA_API_TOKEN", "") - if token: - return {"Authorization": f"Bearer {token}"} - return {} - - -@pytest.fixture(scope="function") -def unique_server_id() -> str: - """Generate a unique MCP server ID for each test.""" - # MCP server IDs follow the format: mcp_server- - return f"mcp_server-{uuid.uuid4()}" - - -@pytest.fixture(scope="function") -def mock_mcp_server_path() -> Path: - """Get path to mock MCP server for testing.""" - script_dir = Path(__file__).parent - mcp_server_path = script_dir / "mock_mcp_server.py" - - if not mcp_server_path.exists(): - # Create a minimal mock server for testing if it doesn't exist - pytest.skip(f"Mock MCP server not found at {mcp_server_path}") - - return mcp_server_path - - -# ------------------------------ -# Helper Functions -# ------------------------------ - - -def create_stdio_server_dict(server_name: str, command: str = "npx", args: List[str] = None) -> Dict[str, Any]: - """Create a dictionary representing a stdio MCP server configuration.""" - return { - "type": "stdio", - "server_name": server_name, - "command": command, - "args": args or ["-y", "@modelcontextprotocol/server-everything"], - "env": {"NODE_ENV": "test", "DEBUG": "true"}, - } - - -def create_sse_server_dict(server_name: str, server_url: str = None) -> Dict[str, Any]: - """Create a dictionary representing an SSE MCP server configuration.""" - return { - "type": "sse", - "server_name": server_name, - "server_url": server_url or "https://api.example.com/sse", - "auth_header": "Authorization", - "auth_token": "Bearer test_token_123", - "custom_headers": {"X-Custom-Header": "custom_value", "X-API-Version": "1.0"}, - } - - -def create_streamable_http_server_dict(server_name: str, server_url: str = None) -> Dict[str, Any]: - """Create a dictionary representing a streamable HTTP MCP server configuration.""" - return { - "type": "streamable_http", - "server_name": server_name, - "server_url": server_url or "https://api.example.com/streamable", - "auth_header": "X-API-Key", - "auth_token": "api_key_456", - "custom_headers": {"Accept": "application/json", "X-Version": "2.0"}, - } - - -# ------------------------------ -# Test Cases for CRUD Operations -# ------------------------------ - - -def test_create_stdio_mcp_server(server_url: str, auth_headers: Dict[str, str]): - """Test creating a stdio MCP server.""" - server_name = f"test-stdio-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - # Create the server - response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert response.status_code == 200, f"Failed to create server: {response.text}" - - server_data = response.json() - assert server_data["server_name"] == server_name - assert server_data["command"] == server_config["command"] - assert server_data["args"] == server_config["args"] - assert "id" in server_data # Should have an ID assigned - - server_id = server_data["id"] - - # Cleanup - delete the server - delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert delete_response.status_code == 204, f"Failed to delete server: {delete_response.text}" - - -def test_create_sse_mcp_server(server_url: str, auth_headers: Dict[str, str]): - """Test creating an SSE MCP server.""" - server_name = f"test-sse-{uuid.uuid4().hex[:8]}" - server_config = create_sse_server_dict(server_name) - - # Create the server - response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert response.status_code == 200, f"Failed to create server: {response.text}" - - server_data = response.json() - assert server_data["server_name"] == server_name - assert server_data["server_url"] == server_config["server_url"] - assert server_data["auth_header"] == server_config["auth_header"] - assert "id" in server_data - - server_id = server_data["id"] - - # Cleanup - delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert delete_response.status_code == 204 - - -def test_create_streamable_http_mcp_server(server_url: str, auth_headers: Dict[str, str]): - """Test creating a streamable HTTP MCP server.""" - server_name = f"test-http-{uuid.uuid4().hex[:8]}" - server_config = create_streamable_http_server_dict(server_name) - - # Create the server - response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert response.status_code == 200, f"Failed to create server: {response.text}" - - server_data = response.json() - assert server_data["server_name"] == server_name - assert server_data["server_url"] == server_config["server_url"] - assert "id" in server_data - - server_id = server_data["id"] - - # Cleanup - delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert delete_response.status_code == 204 - - -def test_list_mcp_servers(server_url: str, auth_headers: Dict[str, str]): - """Test listing all MCP servers.""" - # Create multiple servers - servers_created = [] - - # Create stdio server - stdio_name = f"list-test-stdio-{uuid.uuid4().hex[:8]}" - stdio_config = create_stdio_server_dict(stdio_name) - stdio_response = requests.post(f"{server_url}/v1/mcp-servers/", json=stdio_config, headers=auth_headers) - assert stdio_response.status_code == 200 - stdio_server = stdio_response.json() - servers_created.append(stdio_server["id"]) - - # Create SSE server - sse_name = f"list-test-sse-{uuid.uuid4().hex[:8]}" - sse_config = create_sse_server_dict(sse_name) - sse_response = requests.post(f"{server_url}/v1/mcp-servers/", json=sse_config, headers=auth_headers) - assert sse_response.status_code == 200 - sse_server = sse_response.json() - servers_created.append(sse_server["id"]) - - try: - # List all servers - list_response = requests.get(f"{server_url}/v1/mcp-servers/", headers=auth_headers) - assert list_response.status_code == 200 - - servers_list = list_response.json() - assert isinstance(servers_list, list) - assert len(servers_list) >= 2 # At least our two servers - - # Check our servers are in the list - server_ids = [s["id"] for s in servers_list] - assert stdio_server["id"] in server_ids - assert sse_server["id"] in server_ids - - # Check server names - server_names = [s["server_name"] for s in servers_list] - assert stdio_name in server_names - assert sse_name in server_names - - finally: - # Cleanup - for server_id in servers_created: - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_get_specific_mcp_server(server_url: str, auth_headers: Dict[str, str]): - """Test getting a specific MCP server by ID.""" - # Create a server - server_name = f"get-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name, command="python", args=["-m", "mcp_server"]) - server_config["env"]["PYTHONPATH"] = "/usr/local/lib" - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - created_server = create_response.json() - server_id = created_server["id"] - - try: - # Get the server by ID - get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert get_response.status_code == 200 - - retrieved_server = get_response.json() - assert retrieved_server["id"] == server_id - assert retrieved_server["server_name"] == server_name - assert retrieved_server["command"] == "python" - assert retrieved_server["args"] == ["-m", "mcp_server"] - assert retrieved_server.get("env", {}).get("PYTHONPATH") == "/usr/local/lib" - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_update_stdio_mcp_server(server_url: str, auth_headers: Dict[str, str]): - """Test updating a stdio MCP server.""" - # Create a server - server_name = f"update-test-stdio-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name, command="node", args=["old_server.js"]) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Update the server - update_data = { - "server_name": "updated-stdio-server", - "command": "node", - "args": ["new_server.js", "--port", "3000"], - "env": {"NEW_ENV": "new_value", "PORT": "3000"}, - } - - update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers) - assert update_response.status_code == 200 - - updated_server = update_response.json() - assert updated_server["server_name"] == "updated-stdio-server" - assert updated_server["args"] == ["new_server.js", "--port", "3000"] - assert updated_server.get("env", {}).get("NEW_ENV") == "new_value" - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_update_sse_mcp_server(server_url: str, auth_headers: Dict[str, str]): - """Test updating an SSE MCP server.""" - # Create an SSE server - server_name = f"update-test-sse-{uuid.uuid4().hex[:8]}" - server_config = create_sse_server_dict(server_name, server_url="https://old.example.com/sse") - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Update the server - update_data = { - "server_name": "updated-sse-server", - "server_url": "https://new.example.com/sse/v2", - "token": "new_token_789", - "custom_headers": {"X-Updated": "true", "X-Version": "2.0"}, - } - - update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers) - assert update_response.status_code == 200 - - updated_server = update_response.json() - assert updated_server["server_name"] == "updated-sse-server" - assert updated_server["server_url"] == "https://new.example.com/sse/v2" - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_delete_mcp_server(server_url: str, auth_headers: Dict[str, str]): - """Test deleting an MCP server.""" - # Create a server to delete - server_name = f"delete-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - # Delete the server - delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert delete_response.status_code == 204 - - # Verify it's deleted (should get 404) - get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert get_response.status_code == 404 - - -# ------------------------------ -# Test Cases for Tool Operations -# ------------------------------ - - -def test_list_mcp_tools_by_server(server_url: str, auth_headers: Dict[str, str]): - """Test listing tools for a specific MCP server.""" - # Create a server - server_name = f"tools-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # List tools for this server - tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert tools_response.status_code == 200 - - tools = tools_response.json() - assert isinstance(tools, list) - - # Tools might be empty initially if server hasn't connected - # But response structure should be valid - if len(tools) > 0: - # Verify tool structure - tool = tools[0] - assert "id" in tool - assert "name" in tool - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_get_specific_mcp_tool(server_url: str, auth_headers: Dict[str, str]): - """Test getting a specific tool from an MCP server.""" - # Create a server - server_name = f"tool-get-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # First get list of tools - tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert tools_response.status_code == 200 - tools = tools_response.json() - - if len(tools) > 0: - # Get a specific tool - tool_id = tools[0]["id"] - tool_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}", headers=auth_headers) - assert tool_response.status_code == 200 - - specific_tool = tool_response.json() - assert specific_tool["id"] == tool_id - assert "name" in specific_tool - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_run_mcp_tool(server_url: str, auth_headers: Dict[str, str]): - """Test executing an MCP tool.""" - # Create a server - server_name = f"tool-run-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Get available tools - tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert tools_response.status_code == 200 - tools = tools_response.json() - - if len(tools) > 0: - # Run the first available tool - tool_id = tools[0]["id"] - - # Run with arguments - run_request = {"args": {"test_param": "test_value", "count": 5}} - - run_response = requests.post( - f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}/run", json=run_request, headers=auth_headers - ) - assert run_response.status_code == 200 - - result = run_response.json() - assert "status" in result - assert result["status"] in ["success", "error"] - assert "func_return" in result - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_run_mcp_tool_without_args(server_url: str, auth_headers: Dict[str, str]): - """Test executing an MCP tool without arguments.""" - # Create a server - server_name = f"tool-noargs-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Get available tools - tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert tools_response.status_code == 200 - tools = tools_response.json() - - if len(tools) > 0: - tool_id = tools[0]["id"] - - # Run without arguments (empty dict) - run_request = {"args": {}} - - run_response = requests.post( - f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}/run", json=run_request, headers=auth_headers - ) - assert run_response.status_code == 200 - - result = run_response.json() - assert "status" in result - assert "func_return" in result - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_refresh_mcp_server_tools(server_url: str, auth_headers: Dict[str, str]): - """Test refreshing tools for an MCP server.""" - # Create a server - server_name = f"refresh-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Get initial tools - initial_tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert initial_tools_response.status_code == 200 - - # Refresh tools - refresh_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}/refresh", headers=auth_headers) - assert refresh_response.status_code == 200 - - refresh_result = refresh_response.json() - # Result should contain summary of changes - assert refresh_result is not None - - # Get tools after refresh - refreshed_tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert refreshed_tools_response.status_code == 200 - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_refresh_mcp_server_tools_with_agent(server_url: str, auth_headers: Dict[str, str]): - """Test refreshing tools with agent context.""" - # Create a server - server_name = f"refresh-agent-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Refresh tools with agent ID - mock_agent_id = f"agent-{uuid.uuid4()}" - refresh_response = requests.patch( - f"{server_url}/v1/mcp-servers/{server_id}/refresh", params={"agent_id": mock_agent_id}, headers=auth_headers - ) - assert refresh_response.status_code == 200 - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -# ------------------------------ -# Test Cases for OAuth/Connection -# ------------------------------ - - -def test_connect_mcp_server_oauth(server_url: str, auth_headers: Dict[str, str]): - """Test connecting to an MCP server (OAuth flow).""" - # Create an SSE server that might require OAuth - server_name = f"oauth-test-{uuid.uuid4().hex[:8]}" - server_config = create_sse_server_dict(server_name, server_url="https://oauth.example.com/sse") - # Remove token to simulate OAuth requirement - server_config["auth_token"] = None - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Attempt to connect (returns SSE stream) - # We can't fully test SSE in a simple integration test, but verify endpoint works - connect_response = requests.get( - f"{server_url}/v1/mcp-servers/connect/{server_id}", - headers={**auth_headers, "Accept": "text/event-stream"}, - stream=True, - timeout=2, - ) - - # Should get a streaming response or error, not 404 - assert connect_response.status_code in [200, 400, 500], f"Unexpected status: {connect_response.status_code}" - - # Close the stream - connect_response.close() - - except requests.exceptions.Timeout: - # Timeout is acceptable for SSE endpoints in tests - pass - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -# ------------------------------ -# Test Cases for Error Handling -# ------------------------------ - - -def test_error_handling_invalid_server_id(server_url: str, auth_headers: Dict[str, str]): - """Test error handling with invalid server IDs.""" - invalid_id = "invalid-server-id-12345" - - # Try to get non-existent server - get_response = requests.get(f"{server_url}/v1/mcp-servers/{invalid_id}", headers=auth_headers) - assert get_response.status_code == 404 - - # Try to update non-existent server - update_data = {"server_name": "updated"} - update_response = requests.patch(f"{server_url}/v1/mcp-servers/{invalid_id}", json=update_data, headers=auth_headers) - assert update_response.status_code == 404 # Non-existent server returns 404 - - # Try to delete non-existent server - delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{invalid_id}", headers=auth_headers) - assert delete_response.status_code == 404 - - # Try to list tools for non-existent server - tools_response = requests.get(f"{server_url}/v1/mcp-servers/{invalid_id}/tools", headers=auth_headers) - assert tools_response.status_code == 404 - - -def test_invalid_server_type(server_url: str, auth_headers: Dict[str, str]): - """Test creating server with invalid type.""" - invalid_config = {"type": "invalid_type", "server_name": "invalid-server", "some_field": "value"} - - response = requests.post(f"{server_url}/v1/mcp-servers/", json=invalid_config, headers=auth_headers) - assert response.status_code == 422 # Validation error - - -# ------------------------------ -# Test Cases for Complex Scenarios -# ------------------------------ - - -def test_multiple_server_types_coexist(server_url: str, auth_headers: Dict[str, str]): - """Test that multiple server types can coexist.""" - servers_created = [] - - try: - # Create one of each type - stdio_config = create_stdio_server_dict(f"multi-stdio-{uuid.uuid4().hex[:8]}") - stdio_response = requests.post(f"{server_url}/v1/mcp-servers/", json=stdio_config, headers=auth_headers) - assert stdio_response.status_code == 200 - stdio_server = stdio_response.json() - servers_created.append(stdio_server["id"]) - - sse_config = create_sse_server_dict(f"multi-sse-{uuid.uuid4().hex[:8]}") - sse_response = requests.post(f"{server_url}/v1/mcp-servers/", json=sse_config, headers=auth_headers) - assert sse_response.status_code == 200 - sse_server = sse_response.json() - servers_created.append(sse_server["id"]) - - http_config = create_streamable_http_server_dict(f"multi-http-{uuid.uuid4().hex[:8]}") - http_response = requests.post(f"{server_url}/v1/mcp-servers/", json=http_config, headers=auth_headers) - assert http_response.status_code == 200 - http_server = http_response.json() - servers_created.append(http_server["id"]) - - # List all servers - list_response = requests.get(f"{server_url}/v1/mcp-servers/", headers=auth_headers) - assert list_response.status_code == 200 - - servers_list = list_response.json() - server_ids = [s["id"] for s in servers_list] - - # Verify all three are present - assert stdio_server["id"] in server_ids - assert sse_server["id"] in server_ids - assert http_server["id"] in server_ids - - # Get each server and verify type-specific fields - stdio_get = requests.get(f"{server_url}/v1/mcp-servers/{stdio_server['id']}", headers=auth_headers) - assert stdio_get.status_code == 200 - assert stdio_get.json()["command"] == stdio_config["command"] - - sse_get = requests.get(f"{server_url}/v1/mcp-servers/{sse_server['id']}", headers=auth_headers) - assert sse_get.status_code == 200 - assert sse_get.json()["server_url"] == sse_config["server_url"] - - http_get = requests.get(f"{server_url}/v1/mcp-servers/{http_server['id']}", headers=auth_headers) - assert http_get.status_code == 200 - assert http_get.json()["server_url"] == http_config["server_url"] - - finally: - # Cleanup all servers - for server_id in servers_created: - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_partial_update_preserves_fields(server_url: str, auth_headers: Dict[str, str]): - """Test that partial updates preserve non-updated fields.""" - # Create a server with all fields - server_name = f"partial-update-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name, command="node", args=["server.js", "--port", "3000"]) - server_config["env"] = {"NODE_ENV": "production", "PORT": "3000", "DEBUG": "false"} - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # Update only the server name - update_data = {"server_name": "renamed-server"} - - update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers) - assert update_response.status_code == 200 - - updated_server = update_response.json() - assert updated_server["server_name"] == "renamed-server" - # Other fields should be preserved - assert updated_server["command"] == "node" - assert updated_server["args"] == ["server.js", "--port", "3000"] - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_concurrent_server_operations(server_url: str, auth_headers: Dict[str, str]): - """Test multiple servers can be operated on concurrently.""" - servers_created = [] - - try: - # Create multiple servers quickly - for i in range(3): - server_config = create_stdio_server_dict(f"concurrent-{i}-{uuid.uuid4().hex[:8]}", command="python", args=[f"server_{i}.py"]) - - response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert response.status_code == 200 - servers_created.append(response.json()["id"]) - - # Update all servers - for i, server_id in enumerate(servers_created): - update_data = {"server_name": f"updated-concurrent-{i}"} - - update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers) - assert update_response.status_code == 200 - assert update_response.json()["server_name"] == f"updated-concurrent-{i}" - - # Get all servers - for i, server_id in enumerate(servers_created): - get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert get_response.status_code == 200 - assert get_response.json()["server_name"] == f"updated-concurrent-{i}" - - finally: - # Cleanup all servers - for server_id in servers_created: - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - - -def test_full_server_lifecycle(server_url: str, auth_headers: Dict[str, str]): - """Test complete lifecycle: create, list, get, update, tools, delete.""" - # 1. Create server - server_name = f"lifecycle-test-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name, command="npx", args=["-y", "@modelcontextprotocol/server-everything"]) - server_config["env"]["TEST"] = "true" - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # 2. List servers and verify it's there - list_response = requests.get(f"{server_url}/v1/mcp-servers/", headers=auth_headers) - assert list_response.status_code == 200 - assert any(s["id"] == server_id for s in list_response.json()) - - # 3. Get specific server - get_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert get_response.status_code == 200 - assert get_response.json()["server_name"] == server_name - - # 4. Update server - update_data = {"server_name": "lifecycle-updated", "env": {"TEST": "false", "NEW_VAR": "value"}} - update_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}", json=update_data, headers=auth_headers) - assert update_response.status_code == 200 - assert update_response.json()["server_name"] == "lifecycle-updated" - - # 5. List tools - tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert tools_response.status_code == 200 - tools = tools_response.json() - assert isinstance(tools, list) - - # 6. If tools exist, try to get and run one - if len(tools) > 0: - tool_id = tools[0]["id"] - - # Get specific tool - tool_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}", headers=auth_headers) - assert tool_response.status_code == 200 - assert tool_response.json()["id"] == tool_id - - # Run tool - run_response = requests.post( - f"{server_url}/v1/mcp-servers/{server_id}/tools/{tool_id}/run", json={"args": {}}, headers=auth_headers - ) - assert run_response.status_code == 200 - - # 7. Refresh tools - refresh_response = requests.patch(f"{server_url}/v1/mcp-servers/{server_id}/refresh", headers=auth_headers) - assert refresh_response.status_code == 200 - - # 8. Try to connect (OAuth flow) - try: - connect_response = requests.get( - f"{server_url}/v1/mcp-servers/connect/{server_id}", - headers={**auth_headers, "Accept": "text/event-stream"}, - stream=True, - timeout=1, - ) - # Just verify it doesn't 404 - assert connect_response.status_code in [200, 400, 500] - connect_response.close() - except requests.exceptions.Timeout: - pass # SSE timeout is acceptable - - finally: - # 9. Delete server - delete_response = requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert delete_response.status_code == 204 - - # 10. Verify it's deleted - get_deleted_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers) - assert get_deleted_response.status_code == 404 - - -# ------------------------------ -# Test Cases for Empty Responses -# ------------------------------ - - -def test_empty_tools_list(server_url: str, auth_headers: Dict[str, str]): - """Test handling of servers with no tools.""" - # Create a minimal server that likely has no tools - server_name = f"no-tools-{uuid.uuid4().hex[:8]}" - server_config = create_stdio_server_dict(server_name, command="echo", args=["hello"]) - - create_response = requests.post(f"{server_url}/v1/mcp-servers/", json=server_config, headers=auth_headers) - assert create_response.status_code == 200 - server_id = create_response.json()["id"] - - try: - # List tools (should be empty) - tools_response = requests.get(f"{server_url}/v1/mcp-servers/{server_id}/tools", headers=auth_headers) - assert tools_response.status_code == 200 - - tools = tools_response.json() - assert tools is not None - assert isinstance(tools, list) - # Tools will be empty for a simple echo command - - finally: - # Cleanup - requests.delete(f"{server_url}/v1/mcp-servers/{server_id}", headers=auth_headers)