diff --git a/alembic/versions/c6c43222e2de_add_mcp_tools_table.py b/alembic/versions/c6c43222e2de_add_mcp_tools_table.py new file mode 100644 index 00000000..280ef3d6 --- /dev/null +++ b/alembic/versions/c6c43222e2de_add_mcp_tools_table.py @@ -0,0 +1,47 @@ +"""Add mcp_tools table + +Revision ID: c6c43222e2de +Revises: 6756d04c3ddb +Create Date: 2025-10-20 17:25:54.334037 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c6c43222e2de" +down_revision: Union[str, None] = "6756d04c3ddb" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "mcp_tools", + sa.Column("mcp_server_id", sa.String(), nullable=False), + sa.Column("tool_id", sa.String(), nullable=False), + sa.Column("id", sa.String(), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True), + sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False), + sa.Column("_created_by_id", sa.String(), nullable=True), + sa.Column("_last_updated_by_id", sa.String(), nullable=True), + sa.Column("organization_id", sa.String(), nullable=False), + sa.ForeignKeyConstraint( + ["organization_id"], + ["organizations.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("mcp_tools") + # ### end Alembic commands ### diff --git a/fern/openapi.json b/fern/openapi.json index d5703fd4..19cf5e71 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -1571,13 +1571,13 @@ "schema": { "anyOf": [ { - "$ref": "#/components/schemas/UpdateStdioMCPServer" + "$ref": "#/components/schemas/letta__schemas__mcp__UpdateStdioMCPServer" }, { - "$ref": "#/components/schemas/UpdateSSEMCPServer" + "$ref": "#/components/schemas/letta__schemas__mcp__UpdateSSEMCPServer" }, { - "$ref": "#/components/schemas/UpdateStreamableHTTPMCPServer" + "$ref": "#/components/schemas/letta__schemas__mcp__UpdateStreamableHTTPMCPServer" } ], "title": "Request" @@ -1807,7 +1807,7 @@ "content": { "application/json": { "schema": { - "$ref": "#/components/schemas/MCPToolExecuteRequest" + "$ref": "#/components/schemas/letta__server__rest_api__routers__v1__tools__MCPToolExecuteRequest" } } } @@ -10189,6 +10189,521 @@ } } }, + "/v1/mcp-servers/": { + "post": { + "tags": ["mcp-servers"], + "summary": "Create Mcp Server", + "description": "Add a new MCP server to the Letta MCP server config", + "operationId": "mcp_create_mcp_server", + "parameters": [], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "$ref": "#/components/schemas/CreateStdioMCPServer" + }, + { + "$ref": "#/components/schemas/CreateSSEMCPServer" + }, + { + "$ref": "#/components/schemas/CreateStreamableHTTPMCPServer" + } + ], + "title": "Request" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "$ref": "#/components/schemas/StdioMCPServer" + }, + { + "$ref": "#/components/schemas/SSEMCPServer" + }, + { + "$ref": "#/components/schemas/StreamableHTTPMCPServer" + } + ], + "title": "Response Mcp Create Mcp Server" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "get": { + "tags": ["mcp-servers"], + "summary": "List Mcp Servers", + "description": "Get a list of all configured MCP servers", + "operationId": "mcp_list_mcp_servers", + "parameters": [], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "anyOf": [ + { + "$ref": "#/components/schemas/StdioMCPServer" + }, + { + "$ref": "#/components/schemas/SSEMCPServer" + }, + { + "$ref": "#/components/schemas/StreamableHTTPMCPServer" + } + ] + }, + "title": "Response Mcp List Mcp Servers" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/mcp-servers/{mcp_server_id}": { + "get": { + "tags": ["mcp-servers"], + "summary": "Get Mcp Server", + "description": "Get a specific MCP server", + "operationId": "mcp_get_mcp_server", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "$ref": "#/components/schemas/StdioMCPServer" + }, + { + "$ref": "#/components/schemas/SSEMCPServer" + }, + { + "$ref": "#/components/schemas/StreamableHTTPMCPServer" + } + ], + "title": "Response Mcp Get Mcp Server" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "delete": { + "tags": ["mcp-servers"], + "summary": "Delete Mcp Server", + "description": "Delete an MCP server by its ID", + "operationId": "mcp_delete_mcp_server", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + } + ], + "responses": { + "204": { + "description": "Successful Response" + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + }, + "patch": { + "tags": ["mcp-servers"], + "summary": "Update Mcp Server", + "description": "Update an existing MCP server configuration", + "operationId": "mcp_update_mcp_server", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + } + ], + "requestBody": { + "required": true, + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "$ref": "#/components/schemas/letta__schemas__mcp_server__UpdateStdioMCPServer" + }, + { + "$ref": "#/components/schemas/letta__schemas__mcp_server__UpdateSSEMCPServer" + }, + { + "$ref": "#/components/schemas/letta__schemas__mcp_server__UpdateStreamableHTTPMCPServer" + } + ], + "title": "Request" + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "anyOf": [ + { + "$ref": "#/components/schemas/StdioMCPServer" + }, + { + "$ref": "#/components/schemas/SSEMCPServer" + }, + { + "$ref": "#/components/schemas/StreamableHTTPMCPServer" + } + ], + "title": "Response Mcp Update Mcp Server" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/mcp-servers/{mcp_server_id}/tools": { + "get": { + "tags": ["mcp-servers"], + "summary": "List Mcp Tools By Server", + "description": "Get a list of all tools for a specific MCP server", + "operationId": "mcp_list_mcp_tools_by_server", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "type": "array", + "items": { + "$ref": "#/components/schemas/Tool" + }, + "title": "Response Mcp List Mcp Tools By Server" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/mcp-servers/{mcp_server_id}/tools/{tool_id}": { + "get": { + "tags": ["mcp-servers"], + "summary": "Get Mcp Tool", + "description": "Get a specific MCP tool by its ID", + "operationId": "mcp_get_mcp_tool", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + }, + { + "name": "tool_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Tool Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Tool" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/mcp-servers/{mcp_server_id}/tools/{tool_id}/run": { + "post": { + "tags": ["mcp-servers"], + "summary": "Run Mcp Tool", + "description": "Execute a specific MCP tool\n\nThe request body should contain the tool arguments in the MCPToolExecuteRequest format.", + "operationId": "mcp_run_tool", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + }, + { + "name": "tool_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Tool Id" + } + } + ], + "requestBody": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/letta__schemas__mcp_server__MCPToolExecuteRequest", + "default": { + "args": {} + } + } + } + } + }, + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ToolExecutionResult" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/mcp-servers/{mcp_server_id}/refresh": { + "patch": { + "tags": ["mcp-servers"], + "summary": "Refresh Mcp Server Tools", + "description": "Refresh tools for an MCP server by:\n1. Fetching current tools from the MCP server\n2. Deleting tools that no longer exist on the server\n3. Updating schemas for existing tools\n4. Adding new tools from the server\n\nReturns a summary of changes made.", + "operationId": "mcp_refresh_mcp_server_tools", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + }, + { + "name": "agent_id", + "in": "query", + "required": false, + "schema": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Agent Id" + } + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": {} + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, + "/v1/mcp-servers/connect/{mcp_server_id}": { + "get": { + "tags": ["mcp-servers"], + "summary": "Connect Mcp Server", + "description": "Connect to an MCP server with support for OAuth via SSE.\nReturns a stream of events handling authorization state and exchange if OAuth is required.", + "operationId": "mcp_connect_mcp_server", + "parameters": [ + { + "name": "mcp_server_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "title": "Mcp Server Id" + } + } + ], + "responses": { + "200": { + "description": "Successful response", + "content": { + "application/json": { + "schema": {} + }, + "text/event-stream": { + "description": "Server-Sent Events stream" + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/blocks/": { "get": { "tags": ["blocks"], @@ -21422,6 +21937,173 @@ "title": "CreateBlock", "description": "Create a block" }, + "CreateSSEMCPServer": { + "properties": { + "server_name": { + "type": "string", + "title": "Server Name", + "description": "The name of the server" + }, + "type": { + "$ref": "#/components/schemas/MCPServerType", + "default": "sse" + }, + "server_url": { + "type": "string", + "title": "Server Url", + "description": "The URL of the server" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom HTTP headers to include with requests" + } + }, + "type": "object", + "required": ["server_name", "server_url"], + "title": "CreateSSEMCPServer", + "description": "Create a new SSE MCP server" + }, + "CreateStdioMCPServer": { + "properties": { + "server_name": { + "type": "string", + "title": "Server Name", + "description": "The name of the server" + }, + "type": { + "$ref": "#/components/schemas/MCPServerType", + "default": "stdio" + }, + "command": { + "type": "string", + "title": "Command", + "description": "The command to run (MCP 'local' client will run this command)" + }, + "args": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Args", + "description": "The arguments to pass to the command" + }, + "env": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Env", + "description": "Environment variables to set" + } + }, + "type": "object", + "required": ["server_name", "command", "args"], + "title": "CreateStdioMCPServer", + "description": "Create a new Stdio MCP server" + }, + "CreateStreamableHTTPMCPServer": { + "properties": { + "server_name": { + "type": "string", + "title": "Server Name", + "description": "The name of the server" + }, + "type": { + "$ref": "#/components/schemas/MCPServerType", + "default": "streamable_http" + }, + "server_url": { + "type": "string", + "title": "Server Url", + "description": "The URL of the server" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom HTTP headers to include with requests" + } + }, + "type": "object", + "required": ["server_name", "server_url"], + "title": "CreateStreamableHTTPMCPServer", + "description": "Create a new Streamable HTTP MCP server" + }, "Custom-Input": { "properties": { "input": { @@ -26020,18 +26702,6 @@ "title": "MCPTool", "description": "A simple wrapper around MCP's tool definition (to avoid conflict with our own)" }, - "MCPToolExecuteRequest": { - "properties": { - "args": { - "additionalProperties": true, - "type": "object", - "title": "Args", - "description": "Arguments to pass to the MCP tool" - } - }, - "type": "object", - "title": "MCPToolExecuteRequest" - }, "MCPToolHealth": { "properties": { "status": { @@ -28500,6 +29170,74 @@ "title": "RunStatus", "description": "Status of the run." }, + "SSEMCPServer": { + "properties": { + "server_name": { + "type": "string", + "title": "Server Name", + "description": "The name of the server" + }, + "type": { + "$ref": "#/components/schemas/MCPServerType", + "default": "sse" + }, + "server_url": { + "type": "string", + "title": "Server Url", + "description": "The URL of the server" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom HTTP headers to include with requests" + }, + "id": { + "type": "string", + "pattern": "^mcp_server-[a-fA-F0-9]{8}", + "title": "Id", + "description": "The human-friendly ID of the Mcp_server", + "examples": ["mcp_server-123e4567-e89b-12d3-a456-426614174000"] + } + }, + "type": "object", + "required": ["server_name", "server_url"], + "title": "SSEMCPServer", + "description": "An SSE MCP server" + }, "SSEServerConfig": { "properties": { "server_name": { @@ -29334,6 +30072,58 @@ "title": "SourceUpdate", "description": "Schema for updating an existing Source." }, + "StdioMCPServer": { + "properties": { + "server_name": { + "type": "string", + "title": "Server Name", + "description": "The name of the server" + }, + "type": { + "$ref": "#/components/schemas/MCPServerType", + "default": "stdio" + }, + "command": { + "type": "string", + "title": "Command", + "description": "The command to run (MCP 'local' client will run this command)" + }, + "args": { + "items": { + "type": "string" + }, + "type": "array", + "title": "Args", + "description": "The arguments to pass to the command" + }, + "env": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Env", + "description": "Environment variables to set" + }, + "id": { + "type": "string", + "pattern": "^mcp_server-[a-fA-F0-9]{8}", + "title": "Id", + "description": "The human-friendly ID of the Mcp_server", + "examples": ["mcp_server-123e4567-e89b-12d3-a456-426614174000"] + } + }, + "type": "object", + "required": ["server_name", "command", "args"], + "title": "StdioMCPServer", + "description": "A Stdio MCP server" + }, "StdioServerConfig": { "properties": { "server_name": { @@ -29831,6 +30621,74 @@ ], "title": "StopReasonType" }, + "StreamableHTTPMCPServer": { + "properties": { + "server_name": { + "type": "string", + "title": "Server Name", + "description": "The name of the server" + }, + "type": { + "$ref": "#/components/schemas/MCPServerType", + "default": "streamable_http" + }, + "server_url": { + "type": "string", + "title": "Server Url", + "description": "The URL of the server" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom HTTP headers to include with requests" + }, + "id": { + "type": "string", + "pattern": "^mcp_server-[a-fA-F0-9]{8}", + "title": "Id", + "description": "The human-friendly ID of the Mcp_server", + "examples": ["mcp_server-123e4567-e89b-12d3-a456-426614174000"] + } + }, + "type": "object", + "required": ["server_name", "server_url"], + "title": "StreamableHTTPMCPServer", + "description": "A Streamable HTTP MCP server" + }, "StreamableHTTPServerConfig": { "properties": { "server_name": { @@ -30868,6 +31726,82 @@ "required": ["created_at", "description", "key", "updated_at", "value"], "title": "ToolEnvVarSchema" }, + "ToolExecutionResult": { + "properties": { + "status": { + "type": "string", + "enum": ["success", "error"], + "title": "Status", + "description": "The status of the tool execution and return object" + }, + "func_return": { + "anyOf": [ + {}, + { + "type": "null" + } + ], + "title": "Func Return", + "description": "The function return object" + }, + "agent_state": { + "anyOf": [ + { + "$ref": "#/components/schemas/AgentState" + }, + { + "type": "null" + } + ], + "description": "The agent state" + }, + "stdout": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stdout", + "description": "Captured stdout (prints, logs) from function invocation" + }, + "stderr": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Stderr", + "description": "Captured stderr from the function invocation" + }, + "sandbox_config_fingerprint": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Sandbox Config Fingerprint", + "description": "The fingerprint of the config for the sandbox" + } + }, + "type": "object", + "required": ["status"], + "title": "ToolExecutionResult" + }, "ToolJSONSchema": { "properties": { "name": { @@ -31953,131 +32887,6 @@ "required": ["reasoning"], "title": "UpdateReasoningMessage" }, - "UpdateSSEMCPServer": { - "properties": { - "server_url": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Url", - "description": "The URL of the server (MCP SSE client will connect to this URL)" - }, - "token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Token", - "description": "The access token or API key for the MCP server (used for SSE authentication)" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom authentication headers as key-value pairs" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateSSEMCPServer", - "description": "Update an SSE MCP server" - }, - "UpdateStdioMCPServer": { - "properties": { - "stdio_config": { - "anyOf": [ - { - "$ref": "#/components/schemas/StdioServerConfig" - }, - { - "type": "null" - } - ], - "description": "The configuration for the server (MCP 'local' client will run this command)" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateStdioMCPServer", - "description": "Update a Stdio MCP server" - }, - "UpdateStreamableHTTPMCPServer": { - "properties": { - "server_url": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Server Url", - "description": "The URL path for the streamable HTTP server (e.g., 'example/mcp')" - }, - "auth_header": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Header", - "description": "The name of the authentication header (e.g., 'Authorization')" - }, - "auth_token": { - "anyOf": [ - { - "type": "string" - }, - { - "type": "null" - } - ], - "title": "Auth Token", - "description": "The authentication token or API key value" - }, - "custom_headers": { - "anyOf": [ - { - "additionalProperties": { - "type": "string" - }, - "type": "object" - }, - { - "type": "null" - } - ], - "title": "Custom Headers", - "description": "Custom authentication headers as key-value pairs" - } - }, - "additionalProperties": false, - "type": "object", - "title": "UpdateStreamableHTTPMCPServer", - "description": "Update a Streamable HTTP MCP server" - }, "UpdateSystemMessage": { "properties": { "message_type": { @@ -33670,6 +34479,409 @@ "required": ["tool_return", "status", "tool_call_id"], "title": "ToolReturn" }, + "letta__schemas__mcp__UpdateSSEMCPServer": { + "properties": { + "server_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Name", + "description": "The name of the MCP server" + }, + "server_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Url", + "description": "The URL of the server (MCP SSE client will connect to this URL)" + }, + "token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Token", + "description": "The access token or API key for the MCP server (used for SSE authentication)" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom authentication headers as key-value pairs" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateSSEMCPServer", + "description": "Update an SSE MCP server" + }, + "letta__schemas__mcp__UpdateStdioMCPServer": { + "properties": { + "server_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Name", + "description": "The name of the MCP server" + }, + "stdio_config": { + "anyOf": [ + { + "$ref": "#/components/schemas/StdioServerConfig" + }, + { + "type": "null" + } + ], + "description": "The configuration for the server (MCP 'local' client will run this command)" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateStdioMCPServer", + "description": "Update a Stdio MCP server" + }, + "letta__schemas__mcp__UpdateStreamableHTTPMCPServer": { + "properties": { + "server_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Name", + "description": "The name of the MCP server" + }, + "server_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Url", + "description": "The URL path for the streamable HTTP server (e.g., 'example/mcp')" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom authentication headers as key-value pairs" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateStreamableHTTPMCPServer", + "description": "Update a Streamable HTTP MCP server" + }, + "letta__schemas__mcp_server__MCPToolExecuteRequest": { + "properties": { + "args": { + "additionalProperties": true, + "type": "object", + "title": "Args", + "description": "Arguments to pass to the MCP tool" + } + }, + "additionalProperties": false, + "type": "object", + "title": "MCPToolExecuteRequest", + "description": "Request to execute an MCP tool by IDs." + }, + "letta__schemas__mcp_server__UpdateSSEMCPServer": { + "properties": { + "server_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Name", + "description": "The name of the MCP server" + }, + "server_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Url", + "description": "The URL of the SSE MCP server" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Token", + "description": "The authentication token (internal)" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom headers to send with requests" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateSSEMCPServer", + "description": "Update schema for SSE MCP server - all fields optional" + }, + "letta__schemas__mcp_server__UpdateStdioMCPServer": { + "properties": { + "server_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Name", + "description": "The name of the MCP server" + }, + "command": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Command", + "description": "The command to run the MCP server" + }, + "args": { + "anyOf": [ + { + "items": { + "type": "string" + }, + "type": "array" + }, + { + "type": "null" + } + ], + "title": "Args", + "description": "The arguments to pass to the command" + }, + "env": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Env", + "description": "Environment variables to set" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateStdioMCPServer", + "description": "Update schema for Stdio MCP server - all fields optional" + }, + "letta__schemas__mcp_server__UpdateStreamableHTTPMCPServer": { + "properties": { + "server_name": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Name", + "description": "The name of the MCP server" + }, + "server_url": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Server Url", + "description": "The URL of the Streamable HTTP MCP server" + }, + "auth_token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Token", + "description": "The authentication token or API key value" + }, + "token": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Token", + "description": "The authentication token (internal)" + }, + "auth_header": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Auth Header", + "description": "The name of the authentication header (e.g., 'Authorization')" + }, + "custom_headers": { + "anyOf": [ + { + "additionalProperties": { + "type": "string" + }, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Custom Headers", + "description": "Custom headers to send with requests" + } + }, + "additionalProperties": false, + "type": "object", + "title": "UpdateStreamableHTTPMCPServer", + "description": "Update schema for Streamable HTTP MCP server - all fields optional" + }, "letta__schemas__message__ToolReturn": { "properties": { "tool_call_id": { @@ -34069,6 +35281,18 @@ ], "title": "ToolSchema" }, + "letta__server__rest_api__routers__v1__tools__MCPToolExecuteRequest": { + "properties": { + "args": { + "additionalProperties": true, + "type": "object", + "title": "Args", + "description": "Arguments to pass to the MCP tool" + } + }, + "type": "object", + "title": "MCPToolExecuteRequest" + }, "openai__types__chat__chat_completion_message_function_tool_call__Function": { "properties": { "arguments": { diff --git a/letta/helpers/converters.py b/letta/helpers/converters.py index 9ca236fc..b07c92f3 100644 --- a/letta/helpers/converters.py +++ b/letta/helpers/converters.py @@ -504,14 +504,43 @@ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormat def serialize_mcp_stdio_config(config: Union[Optional[StdioServerConfig], Dict]) -> Optional[Dict]: - """Convert an StdioServerConfig object into a JSON-serializable dictionary.""" + """Convert an StdioServerConfig object into a JSON-serializable dictionary. + + Persist required fields for successful deserialization back into a + StdioServerConfig model (namely `server_name` and `type`). The + `to_dict()` helper intentionally omits these since they're not needed + by MCP transport, but our ORM deserializer reconstructs the pydantic + model and requires them. + """ if config and isinstance(config, StdioServerConfig): - return config.to_dict() + data = config.to_dict() + # Preserve required fields for pydantic reconstruction + data["server_name"] = config.server_name + # Store enum as its value; pydantic will coerce on load + data["type"] = config.type.value if hasattr(config.type, "value") else str(config.type) + return data return config def deserialize_mcp_stdio_config(data: Optional[Dict]) -> Optional[StdioServerConfig]: - """Convert a dictionary back into an StdioServerConfig object.""" + """Convert a dictionary back into an StdioServerConfig object. + + Backwards-compatibility notes: + - Older rows may only include `transport`, `command`, `args`, `env`. + In that case, provide defaults for `server_name` and `type` to + satisfy the pydantic model requirements. + - If both `type` and `transport` are present, prefer `type`. + """ if not data: return None - return StdioServerConfig(**data) + + payload = dict(data) + # Map legacy `transport` field to required `type` if missing + if "type" not in payload and "transport" in payload: + payload["type"] = payload["transport"] + + # Ensure required field exists; use a sensible placeholder when unknown + if "server_name" not in payload: + payload["server_name"] = payload.get("name", "unknown") + + return StdioServerConfig(**payload) diff --git a/letta/orm/mcp_server.py b/letta/orm/mcp_server.py index 49cffb84..14888baf 100644 --- a/letta/orm/mcp_server.py +++ b/letta/orm/mcp_server.py @@ -56,3 +56,12 @@ class MCPServer(SqlalchemyBase, OrganizationMixin): metadata_: Mapped[Optional[dict]] = mapped_column( JSON, default=lambda: {}, doc="A dictionary of additional metadata for the MCP server." ) + + +class MCPTools(SqlalchemyBase, OrganizationMixin): + """Represents a mapping of MCP server ID to tool ID""" + + __tablename__ = "mcp_tools" + + mcp_server_id: Mapped[str] = mapped_column(String, doc="The ID of the MCP server") + tool_id: Mapped[str] = mapped_column(String, doc="The ID of the tool") diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index 856760fd..92996344 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -148,6 +148,7 @@ class MCPServer(BaseMCPServer): class UpdateSSEMCPServer(LettaBase): """Update an SSE MCP server""" + server_name: Optional[str] = Field(None, description="The name of the MCP server") server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)") token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)") custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") @@ -156,6 +157,7 @@ class UpdateSSEMCPServer(LettaBase): class UpdateStdioMCPServer(LettaBase): """Update a Stdio MCP server""" + server_name: Optional[str] = Field(None, description="The name of the MCP server") stdio_config: Optional[StdioServerConfig] = Field( None, description="The configuration for the server (MCP 'local' client will run this command)" ) @@ -164,6 +166,7 @@ class UpdateStdioMCPServer(LettaBase): class UpdateStreamableHTTPMCPServer(LettaBase): """Update a Streamable HTTP MCP server""" + server_name: Optional[str] = Field(None, description="The name of the MCP server") server_url: Optional[str] = Field(None, description="The URL path for the streamable HTTP server (e.g., 'example/mcp')") auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')") auth_token: Optional[str] = Field(None, description="The authentication token or API key value") diff --git a/letta/schemas/mcp_server.py b/letta/schemas/mcp_server.py index 28d47fad..5b495c5a 100644 --- a/letta/schemas/mcp_server.py +++ b/letta/schemas/mcp_server.py @@ -41,18 +41,21 @@ class StdioMCPServer(CreateStdioMCPServer): """A Stdio MCP server""" id: str = BaseMCPServer.generate_id_field() + type: MCPServerType = MCPServerType.STDIO class SSEMCPServer(CreateSSEMCPServer): """An SSE MCP server""" id: str = BaseMCPServer.generate_id_field() + type: MCPServerType = MCPServerType.SSE class StreamableHTTPMCPServer(CreateStreamableHTTPMCPServer): """A Streamable HTTP MCP server""" id: str = BaseMCPServer.generate_id_field() + type: MCPServerType = MCPServerType.STREAMABLE_HTTP MCPServerUnion = Union[StdioMCPServer, SSEMCPServer, StreamableHTTPMCPServer] @@ -74,9 +77,10 @@ class UpdateSSEMCPServer(LettaBase): server_name: Optional[str] = Field(None, description="The name of the MCP server") server_url: Optional[str] = Field(None, description="The URL of the SSE MCP server") - # Note: auth_token is renamed to token to match the ORM field - token: Optional[str] = Field(None, description="The authentication token") - # auth_header is excluded as it's derived from the token + # Accept both `auth_token` (API surface) and `token` (internal ORM naming) + auth_token: Optional[str] = Field(None, description="The authentication token or API key value") + token: Optional[str] = Field(None, description="The authentication token (internal)") + auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')") custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to send with requests") @@ -85,9 +89,10 @@ class UpdateStreamableHTTPMCPServer(LettaBase): server_name: Optional[str] = Field(None, description="The name of the MCP server") server_url: Optional[str] = Field(None, description="The URL of the Streamable HTTP MCP server") - # Note: auth_token is renamed to token to match the ORM field - token: Optional[str] = Field(None, description="The authentication token") - # auth_header is excluded as it's derived from the token + # Accept both `auth_token` (API surface) and `token` (internal ORM naming) + auth_token: Optional[str] = Field(None, description="The authentication token or API key value") + token: Optional[str] = Field(None, description="The authentication token (internal)") + auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')") custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to send with requests") @@ -296,3 +301,49 @@ def convert_generic_to_union(server) -> MCPServerUnion: ) else: raise ValueError(f"Unknown server type: {server.server_type}") + + +def convert_update_to_internal(request: Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer]): + """Convert external API update models to internal UpdateMCPServer union used by the manager. + + - Flattens stdio fields into StdioServerConfig inside UpdateStdioMCPServer + - Maps `auth_token` to `token` for HTTP-based transports + - Ignores `auth_header` at update time (header is derived from token) + """ + # Local import to avoid circulars + from letta.functions.mcp_client.types import MCPServerType as MCPType, StdioServerConfig as StdioCfg + from letta.schemas.mcp import ( + UpdateSSEMCPServer as InternalUpdateSSE, + UpdateStdioMCPServer as InternalUpdateStdio, + UpdateStreamableHTTPMCPServer as InternalUpdateHTTP, + ) + + if isinstance(request, UpdateStdioMCPServer): + stdio_cfg = None + # Only build stdio_config if command and args are explicitly provided to avoid overwriting existing config + if request.command is not None and request.args is not None: + stdio_cfg = StdioCfg( + server_name=request.server_name or "", + type=MCPType.STDIO, + command=request.command, + args=request.args, + env=request.env, + ) + kwargs: dict = {} + if request.server_name is not None: + kwargs["server_name"] = request.server_name + if stdio_cfg is not None: + kwargs["stdio_config"] = stdio_cfg + return InternalUpdateStdio(**kwargs) + elif isinstance(request, UpdateSSEMCPServer): + token_value = request.auth_token or request.token + return InternalUpdateSSE( + server_name=request.server_name, server_url=request.server_url, token=token_value, custom_headers=request.custom_headers + ) + elif isinstance(request, UpdateStreamableHTTPMCPServer): + token_value = request.auth_token or request.token + return InternalUpdateHTTP( + server_name=request.server_name, server_url=request.server_url, auth_token=token_value, custom_headers=request.custom_headers + ) + else: + raise TypeError(f"Unsupported update request type: {type(request)}") diff --git a/letta/server/rest_api/routers/v1/__init__.py b/letta/server/rest_api/routers/v1/__init__.py index 8568485c..520a77b4 100644 --- a/letta/server/rest_api/routers/v1/__init__.py +++ b/letta/server/rest_api/routers/v1/__init__.py @@ -11,6 +11,7 @@ from letta.server.rest_api.routers.v1.internal_runs import router as internal_ru from letta.server.rest_api.routers.v1.internal_templates import router as internal_templates_router from letta.server.rest_api.routers.v1.jobs import router as jobs_router from letta.server.rest_api.routers.v1.llms import router as llm_router +from letta.server.rest_api.routers.v1.mcp_servers import router as mcp_servers_router from letta.server.rest_api.routers.v1.messages import router as messages_router from letta.server.rest_api.routers.v1.providers import router as providers_router from letta.server.rest_api.routers.v1.runs import router as runs_router @@ -34,6 +35,7 @@ ROUTERS = [ internal_runs_router, internal_templates_router, llm_router, + mcp_servers_router, blocks_router, jobs_router, health_router, diff --git a/letta/server/rest_api/routers/v1/mcp_servers.py b/letta/server/rest_api/routers/v1/mcp_servers.py index e29c5a1d..723a0ffc 100644 --- a/letta/server/rest_api/routers/v1/mcp_servers.py +++ b/letta/server/rest_api/routers/v1/mcp_servers.py @@ -1,8 +1,10 @@ -from typing import Any, Dict, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional, Union -from fastapi import APIRouter, Body, Depends, HTTPException +from fastapi import APIRouter, Body, Depends, HTTPException, Request +from httpx import HTTPStatusError from starlette.responses import StreamingResponse +from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig from letta.log import get_logger from letta.schemas.letta_message import ToolReturnMessage from letta.schemas.mcp_server import ( @@ -11,14 +13,20 @@ from letta.schemas.mcp_server import ( MCPToolExecuteRequest, UpdateMCPServerUnion, convert_generic_to_union, + convert_update_to_internal, ) from letta.schemas.tool import Tool +from letta.schemas.tool_execution_result import ToolExecutionResult from letta.server.rest_api.dependencies import ( HeaderParams, get_headers, get_letta_server, ) +from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode from letta.server.server import SyncServer +from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event +from letta.services.mcp.stdio_client import AsyncStdioMCPClient +from letta.services.mcp.types import OauthStreamEvent from letta.settings import tool_settings router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"]) @@ -39,6 +47,7 @@ async def create_mcp_server( """ Add a new MCP server to the Letta MCP server config """ + # TODO: add the tools to the MCP server table we made. actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) new_server = await server.mcp_server_manager.create_mcp_server_from_config_with_tools(request, actor=actor) return convert_generic_to_union(new_server) @@ -56,7 +65,6 @@ async def list_mcp_servers( """ Get a list of all configured MCP servers """ - actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor) return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers] @@ -112,8 +120,10 @@ async def update_mcp_server( Update an existing MCP server configuration """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + # Convert external update payload to internal manager union + internal_update = convert_update_to_internal(request) updated_server = await server.mcp_server_manager.update_mcp_server_by_id( - mcp_server_id=mcp_server_id, mcp_server_update=request, actor=actor + mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor ) return convert_generic_to_union(updated_server) @@ -127,24 +137,10 @@ async def list_mcp_tools_by_server( """ Get a list of all tools for a specific MCP server """ - # TODO: implement this. We want to use the new tools table instead of going to the mcp server. - pass - # actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # mcp_tools = await server.mcp_server_manager.list_mcp_server_tools(mcp_server_id, actor=actor) - # # Convert MCPTool objects to Tool objects - # tools = [] - # for mcp_tool in mcp_tools: - # from letta.schemas.tool import ToolCreate - # tool_create = ToolCreate.from_mcp(mcp_server_name="", mcp_tool=mcp_tool) - # tools.append(Tool( - # id=f"mcp-tool-{mcp_tool.name}", # Generate a temporary ID - # name=mcp_tool.name, - # description=tool_create.description, - # json_schema=tool_create.json_schema, - # source_code=tool_create.source_code, - # tags=tool_create.tags, - # )) - # return tools + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + # Use the new efficient method that queries from the database using MCPTools mapping + tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor) + return tools @router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_get_mcp_tool") @@ -158,13 +154,11 @@ async def get_mcp_tool( Get a specific MCP tool by its ID """ actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) - # Use the tool_manager's existing method to get the tool by ID - # Verify the tool belongs to the MCP server (optional check) - tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor) + tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor) return tool -@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolReturnMessage, operation_id="mcp_run_tool") +@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool") async def run_mcp_tool( mcp_server_id: str, tool_id: str, @@ -188,9 +182,10 @@ async def run_mcp_tool( actor=actor, ) - # Create a ToolReturnMessage - return ToolReturnMessage( - id=f"tool-return-{tool_id}", tool_call_id=f"call-{tool_id}", tool_return=result, status="success" if success else "error" + # Create a ToolExecutionResult + return ToolExecutionResult( + status="success" if success else "error", + func_return=result, ) @@ -231,6 +226,7 @@ async def refresh_mcp_server_tools( ) async def connect_mcp_server( mcp_server_id: str, + request: Request, server: SyncServer = Depends(get_letta_server), headers: HeaderParams = Depends(get_headers), ) -> StreamingResponse: @@ -238,72 +234,76 @@ async def connect_mcp_server( Connect to an MCP server with support for OAuth via SSE. Returns a stream of events handling authorization state and exchange if OAuth is required. """ - pass + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor) - # async def oauth_stream_generator( - # request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], - # http_request: Request, - # ) -> AsyncGenerator[str, None]: - # client = None + # Convert the MCP server to the appropriate config type + config = mcp_server.to_config(resolve_variables=False) - # oauth_flow_attempted = False - # try: - # # Acknolwedge connection attempt - # yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name) + async def oauth_stream_generator( + mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], + http_request: Request, + ) -> AsyncGenerator[str, None]: + client = None - # actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + oauth_flow_attempted = False + try: + # Acknowledge connection attempt + yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name) - # # Create MCP client with respective transport type - # try: - # request.resolve_environment_variables() - # client = await server.mcp_server_manager.get_mcp_client(request, actor) - # except ValueError as e: - # yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e)) - # return + # Create MCP client with respective transport type + try: + mcp_config.resolve_environment_variables() + client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor) + except ValueError as e: + yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e)) + return - # # Try normal connection first for flows that don't require OAuth - # try: - # await client.connect_to_server() - # tools = await client.list_tools(serialize=True) - # yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) - # return - # except ConnectionError: - # # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error - # if isinstance(client, AsyncStdioMCPClient): - # logger.warning("OAuth not supported for stdio") - # yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio") - # return - # # Continue to OAuth flow - # logger.info(f"Attempting OAuth flow for {request}...") - # except Exception as e: - # yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}") - # return - # finally: - # if client: - # try: - # await client.cleanup() - # # This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py - # # For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception - # except* HTTPStatusError: - # oauth_flow_attempted = True - # async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request): - # yield event + # Try normal connection first for flows that don't require OAuth + try: + await client.connect_to_server() + tools = await client.list_tools(serialize=True) + yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) + return + except ConnectionError: + # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error + if isinstance(client, AsyncStdioMCPClient): + logger.warning("OAuth not supported for stdio") + yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio") + return + # Continue to OAuth flow + logger.info(f"Attempting OAuth flow for {mcp_config}...") + except Exception as e: + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}") + return + finally: + if client: + try: + await client.cleanup() + # This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py + # For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception + except HTTPStatusError: + oauth_flow_attempted = True + async for event in server.mcp_server_manager.handle_oauth_flow( + request=mcp_config, actor=actor, http_request=http_request + ): + yield event - # # Failsafe to make sure we don't try to handle OAuth flow twice - # if not oauth_flow_attempted: - # async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request): - # yield event - # return - # except Exception as e: - # detailed_error = drill_down_exception(e) - # logger.error(f"Error in OAuth stream:\n{detailed_error}") - # yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}") + # Failsafe to make sure we don't try to handle OAuth flow twice + if not oauth_flow_attempted: + async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request): + yield event + return + except Exception as e: + detailed_error = drill_down_exception(e) + logger.error(f"Error in OAuth stream:\n{detailed_error}") + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}") - # finally: - # if client: - # try: - # await client.cleanup() - # except Exception as cleanup_error: - # logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") + finally: + if client: + try: + await client.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") - # return StreamingResponseWithStatusCode(oauth_stream_generator(request, http_request), media_type="text/event-stream") + return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream") diff --git a/letta/server/server.py b/letta/server/server.py index f9686557..b9b51a70 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -94,6 +94,7 @@ from letta.services.mcp.base_client import AsyncBaseMCPClient from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient from letta.services.mcp.stdio_client import AsyncStdioMCPClient from letta.services.mcp_manager import MCPManager +from letta.services.mcp_server_manager import MCPServerManager from letta.services.message_manager import MessageManager from letta.services.organization_manager import OrganizationManager from letta.services.passage_manager import PassageManager @@ -154,6 +155,7 @@ class SyncServer(object): self.user_manager = UserManager() self.tool_manager = ToolManager() self.mcp_manager = MCPManager() + self.mcp_server_manager = MCPServerManager() self.block_manager = BlockManager() self.source_manager = SourceManager() self.sandbox_config_manager = SandboxConfigManager() diff --git a/letta/services/mcp_server_manager.py b/letta/services/mcp_server_manager.py new file mode 100644 index 00000000..7f71d828 --- /dev/null +++ b/letta/services/mcp_server_manager.py @@ -0,0 +1,1331 @@ +import json +import os +import secrets +import uuid +from datetime import datetime, timedelta +from typing import Any, Dict, List, Optional, Tuple, Union + +from fastapi import HTTPException +from sqlalchemy import delete, desc, null, select +from starlette.requests import Request + +import letta.constants as constants +from letta.functions.mcp_client.types import ( + MCPServerType, + MCPTool, + MCPToolHealth, + SSEServerConfig, + StdioServerConfig, + StreamableHTTPServerConfig, +) +from letta.functions.schema_generator import normalize_mcp_schema +from letta.functions.schema_validator import validate_complete_json_schema +from letta.log import get_logger +from letta.orm.errors import NoResultFound +from letta.orm.mcp_oauth import MCPOAuth, OAuthSessionStatus +from letta.orm.mcp_server import MCPServer as MCPServerModel, MCPTools as MCPToolsModel +from letta.orm.tool import Tool as ToolModel +from letta.schemas.mcp import ( + MCPOAuthSession, + MCPOAuthSessionCreate, + MCPOAuthSessionUpdate, + MCPServer, + MCPServerResyncResult, + UpdateMCPServer, + UpdateSSEMCPServer, + UpdateStdioMCPServer, + UpdateStreamableHTTPMCPServer, +) +from letta.schemas.secret import Secret +from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate +from letta.schemas.user import User as PydanticUser +from letta.server.db import db_registry +from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient +from letta.services.mcp.stdio_client import AsyncStdioMCPClient +from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient +from letta.services.tool_manager import ToolManager +from letta.settings import settings, tool_settings +from letta.utils import enforce_types, printd, safe_create_task + +logger = get_logger(__name__) + + +class MCPServerManager: + """Manager class to handle business logic related to MCP.""" + + def __init__(self): + # TODO: timeouts? + self.tool_manager = ToolManager() + self.cached_mcp_servers = {} # maps id -> async connection + + # MCPTools mapping table management methods + @enforce_types + async def create_mcp_tool_mapping(self, mcp_server_id: str, tool_id: str, actor: PydanticUser) -> None: + """Create a mapping between an MCP server and a tool.""" + async with db_registry.async_session() as session: + mapping = MCPToolsModel( + id=f"mcp-tool-mapping-{uuid.uuid4()}", + mcp_server_id=mcp_server_id, + tool_id=tool_id, + organization_id=actor.organization_id, + ) + await mapping.create_async(session, actor=actor) + + @enforce_types + async def delete_mcp_tool_mappings_by_server(self, mcp_server_id: str, actor: PydanticUser) -> None: + """Delete all tool mappings for a specific MCP server.""" + async with db_registry.async_session() as session: + await session.execute( + delete(MCPToolsModel).where( + MCPToolsModel.mcp_server_id == mcp_server_id, + MCPToolsModel.organization_id == actor.organization_id, + ) + ) + await session.commit() + + @enforce_types + async def get_tool_ids_by_mcp_server(self, mcp_server_id: str, actor: PydanticUser) -> List[str]: + """Get all tool IDs associated with an MCP server.""" + async with db_registry.async_session() as session: + result = await session.execute( + select(MCPToolsModel.tool_id).where( + MCPToolsModel.mcp_server_id == mcp_server_id, + MCPToolsModel.organization_id == actor.organization_id, + ) + ) + return [row[0] for row in result.fetchall()] + + @enforce_types + async def get_mcp_server_id_by_tool(self, tool_id: str, actor: PydanticUser) -> Optional[str]: + """Get the MCP server ID associated with a tool.""" + async with db_registry.async_session() as session: + result = await session.execute( + select(MCPToolsModel.mcp_server_id).where( + MCPToolsModel.tool_id == tool_id, + MCPToolsModel.organization_id == actor.organization_id, + ) + ) + row = result.fetchone() + return row[0] if row else None + + @enforce_types + async def list_tools_by_mcp_server_from_db(self, mcp_server_id: str, actor: PydanticUser) -> List[PydanticTool]: + """ + Get tools associated with an MCP server from the database using the MCPTools mapping. + This is more efficient than fetching from the MCP server directly. + """ + # First get all tool IDs associated with this MCP server + tool_ids = await self.get_tool_ids_by_mcp_server(mcp_server_id, actor) + + if not tool_ids: + return [] + + # Fetch all tools in a single query + async with db_registry.async_session() as session: + result = await session.execute( + select(ToolModel).where( + ToolModel.id.in_(tool_ids), + ToolModel.organization_id == actor.organization_id, + ) + ) + tools = result.scalars().all() + return [tool.to_pydantic() for tool in tools] + + @enforce_types + async def get_tool_by_mcp_server(self, mcp_server_id: str, tool_id: str, actor: PydanticUser) -> Optional[PydanticTool]: + """ + Get a specific tool that belongs to an MCP server. + Verifies the tool is associated with the MCP server via the mapping table. + """ + async with db_registry.async_session() as session: + # Check if the tool is associated with this MCP server + result = await session.execute( + select(MCPToolsModel).where( + MCPToolsModel.mcp_server_id == mcp_server_id, + MCPToolsModel.tool_id == tool_id, + MCPToolsModel.organization_id == actor.organization_id, + ) + ) + mapping = result.scalar_one_or_none() + + if not mapping: + return None + + # Fetch the tool + tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) + return tool.to_pydantic() + + @enforce_types + async def list_mcp_server_tools(self, mcp_server_id: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]: + """Get a list of all tools for a specific MCP server by server ID.""" + mcp_client = None + try: + mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) + server_config = mcp_config.to_config() + mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id) + await mcp_client.connect_to_server() + + # list tools + tools = await mcp_client.list_tools() + # Add health information to each tool + for tool in tools: + # Try to normalize the schema and re-validate + if tool.inputSchema: + tool.inputSchema = normalize_mcp_schema(tool.inputSchema) + health_status, reasons = validate_complete_json_schema(tool.inputSchema) + tool.health = MCPToolHealth(status=health_status.value, reasons=reasons) + + return tools + except Exception as e: + # MCP tool listing errors are often due to connection/configuration issues, not system errors + # Log at info level to avoid triggering Sentry alerts for expected failures + logger.warning(f"Error listing tools for MCP server {mcp_server_id}: {e}") + raise e + finally: + if mcp_client: + try: + await mcp_client.cleanup() + except Exception as e: + logger.warning(f"Error listing tools for MCP server {mcp_server_id}: {e}") + raise e + + @enforce_types + async def execute_mcp_server_tool( + self, + mcp_server_id: str, + tool_id: str, + tool_args: Optional[Dict[str, Any]], + environment_variables: Dict[str, str], + actor: PydanticUser, + agent_id: Optional[str] = None, + ) -> Tuple[str, bool]: + """Call a specific tool from a specific MCP server by IDs.""" + mcp_client = None + try: + # Get the tool to find its actual name + async with db_registry.async_session() as session: + tool = await ToolModel.read_async(db_session=session, identifier=tool_id, actor=actor) + tool_name = tool.name + + # Get the MCP server config + mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) + server_config = mcp_config.to_config(environment_variables) + + mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id) + await mcp_client.connect_to_server() + + # call tool + result, success = await mcp_client.execute_tool(tool_name, tool_args) + logger.info(f"MCP Result: {result}, Success: {success}") + return result, success + finally: + if mcp_client: + await mcp_client.cleanup() + + @enforce_types + async def add_tool_from_mcp_server(self, mcp_server_id: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool: + """Add a tool from an MCP server to the Letta tool registry.""" + # Get the MCP server to get its name + mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) + mcp_server_name = mcp_server.server_name + + mcp_tools = await self.list_mcp_server_tools(mcp_server_id, actor=actor) + for mcp_tool in mcp_tools: + # TODO: @jnjpng move health check to tool class + if mcp_tool.name == mcp_tool_name: + # Check tool health - but try normalization first for INVALID schemas + if mcp_tool.health and mcp_tool.health.status == "INVALID": + logger.info(f"Attempting to normalize INVALID schema for tool {mcp_tool_name}") + logger.info(f"Original health reasons: {mcp_tool.health.reasons}") + + # Try to normalize the schema and re-validate + try: + # Normalize the schema to fix common issues + logger.debug(f"Normalizing schema for {mcp_tool_name}") + normalized_schema = normalize_mcp_schema(mcp_tool.inputSchema) + + # Re-validate after normalization + logger.debug(f"Re-validating schema for {mcp_tool_name}") + health_status, health_reasons = validate_complete_json_schema(normalized_schema) + logger.info(f"After normalization: status={health_status.value}, reasons={health_reasons}") + + # Update the tool's schema and health (use inputSchema, not input_schema) + mcp_tool.inputSchema = normalized_schema + mcp_tool.health.status = health_status.value + mcp_tool.health.reasons = health_reasons + + # Log the normalization result + if health_status.value != "INVALID": + logger.info(f"✓ MCP tool {mcp_tool_name} schema normalized successfully: {health_status.value}") + else: + logger.warning(f"MCP tool {mcp_tool_name} still INVALID after normalization. Reasons: {health_reasons}") + except Exception as e: + logger.error(f"Failed to normalize schema for tool {mcp_tool_name}: {e}", exc_info=True) + + # After normalization attempt, check if still INVALID + if mcp_tool.health and mcp_tool.health.status == "INVALID": + logger.warning(f"Tool {mcp_tool_name} has potentially invalid schema. Reasons: {', '.join(mcp_tool.health.reasons)}") + + tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) + created_tool = await self.tool_manager.create_mcp_tool_async( + tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor + ) + + # Create mapping in MCPTools table + if created_tool: + await self.create_mcp_tool_mapping(mcp_server_id, created_tool.id, actor) + + return created_tool + + # failed to add - handle error? + return None + + @enforce_types + async def resync_mcp_server_tools( + self, mcp_server_id: str, actor: PydanticUser, agent_id: Optional[str] = None + ) -> MCPServerResyncResult: + """ + Resync tools for an MCP server by: + 1. Fetching current tools from the MCP server + 2. Deleting tools that no longer exist on the server + 3. Updating schemas for existing tools + 4. Adding new tools from the server + + Returns a result with: + - deleted: List of deleted tool names + - updated: List of updated tool names + - added: List of added tool names + """ + # Get the MCP server to get its name + mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) + mcp_server_name = mcp_server.server_name + + # Fetch current tools from MCP server + try: + current_mcp_tools = await self.list_mcp_server_tools(mcp_server_id, actor=actor, agent_id=agent_id) + except Exception as e: + logger.error(f"Failed to fetch tools from MCP server {mcp_server_name}: {e}") + raise HTTPException( + status_code=404, + detail={ + "code": "MCPServerUnavailable", + "message": f"Could not connect to MCP server {mcp_server_name} to resync tools", + "error": str(e), + }, + ) + + # Get all persisted tools for this MCP server + async with db_registry.async_session() as session: + # Query for tools with MCP metadata matching this server + # Using JSON path query to filter by metadata + persisted_tools = await ToolModel.list_async( + db_session=session, + organization_id=actor.organization_id, + ) + + # Filter tools that belong to this MCP server + mcp_tools = [] + for tool in persisted_tools: + if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_: + if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id: + mcp_tools.append(tool) + + # Create maps for easier comparison + current_tool_map = {tool.name: tool for tool in current_mcp_tools} + persisted_tool_map = {tool.name: tool for tool in mcp_tools} + + deleted_tools = [] + updated_tools = [] + added_tools = [] + + # 1. Delete tools that no longer exist on the server + for tool_name, persisted_tool in persisted_tool_map.items(): + if tool_name not in current_tool_map: + # Delete the tool (cascade will handle agent detachment) + await persisted_tool.hard_delete_async(db_session=session, actor=actor) + deleted_tools.append(tool_name) + logger.info(f"Deleted MCP tool {tool_name} as it no longer exists on server {mcp_server_name}") + + # Commit deletions + await session.commit() + + # 2. Update existing tools and add new tools + for tool_name, current_tool in current_tool_map.items(): + if tool_name in persisted_tool_map: + # Update existing tool + persisted_tool = persisted_tool_map[tool_name] + tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool) + + # Check if schema has changed + if persisted_tool.json_schema != tool_create.json_schema: + # Update the tool + update_data = ToolUpdate( + description=tool_create.description, + json_schema=tool_create.json_schema, + source_code=tool_create.source_code, + ) + + await self.tool_manager.update_tool_by_id_async(tool_id=persisted_tool.id, tool_update=update_data, actor=actor) + updated_tools.append(tool_name) + logger.info(f"Updated MCP tool {tool_name} with new schema from server {mcp_server_name}") + else: + # Add new tool + # Skip INVALID tools + if current_tool.health and current_tool.health.status == "INVALID": + logger.warning( + f"Skipping invalid tool {tool_name} from MCP server {mcp_server_name}: {', '.join(current_tool.health.reasons)}" + ) + continue + + tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=current_tool) + created_tool = await self.tool_manager.create_mcp_tool_async( + tool_create=tool_create, mcp_server_name=mcp_server_name, mcp_server_id=mcp_server_id, actor=actor + ) + + # Create mapping in MCPTools table + if created_tool: + await self.create_mcp_tool_mapping(mcp_server_id, created_tool.id, actor) + added_tools.append(tool_name) + logger.info(f"Added new MCP tool {tool_name} from server {mcp_server_name} with mapping") + + return MCPServerResyncResult( + deleted=deleted_tools, + updated=updated_tools, + added=added_tools, + ) + + @enforce_types + async def list_mcp_servers(self, actor: PydanticUser) -> List[MCPServer]: + """List all MCP servers available""" + async with db_registry.async_session() as session: + mcp_servers = await MCPServerModel.list_async( + db_session=session, + organization_id=actor.organization_id, + ) + + return [mcp_server.to_pydantic() for mcp_server in mcp_servers] + + @enforce_types + async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: + """Create a new tool based on the ToolCreate schema.""" + mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name=pydantic_mcp_server.server_name, actor=actor) + if mcp_server_id: + # Put to dict and remove fields that should not be reset + update_data = pydantic_mcp_server.model_dump(exclude_unset=True, exclude_none=True) + + # If there's anything to update (can only update the configs, not the name) + # TODO: pass in custom headers for update as well? + if update_data: + if pydantic_mcp_server.server_type == MCPServerType.SSE: + update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token) + elif pydantic_mcp_server.server_type == MCPServerType.STDIO: + update_request = UpdateStdioMCPServer(stdio_config=pydantic_mcp_server.stdio_config) + elif pydantic_mcp_server.server_type == MCPServerType.STREAMABLE_HTTP: + update_request = UpdateStreamableHTTPMCPServer( + server_url=pydantic_mcp_server.server_url, auth_token=pydantic_mcp_server.token + ) + else: + raise ValueError(f"Unsupported server type: {pydantic_mcp_server.server_type}") + mcp_server = await self.update_mcp_server_by_id(mcp_server_id, update_request, actor) + else: + printd( + f"`create_or_update_mcp_server` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_mcp_server.server_name}, but found existing mcp server with nothing to update." + ) + mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) + else: + mcp_server = await self.create_mcp_server(pydantic_mcp_server, actor=actor) + + return mcp_server + + @enforce_types + async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: + """Create a new MCP server.""" + async with db_registry.async_session() as session: + try: + # Set the organization id at the ORM layer + pydantic_mcp_server.organization_id = actor.organization_id + + # Explicitly populate encrypted fields + if pydantic_mcp_server.token is not None: + pydantic_mcp_server.token_enc = Secret.from_plaintext(pydantic_mcp_server.token) + if pydantic_mcp_server.custom_headers is not None: + # custom_headers is a Dict[str, str], serialize to JSON then encrypt + import json + + json_str = json.dumps(pydantic_mcp_server.custom_headers) + pydantic_mcp_server.custom_headers_enc = Secret.from_plaintext(json_str) + + mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) + + # Ensure custom_headers None is stored as SQL NULL, not JSON null + if mcp_server_data.get("custom_headers") is None: + mcp_server_data.pop("custom_headers", None) + + mcp_server = MCPServerModel(**mcp_server_data) + mcp_server = await mcp_server.create_async(session, actor=actor, no_commit=True) + + # Link existing OAuth sessions for the same user and server URL + # This ensures OAuth sessions created during testing get linked to the server + server_url = getattr(mcp_server, "server_url", None) + if server_url: + result = await session.execute( + select(MCPOAuth).where( + MCPOAuth.server_url == server_url, + MCPOAuth.organization_id == actor.organization_id, + MCPOAuth.user_id == actor.id, # Only link sessions for the same user + MCPOAuth.server_id.is_(None), # Only update sessions not already linked + ) + ) + oauth_sessions = result.scalars().all() + + # TODO: @jnjpng we should upate sessions in bulk + for oauth_session in oauth_sessions: + oauth_session.server_id = mcp_server.id + await oauth_session.update_async(db_session=session, actor=actor, no_commit=True) + + if oauth_sessions: + logger.info( + f"Linked {len(oauth_sessions)} OAuth sessions to MCP server {mcp_server.id} (URL: {server_url}) for user {actor.id}" + ) + + await session.commit() + return mcp_server.to_pydantic() + except Exception as e: + await session.rollback() + raise + + @enforce_types + async def create_mcp_server_from_config( + self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser + ) -> MCPServer: + """ + Create an MCP server from a config object, handling encryption of sensitive fields. + + This method converts the server config to an MCPServer model and encrypts + sensitive fields like tokens and custom headers. + """ + # Create base MCPServer object + if isinstance(server_config, StdioServerConfig): + mcp_server = MCPServer(server_name=server_config.server_name, server_type=server_config.type, stdio_config=server_config) + elif isinstance(server_config, SSEServerConfig): + mcp_server = MCPServer( + server_name=server_config.server_name, + server_type=server_config.type, + server_url=server_config.server_url, + ) + # Encrypt sensitive fields + token = server_config.resolve_token() + if token: + token_secret = Secret.from_plaintext(token) + mcp_server.set_token_secret(token_secret) + if server_config.custom_headers: + # Convert dict to JSON string, then encrypt as Secret + headers_json = json.dumps(server_config.custom_headers) + headers_secret = Secret.from_plaintext(headers_json) + mcp_server.set_custom_headers_secret(headers_secret) + + elif isinstance(server_config, StreamableHTTPServerConfig): + mcp_server = MCPServer( + server_name=server_config.server_name, + server_type=server_config.type, + server_url=server_config.server_url, + ) + # Encrypt sensitive fields + token = server_config.resolve_token() + if token: + token_secret = Secret.from_plaintext(token) + mcp_server.set_token_secret(token_secret) + if server_config.custom_headers: + # Convert dict to JSON string, then encrypt as Secret + headers_json = json.dumps(server_config.custom_headers) + headers_secret = Secret.from_plaintext(headers_json) + mcp_server.set_custom_headers_secret(headers_secret) + else: + raise ValueError(f"Unsupported server config type: {type(server_config)}") + + return mcp_server + + @enforce_types + async def create_mcp_server_from_config_with_tools( + self, server_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], actor: PydanticUser + ) -> MCPServer: + """ + Create an MCP server from a config object and optimistically sync its tools. + + This method handles encryption of sensitive fields and then creates the server + with automatic tool synchronization. + """ + # Convert config to MCPServer with encryption + mcp_server = await self.create_mcp_server_from_config(server_config, actor) + + # Create the server with tools + return await self.create_mcp_server_with_tools(mcp_server, actor) + + @enforce_types + async def create_mcp_server_with_tools(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: + """ + Create a new MCP server and optimistically sync its tools. + + This method: + 1. Creates the MCP server record + 2. Attempts to connect and fetch tools + 3. Persists valid tools in parallel (best-effort) + """ + import asyncio + + # First, create the MCP server + created_server = await self.create_mcp_server(pydantic_mcp_server, actor) + + # Optimistically try to sync tools + try: + logger.info(f"Attempting to auto-sync tools from MCP server: {created_server.server_name}") + + # List all tools from the MCP server + mcp_tools = await self.list_mcp_server_tools(created_server.id, actor=actor) + + # Filter out invalid tools + valid_tools = [tool for tool in mcp_tools if not (tool.health and tool.health.status == "INVALID")] + + # Register in parallel + if valid_tools: + tool_tasks = [] + for mcp_tool in valid_tools: + tool_create = ToolCreate.from_mcp(mcp_server_name=created_server.server_name, mcp_tool=mcp_tool) + task = self.tool_manager.create_mcp_tool_async( + tool_create=tool_create, mcp_server_name=created_server.server_name, mcp_server_id=created_server.id, actor=actor + ) + tool_tasks.append(task) + + results = await asyncio.gather(*tool_tasks, return_exceptions=True) + + # Create mappings in MCPTools table for successful tools + mapping_tasks = [] + successful_count = 0 + for result in results: + if not isinstance(result, Exception) and result: + # result should be a PydanticTool + mapping_task = self.create_mcp_tool_mapping(created_server.id, result.id, actor) + mapping_tasks.append(mapping_task) + successful_count += 1 + + # Execute mapping creation in parallel + if mapping_tasks: + await asyncio.gather(*mapping_tasks, return_exceptions=True) + + failed = len(results) - successful_count + logger.info( + f"Auto-sync completed for MCP server {created_server.server_name}: " + f"{successful_count} tools persisted with mappings, {failed} failed, " + f"{len(mcp_tools) - len(valid_tools)} invalid tools skipped" + ) + else: + logger.info(f"No valid tools found to sync from MCP server {created_server.server_name}") + + except Exception as e: + # Log the error but don't fail the server creation + logger.warning( + f"Failed to auto-sync tools from MCP server {created_server.server_name}: {e}. " + f"Server was created successfully but tools were not persisted." + ) + + return created_server + + @enforce_types + async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer: + """Update a tool by its ID with the given ToolUpdate object.""" + async with db_registry.async_session() as session: + # Fetch the tool by ID + mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) + + # Update tool attributes with only the fields that were explicitly set + update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True) + + # If renaming, proactively resolve name collisions within the same organization + new_name = update_data.get("server_name") + if new_name and new_name != getattr(mcp_server, "server_name", None): + # Look for another server with the same name in this org + existing = await MCPServerModel.list_async( + db_session=session, + organization_id=actor.organization_id, + server_name=new_name, + ) + # Delete conflicting entries that are not the current server + for other in existing: + if other.id != mcp_server.id: + await session.execute( + delete(MCPServerModel).where( + MCPServerModel.id == other.id, + MCPServerModel.organization_id == actor.organization_id, + ) + ) + + # Handle encryption for token if provided + # Only re-encrypt if the value has actually changed + if "token" in update_data and update_data["token"] is not None: + # Check if value changed + existing_token = None + if mcp_server.token_enc: + existing_secret = Secret.from_encrypted(mcp_server.token_enc) + existing_token = existing_secret.get_plaintext() + elif mcp_server.token: + existing_token = mcp_server.token + + # Only re-encrypt if different + if existing_token != update_data["token"]: + mcp_server.token_enc = Secret.from_plaintext(update_data["token"]).get_encrypted() + # Keep plaintext for dual-write during migration + mcp_server.token = update_data["token"] + + # Remove from update_data since we set directly on mcp_server + update_data.pop("token", None) + update_data.pop("token_enc", None) + + # Handle encryption for custom_headers if provided + # Only re-encrypt if the value has actually changed + if "custom_headers" in update_data: + if update_data["custom_headers"] is not None: + # custom_headers is a Dict[str, str], serialize to JSON then encrypt + import json + + json_str = json.dumps(update_data["custom_headers"]) + + # Check if value changed + existing_headers_json = None + if mcp_server.custom_headers_enc: + existing_secret = Secret.from_encrypted(mcp_server.custom_headers_enc) + existing_headers_json = existing_secret.get_plaintext() + elif mcp_server.custom_headers: + existing_headers_json = json.dumps(mcp_server.custom_headers) + + # Only re-encrypt if different + if existing_headers_json != json_str: + mcp_server.custom_headers_enc = Secret.from_plaintext(json_str).get_encrypted() + # Keep plaintext for dual-write during migration + mcp_server.custom_headers = update_data["custom_headers"] + + # Remove from update_data since we set directly on mcp_server + update_data.pop("custom_headers", None) + update_data.pop("custom_headers_enc", None) + else: + # Ensure custom_headers None is stored as SQL NULL, not JSON null + update_data.pop("custom_headers", None) + setattr(mcp_server, "custom_headers", null()) + setattr(mcp_server, "custom_headers_enc", None) + + for key, value in update_data.items(): + setattr(mcp_server, key, value) + + mcp_server = await mcp_server.update_async(db_session=session, actor=actor) + + # Save the updated tool to the database mcp_server = await mcp_server.update_async(db_session=session, actor=actor) + return mcp_server.to_pydantic() + + @enforce_types + async def update_mcp_server_by_name(self, mcp_server_name: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer: + """Update an MCP server by its name.""" + mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) + if not mcp_server_id: + raise HTTPException( + status_code=404, + detail={ + "code": "MCPServerNotFoundError", + "message": f"MCP server {mcp_server_name} not found", + "mcp_server_name": mcp_server_name, + }, + ) + return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor) + + @enforce_types + async def get_mcp_server_id_by_name(self, mcp_server_name: str, actor: PydanticUser) -> Optional[str]: + """Retrieve a MCP server by its name and a user""" + try: + async with db_registry.async_session() as session: + mcp_server = await MCPServerModel.read_async(db_session=session, server_name=mcp_server_name, actor=actor) + return mcp_server.id + except NoResultFound: + return None + + @enforce_types + async def get_mcp_server_by_id_async(self, mcp_server_id: str, actor: PydanticUser) -> MCPServer: + """Fetch a tool by its ID.""" + async with db_registry.async_session() as session: + # Retrieve tool by id using the Tool model's read method + mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) + # Convert the SQLAlchemy Tool object to PydanticTool + return mcp_server.to_pydantic() + + @enforce_types + async def get_mcp_servers_by_ids(self, mcp_server_ids: List[str], actor: PydanticUser) -> List[MCPServer]: + """Fetch multiple MCP servers by their IDs in a single query.""" + if not mcp_server_ids: + return [] + + async with db_registry.async_session() as session: + mcp_servers = await MCPServerModel.list_async( + db_session=session, + organization_id=actor.organization_id, + id=mcp_server_ids, # This will use the IN operator + ) + return [mcp_server.to_pydantic() for mcp_server in mcp_servers] + + @enforce_types + async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool: + """Get a MCP server by name.""" + async with db_registry.async_session() as session: + mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) + mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) + if not mcp_server: + raise HTTPException( + status_code=404, # Not Found + detail={ + "code": "MCPServerNotFoundError", + "message": f"MCP server {mcp_server_name} not found", + "mcp_server_name": mcp_server_name, + }, + ) + return mcp_server.to_pydantic() + + @enforce_types + async def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None: + """Delete a MCP server by its ID and associated tools and OAuth sessions.""" + async with db_registry.async_session() as session: + try: + mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) + if not mcp_server: + raise NoResultFound(f"MCP server with id {mcp_server_id} not found.") + + server_url = getattr(mcp_server, "server_url", None) + # Get all tools with matching metadata + stmt = select(ToolModel).where(ToolModel.organization_id == actor.organization_id) + result = await session.execute(stmt) + all_tools = result.scalars().all() + + # Filter and delete tools that belong to this MCP server + tools_deleted = 0 + for tool in all_tools: + if tool.metadata_ and constants.MCP_TOOL_TAG_NAME_PREFIX in tool.metadata_: + if tool.metadata_[constants.MCP_TOOL_TAG_NAME_PREFIX].get("server_id") == mcp_server_id: + await tool.hard_delete_async(db_session=session, actor=actor) + tools_deleted = 1 + logger.info(f"Deleted MCP tool {tool.name} associated with MCP server {mcp_server_id}") + + if tools_deleted > 0: + logger.info(f"Deleted {tools_deleted} MCP tools associated with MCP server {mcp_server_id}") + + # Delete all MCPTools mappings for this server + await session.execute( + delete(MCPToolsModel).where( + MCPToolsModel.mcp_server_id == mcp_server_id, + MCPToolsModel.organization_id == actor.organization_id, + ) + ) + logger.info(f"Deleted MCPTools mappings for MCP server {mcp_server_id}") + + # Delete OAuth sessions for the same user and server URL in the same transaction + # This handles orphaned sessions that were created during testing/connection + oauth_count = 0 + if server_url: + result = await session.execute( + delete(MCPOAuth).where( + MCPOAuth.server_url == server_url, + MCPOAuth.organization_id == actor.organization_id, + MCPOAuth.user_id == actor.id, # Only delete sessions for the same user + ) + ) + oauth_count = result.rowcount + if oauth_count > 0: + logger.info( + f"Deleting {oauth_count} OAuth sessions for MCP server {mcp_server_id} (URL: {server_url}) for user {actor.id}" + ) + + # Delete the MCP server, will cascade delete to linked OAuth sessions + await session.execute( + delete(MCPServerModel).where( + MCPServerModel.id == mcp_server_id, + MCPServerModel.organization_id == actor.organization_id, + ) + ) + + await session.commit() + except NoResultFound: + await session.rollback() + raise ValueError(f"MCP server with id {mcp_server_id} not found.") + except Exception as e: + await session.rollback() + logger.error(f"Failed to delete MCP server {mcp_server_id}: {e}") + raise + + def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]]: + mcp_server_list = {} + + # Attempt to read from ~/.letta/mcp_config.json + mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) + if os.path.exists(mcp_config_path): + with open(mcp_config_path, "r") as f: + try: + mcp_config = json.load(f) + except Exception as e: + # Config parsing errors are user configuration issues, not system errors + logger.warning(f"Failed to parse MCP config file ({mcp_config_path}) as json: {e}") + return mcp_server_list + + # Proper formatting is "mcpServers" key at the top level, + # then a dict with the MCP server name as the key, + # with the value being the schema from StdioServerParameters + if MCP_CONFIG_TOPLEVEL_KEY in mcp_config: + for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items(): + # No support for duplicate server names + if server_name in mcp_server_list: + # Duplicate server names are configuration issues, not system errors + logger.warning(f"Duplicate MCP server name found (skipping): {server_name}") + continue + + if "url" in server_params_raw: + # Attempt to parse the server params as an SSE server + try: + server_params = SSEServerConfig( + server_name=server_name, + server_url=server_params_raw["url"], + auth_header=server_params_raw.get("auth_header", None), + auth_token=server_params_raw.get("auth_token", None), + headers=server_params_raw.get("headers", None), + ) + mcp_server_list[server_name] = server_params + except Exception as e: + # Config parsing errors are user configuration issues, not system errors + logger.warning(f"Failed to parse server params for MCP server {server_name} (skipping): {e}") + continue + else: + # Attempt to parse the server params as a StdioServerParameters + try: + server_params = StdioServerConfig( + server_name=server_name, + command=server_params_raw["command"], + args=server_params_raw.get("args", []), + env=server_params_raw.get("env", {}), + ) + mcp_server_list[server_name] = server_params + except Exception as e: + # Config parsing errors are user configuration issues, not system errors + logger.warning(f"Failed to parse server params for MCP server {server_name} (skipping): {e}") + continue + return mcp_server_list + + async def get_mcp_client( + self, + server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig], + actor: PydanticUser, + oauth_provider: Optional[Any] = None, + agent_id: Optional[str] = None, + ) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]: + """ + Helper function to create the appropriate MCP client based on server configuration. + + Args: + server_config: The server configuration object + actor: The user making the request + oauth_provider: Optional OAuth provider for authentication + + Returns: + The appropriate MCP client instance + + Raises: + ValueError: If server config type is not supported + """ + # If no OAuth provider is provided, check if we have stored OAuth credentials + if oauth_provider is None and hasattr(server_config, "server_url"): + oauth_session = await self.get_oauth_session_by_server(server_config.server_url, actor) + # Check if access token exists by attempting to decrypt it + if oauth_session and oauth_session.get_access_token_secret().get_plaintext(): + # Create OAuth provider from stored credentials + from letta.services.mcp.oauth_utils import create_oauth_provider + + oauth_provider = await create_oauth_provider( + session_id=oauth_session.id, + server_url=oauth_session.server_url, + redirect_uri=oauth_session.redirect_uri, + mcp_manager=self, + actor=actor, + ) + + if server_config.type == MCPServerType.SSE: + server_config = SSEServerConfig(**server_config.model_dump()) + return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) + elif server_config.type == MCPServerType.STDIO: + server_config = StdioServerConfig(**server_config.model_dump()) + return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) + elif server_config.type == MCPServerType.STREAMABLE_HTTP: + server_config = StreamableHTTPServerConfig(**server_config.model_dump()) + return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) + else: + raise ValueError(f"Unsupported server config type: {type(server_config)}") + + # OAuth-related methods + def _oauth_orm_to_pydantic(self, oauth_session: MCPOAuth) -> MCPOAuthSession: + """ + Convert OAuth ORM model to Pydantic model, handling decryption of sensitive fields. + """ + # Get decrypted values using the dual-read approach + # Secret.from_db() will automatically use settings.encryption_key if available + access_token = None + if oauth_session.access_token_enc or oauth_session.access_token: + if settings.encryption_key: + secret = Secret.from_db(oauth_session.access_token_enc, oauth_session.access_token) + access_token = secret.get_plaintext() + else: + # No encryption key, use plaintext if available + access_token = oauth_session.access_token + + refresh_token = None + if oauth_session.refresh_token_enc or oauth_session.refresh_token: + if settings.encryption_key: + secret = Secret.from_db(oauth_session.refresh_token_enc, oauth_session.refresh_token) + refresh_token = secret.get_plaintext() + else: + # No encryption key, use plaintext if available + refresh_token = oauth_session.refresh_token + + client_secret = None + if oauth_session.client_secret_enc or oauth_session.client_secret: + if settings.encryption_key: + secret = Secret.from_db(oauth_session.client_secret_enc, oauth_session.client_secret) + client_secret = secret.get_plaintext() + else: + # No encryption key, use plaintext if available + client_secret = oauth_session.client_secret + + authorization_code = None + if oauth_session.authorization_code_enc or oauth_session.authorization_code: + if settings.encryption_key: + secret = Secret.from_db(oauth_session.authorization_code_enc, oauth_session.authorization_code) + authorization_code = secret.get_plaintext() + else: + # No encryption key, use plaintext if available + authorization_code = oauth_session.authorization_code + + # Create the Pydantic object with encrypted fields as Secret objects + pydantic_session = MCPOAuthSession( + id=oauth_session.id, + state=oauth_session.state, + server_id=oauth_session.server_id, + server_url=oauth_session.server_url, + server_name=oauth_session.server_name, + user_id=oauth_session.user_id, + organization_id=oauth_session.organization_id, + authorization_url=oauth_session.authorization_url, + authorization_code=authorization_code, + access_token=access_token, + refresh_token=refresh_token, + token_type=oauth_session.token_type, + expires_at=oauth_session.expires_at, + scope=oauth_session.scope, + client_id=oauth_session.client_id, + client_secret=client_secret, + redirect_uri=oauth_session.redirect_uri, + status=oauth_session.status, + created_at=oauth_session.created_at, + updated_at=oauth_session.updated_at, + # Encrypted fields as Secret objects (converted from encrypted strings in DB) + authorization_code_enc=Secret.from_encrypted(oauth_session.authorization_code_enc) + if oauth_session.authorization_code_enc + else None, + access_token_enc=Secret.from_encrypted(oauth_session.access_token_enc) if oauth_session.access_token_enc else None, + refresh_token_enc=Secret.from_encrypted(oauth_session.refresh_token_enc) if oauth_session.refresh_token_enc else None, + client_secret_enc=Secret.from_encrypted(oauth_session.client_secret_enc) if oauth_session.client_secret_enc else None, + ) + return pydantic_session + + @enforce_types + async def create_oauth_session(self, session_create: MCPOAuthSessionCreate, actor: PydanticUser) -> MCPOAuthSession: + """Create a new OAuth session for MCP server authentication.""" + async with db_registry.async_session() as session: + # Create the OAuth session with a unique state + oauth_session = MCPOAuth( + id="mcp-oauth-" + str(uuid.uuid4())[:8], + state=secrets.token_urlsafe(32), + server_url=session_create.server_url, + server_name=session_create.server_name, + user_id=session_create.user_id, + organization_id=session_create.organization_id, + status=OAuthSessionStatus.PENDING, + created_at=datetime.now(), + updated_at=datetime.now(), + ) + oauth_session = await oauth_session.create_async(session, actor=actor) + + # Convert to Pydantic model - note: new sessions won't have tokens yet + return self._oauth_orm_to_pydantic(oauth_session) + + @enforce_types + async def get_oauth_session_by_id(self, session_id: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: + """Get an OAuth session by its ID.""" + async with db_registry.async_session() as session: + try: + oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) + return self._oauth_orm_to_pydantic(oauth_session) + except NoResultFound: + return None + + @enforce_types + async def get_oauth_session_by_server(self, server_url: str, actor: PydanticUser) -> Optional[MCPOAuthSession]: + """Get the latest OAuth session by server URL, organization, and user.""" + async with db_registry.async_session() as session: + # Query for OAuth session matching organization, user, server URL, and status + # Order by updated_at desc to get the most recent record + result = await session.execute( + select(MCPOAuth) + .where( + MCPOAuth.organization_id == actor.organization_id, + MCPOAuth.user_id == actor.id, + MCPOAuth.server_url == server_url, + MCPOAuth.status == OAuthSessionStatus.AUTHORIZED, + ) + .order_by(desc(MCPOAuth.updated_at)) + .limit(1) + ) + oauth_session = result.scalar_one_or_none() + + if not oauth_session: + return None + + return self._oauth_orm_to_pydantic(oauth_session) + + @enforce_types + async def update_oauth_session(self, session_id: str, session_update: MCPOAuthSessionUpdate, actor: PydanticUser) -> MCPOAuthSession: + """Update an existing OAuth session.""" + async with db_registry.async_session() as session: + oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) + + # Update fields that are provided + if session_update.authorization_url is not None: + oauth_session.authorization_url = session_update.authorization_url + + # Handle encryption for authorization_code + # Only re-encrypt if the value has actually changed + if session_update.authorization_code is not None: + # Check if value changed + existing_code = None + if oauth_session.authorization_code_enc: + existing_secret = Secret.from_encrypted(oauth_session.authorization_code_enc) + existing_code = existing_secret.get_plaintext() + elif oauth_session.authorization_code: + existing_code = oauth_session.authorization_code + + # Only re-encrypt if different + if existing_code != session_update.authorization_code: + oauth_session.authorization_code_enc = Secret.from_plaintext(session_update.authorization_code).get_encrypted() + # Keep plaintext for dual-write during migration + oauth_session.authorization_code = session_update.authorization_code + + # Handle encryption for access_token + # Only re-encrypt if the value has actually changed + if session_update.access_token is not None: + # Check if value changed + existing_token = None + if oauth_session.access_token_enc: + existing_secret = Secret.from_encrypted(oauth_session.access_token_enc) + existing_token = existing_secret.get_plaintext() + elif oauth_session.access_token: + existing_token = oauth_session.access_token + + # Only re-encrypt if different + if existing_token != session_update.access_token: + oauth_session.access_token_enc = Secret.from_plaintext(session_update.access_token).get_encrypted() + # Keep plaintext for dual-write during migration + oauth_session.access_token = session_update.access_token + + # Handle encryption for refresh_token + # Only re-encrypt if the value has actually changed + if session_update.refresh_token is not None: + # Check if value changed + existing_refresh = None + if oauth_session.refresh_token_enc: + existing_secret = Secret.from_encrypted(oauth_session.refresh_token_enc) + existing_refresh = existing_secret.get_plaintext() + elif oauth_session.refresh_token: + existing_refresh = oauth_session.refresh_token + + # Only re-encrypt if different + if existing_refresh != session_update.refresh_token: + oauth_session.refresh_token_enc = Secret.from_plaintext(session_update.refresh_token).get_encrypted() + # Keep plaintext for dual-write during migration + oauth_session.refresh_token = session_update.refresh_token + + if session_update.token_type is not None: + oauth_session.token_type = session_update.token_type + if session_update.expires_at is not None: + oauth_session.expires_at = session_update.expires_at + if session_update.scope is not None: + oauth_session.scope = session_update.scope + if session_update.client_id is not None: + oauth_session.client_id = session_update.client_id + + # Handle encryption for client_secret + # Only re-encrypt if the value has actually changed + if session_update.client_secret is not None: + # Check if value changed + existing_secret_val = None + if oauth_session.client_secret_enc: + existing_secret = Secret.from_encrypted(oauth_session.client_secret_enc) + existing_secret_val = existing_secret.get_plaintext() + elif oauth_session.client_secret: + existing_secret_val = oauth_session.client_secret + + # Only re-encrypt if different + if existing_secret_val != session_update.client_secret: + oauth_session.client_secret_enc = Secret.from_plaintext(session_update.client_secret).get_encrypted() + # Keep plaintext for dual-write during migration + oauth_session.client_secret = session_update.client_secret + + if session_update.redirect_uri is not None: + oauth_session.redirect_uri = session_update.redirect_uri + if session_update.status is not None: + oauth_session.status = session_update.status + + # Always update the updated_at timestamp + oauth_session.updated_at = datetime.now() + + oauth_session = await oauth_session.update_async(db_session=session, actor=actor) + + return self._oauth_orm_to_pydantic(oauth_session) + + @enforce_types + async def delete_oauth_session(self, session_id: str, actor: PydanticUser) -> None: + """Delete an OAuth session.""" + async with db_registry.async_session() as session: + try: + oauth_session = await MCPOAuth.read_async(db_session=session, identifier=session_id, actor=actor) + await oauth_session.hard_delete_async(db_session=session, actor=actor) + except NoResultFound: + raise ValueError(f"OAuth session with id {session_id} not found.") + + @enforce_types + async def cleanup_expired_oauth_sessions(self, max_age_hours: int = 24) -> int: + """Clean up expired OAuth sessions and return the count of deleted sessions.""" + cutoff_time = datetime.now() - timedelta(hours=max_age_hours) + + async with db_registry.async_session() as session: + # Find expired sessions + result = await session.execute(select(MCPOAuth).where(MCPOAuth.created_at < cutoff_time)) + expired_sessions = result.scalars().all() + + # Delete expired sessions using async ORM method + for oauth_session in expired_sessions: + await oauth_session.hard_delete_async(db_session=session, actor=None) + + if expired_sessions: + logger.info(f"Cleaned up {len(expired_sessions)} expired OAuth sessions") + + return len(expired_sessions) + + @enforce_types + async def handle_oauth_flow( + self, + request: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig], + actor: PydanticUser, + http_request: Optional[Request] = None, + ): + """ + Handle OAuth flow for MCP server connection and yield SSE events. + + Args: + request: The server configuration + actor: The user making the request + http_request: The HTTP request object + + Yields: + SSE events during OAuth flow + + Returns: + Tuple of (temp_client, connect_task) after yielding events + """ + import asyncio + + from letta.services.mcp.oauth_utils import create_oauth_provider, oauth_stream_event + from letta.services.mcp.types import OauthStreamEvent + + # OAuth required, yield state to client to prepare to handle authorization URL + yield oauth_stream_event(OauthStreamEvent.OAUTH_REQUIRED, message="OAuth authentication required") + + # Create OAuth session to persist the state of the OAuth flow + session_create = MCPOAuthSessionCreate( + server_url=request.server_url, + server_name=request.server_name, + user_id=actor.id, + organization_id=actor.organization_id, + ) + oauth_session = await self.create_oauth_session(session_create, actor) + session_id = oauth_session.id + + # TODO: @jnjpng make this check more robust and remove direct os.getenv + # Check if request is from web frontend to determine redirect URI + is_web_request = ( + http_request + and http_request.headers + and http_request.headers.get("user-agent", "") == "Next.js Middleware" + and http_request.headers.__contains__("x-organization-id") + ) + + logo_uri = None + NEXT_PUBLIC_CURRENT_HOST = os.getenv("NEXT_PUBLIC_CURRENT_HOST") + LETTA_AGENTS_ENDPOINT = os.getenv("LETTA_AGENTS_ENDPOINT") + + if is_web_request and NEXT_PUBLIC_CURRENT_HOST: + redirect_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/oauth/callback/{session_id}" + logo_uri = f"{NEXT_PUBLIC_CURRENT_HOST}/seo/favicon.svg" + elif LETTA_AGENTS_ENDPOINT: + # API and SDK usage should call core server directly + redirect_uri = f"{LETTA_AGENTS_ENDPOINT}/v1/tools/mcp/oauth/callback/{session_id}" + else: + logger.error( + f"No redirect URI found for request and base urls: {http_request.headers if http_request else 'No headers'} {NEXT_PUBLIC_CURRENT_HOST} {LETTA_AGENTS_ENDPOINT}" + ) + raise HTTPException(status_code=400, detail="No redirect URI found") + + # Create OAuth provider for the instance of the stream connection + oauth_provider = await create_oauth_provider(session_id, request.server_url, redirect_uri, self, actor, logo_uri=logo_uri) + + # Get authorization URL by triggering OAuth flow + temp_client = None + connect_task = None + try: + temp_client = await self.get_mcp_client(request, actor, oauth_provider) + + # Run connect_to_server in background to avoid blocking + # This will trigger the OAuth flow and the redirect_handler will save the authorization URL to database + connect_task = safe_create_task(temp_client.connect_to_server(), label="mcp_oauth_connect") + + # Give the OAuth flow time to trigger and save the URL + await asyncio.sleep(1.0) + + # Fetch the authorization URL from database and yield state to client to proceed with handling authorization URL + auth_session = await self.get_oauth_session_by_id(session_id, actor) + if auth_session and auth_session.authorization_url: + yield oauth_stream_event(OauthStreamEvent.AUTHORIZATION_URL, url=auth_session.authorization_url, session_id=session_id) + + # Wait for user authorization (with timeout), client should render loading state until user completes the flow and /mcp/oauth/callback/{session_id} is hit + yield oauth_stream_event(OauthStreamEvent.WAITING_FOR_AUTH, message="Waiting for user authorization...") + + # Callback handler will poll for authorization code and state and update the OAuth session + await connect_task + + tools = await temp_client.list_tools(serialize=True) + yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools) + + except Exception as e: + logger.error(f"Error triggering OAuth flow: {e}") + yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Failed to trigger OAuth: {str(e)}") + raise e + finally: + # Clean up resources + if connect_task and not connect_task.done(): + connect_task.cancel() + try: + await connect_task + except asyncio.CancelledError: + pass + if temp_client: + try: + await temp_client.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}") diff --git a/tests/integration_test_mcp_servers.py b/tests/integration_test_mcp_servers.py new file mode 100644 index 00000000..d404b2e1 --- /dev/null +++ b/tests/integration_test_mcp_servers.py @@ -0,0 +1,1050 @@ +""" +Integration tests for the new MCP server endpoints (/v1/mcp-servers/). +Tests all CRUD operations, tool management, and OAuth connection flows. +Uses the Letta SDK client instead of direct HTTP requests. +""" + +import os +import sys +import threading +import time +import uuid +from pathlib import Path +from typing import Any, Dict, List, Optional + +import pytest +import requests +from dotenv import load_dotenv +from letta_client import ( + CreateSsemcpServer, + CreateStdioMcpServer, + CreateStreamableHttpmcpServer, + Letta, + LettaSchemasMcpServerUpdateSsemcpServer, + LettaSchemasMcpServerUpdateStdioMcpServer, + LettaSchemasMcpServerUpdateStreamableHttpmcpServer, + LettaServerRestApiRoutersV1ToolsMcpToolExecuteRequest, + MessageCreate, + ToolCallMessage, + ToolReturnMessage, +) +from letta_client.core import ApiError + +from letta.schemas.agent import AgentState +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it's accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + yield url + + +@pytest.fixture(scope="module") +def letta_client(server_url: str) -> Letta: + """ + Provides a configured Letta SDK client. + """ + token = os.getenv("LETTA_API_TOKEN", "") + + # Initialize the SDK client + client = Letta( + base_url=server_url, + token=token if token else None, + # Skip default cloud environment since we're using a custom server + environment=None, + ) + yield client + + +@pytest.fixture(scope="function") +def unique_server_id() -> str: + """Generate a unique MCP server ID for each test.""" + # MCP server IDs follow the format: mcp_server- + return f"mcp_server-{uuid.uuid4()}" + + +@pytest.fixture(scope="function") +def mock_mcp_server_path() -> Path: + """Get path to mock MCP server for testing.""" + script_dir = Path(__file__).parent + mcp_server_path = script_dir / "mock_mcp_server.py" + + if not mcp_server_path.exists(): + # Create a minimal mock server for testing if it doesn't exist + pytest.skip(f"Mock MCP server not found at {mcp_server_path}") + + return mcp_server_path + + +@pytest.fixture(scope="function") +def mock_mcp_server_config_for_agent() -> CreateStdioMcpServer: + """ + Creates a stdio configuration for the mock MCP server for agent testing. + """ + # Get path to mock_mcp_server.py + script_dir = Path(__file__).parent + mcp_server_path = script_dir / "mock_mcp_server.py" + + if not mcp_server_path.exists(): + raise FileNotFoundError(f"Mock MCP server not found at {mcp_server_path}") + + server_name = f"test-mcp-agent-{uuid.uuid4().hex[:8]}" + + return CreateStdioMcpServer( + server_name=server_name, + command=sys.executable, # Use the current Python interpreter + args=[str(mcp_server_path)], + ) + + +@pytest.fixture(scope="function") +def agent_with_mcp_tools(letta_client: Letta, mock_mcp_server_config_for_agent: CreateStdioMcpServer) -> AgentState: + """ + Creates an agent with MCP tools attached for testing. + """ + # Register the MCP server (this should automatically sync tools) + server = letta_client.mcp_servers.mcp_create_mcp_server(request=mock_mcp_server_config_for_agent) + server_id = server.id + + try: + # List available MCP tools from the database (they should have been synced during server creation) + mcp_tools = letta_client.mcp_servers.mcp_list_mcp_tools_by_server(server_id) + assert len(mcp_tools) > 0, "No tools found from MCP server" + + # Find the echo and add tools (they should already be in Letta's tool registry) + echo_tool = next((t for t in mcp_tools if t.name == "echo"), None) + add_tool = next((t for t in mcp_tools if t.name == "add"), None) + + assert echo_tool is not None, "echo tool not found" + assert add_tool is not None, "add tool not found" + + # Create agent with the MCP tools (using tool IDs from the synced tools) + agent = letta_client.agents.create( + name=f"test_mcp_agent_{uuid.uuid4().hex[:8]}", + include_base_tools=True, + tool_ids=[echo_tool.id, add_tool.id], + memory_blocks=[ + { + "label": "human", + "value": "Name: Test User", + }, + { + "label": "persona", + "value": "You are a helpful assistant that can use MCP tools to help the user.", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tags=["test_mcp_agent"], + ) + + yield agent + + finally: + # Cleanup agent if it exists + if "agent" in locals(): + try: + letta_client.agents.delete(agent.id) + except Exception as e: + print(f"Warning: Failed to delete agent {agent.id}: {e}") + + # Cleanup MCP server + try: + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + except Exception as e: + print(f"Warning: Failed to delete MCP server {server_id}: {e}") + + +# ------------------------------ +# Helper Functions +# ------------------------------ + + +def create_stdio_server_request(server_name: str, command: str = "npx", args: List[str] = None) -> CreateStdioMcpServer: + """Create a stdio MCP server configuration object.""" + return CreateStdioMcpServer( + server_name=server_name, + command=command, + args=args or ["-y", "@modelcontextprotocol/server-everything"], + env={"NODE_ENV": "test", "DEBUG": "true"}, + ) + + +def create_sse_server_request(server_name: str, server_url: str = None) -> CreateSsemcpServer: + """Create an SSE MCP server configuration object.""" + return CreateSsemcpServer( + server_name=server_name, + server_url=server_url or "https://api.example.com/sse", + auth_header="Authorization", + auth_token="Bearer test_token_123", + custom_headers={"X-Custom-Header": "custom_value", "X-API-Version": "1.0"}, + ) + + +def create_streamable_http_server_request(server_name: str, server_url: str = None) -> CreateStreamableHttpmcpServer: + """Create a streamable HTTP MCP server configuration object.""" + return CreateStreamableHttpmcpServer( + server_name=server_name, + server_url=server_url or "https://api.example.com/streamable", + auth_header="X-API-Key", + auth_token="api_key_456", + custom_headers={"Accept": "application/json", "X-Version": "2.0"}, + ) + + +def create_exa_streamable_http_server_request(server_name: str) -> CreateStreamableHttpmcpServer: + """Create a Streamable HTTP config for Exa MCP with no auth. + + Reference: https://mcp.exa.ai/mcp + """ + return CreateStreamableHttpmcpServer( + server_name=server_name, + server_url="https://mcp.exa.ai/mcp?exaApiKey=your-exa-api-key", + # no auth header/token, no custom headers + ) + + +# ------------------------------ +# Test Cases for CRUD Operations +# ------------------------------ + + +def test_create_stdio_mcp_server(letta_client: Letta): + """Test creating a stdio MCP server.""" + server_name = f"test-stdio-{uuid.uuid4().hex[:8]}" + server_config = create_stdio_server_request(server_name) + + # Create the server + server_data = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + + assert server_data.server_name == server_name + assert server_data.command == server_config.command + assert server_data.args == server_config.args + assert server_data.id is not None # Should have an ID assigned + + server_id = server_data.id + + # Cleanup - delete the server + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_create_sse_mcp_server(letta_client: Letta): + """Test creating an SSE MCP server.""" + server_name = f"test-sse-{uuid.uuid4().hex[:8]}" + server_config = create_sse_server_request(server_name) + + # Create the server + server_data = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + + assert server_data.server_name == server_name + assert server_data.server_url == server_config.server_url + assert server_data.auth_header == server_config.auth_header + assert server_data.id is not None + + server_id = server_data.id + + # Cleanup + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_create_streamable_http_mcp_server(letta_client: Letta): + """Test creating a streamable HTTP MCP server.""" + server_name = f"test-http-{uuid.uuid4().hex[:8]}" + server_config = create_streamable_http_server_request(server_name) + + # Create the server + server_data = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + + assert server_data.server_name == server_name + assert server_data.server_url == server_config.server_url + assert server_data.id is not None + + server_id = server_data.id + + # Cleanup + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_list_mcp_servers(letta_client: Letta): + """Test listing all MCP servers.""" + # Create multiple servers + servers_created = [] + + # Create stdio server + stdio_name = f"list-test-stdio-{uuid.uuid4().hex[:8]}" + stdio_config = create_stdio_server_request(stdio_name) + stdio_server = letta_client.mcp_servers.mcp_create_mcp_server(request=stdio_config) + servers_created.append(stdio_server.id) + + # Create SSE server + sse_name = f"list-test-sse-{uuid.uuid4().hex[:8]}" + sse_config = create_sse_server_request(sse_name) + sse_server = letta_client.mcp_servers.mcp_create_mcp_server(request=sse_config) + servers_created.append(sse_server.id) + + try: + # List all servers + servers_list = letta_client.mcp_servers.mcp_list_mcp_servers() + assert isinstance(servers_list, list) + assert len(servers_list) >= 2 # At least our two servers + + # Check our servers are in the list + server_ids = [s.id for s in servers_list] + assert stdio_server.id in server_ids + assert sse_server.id in server_ids + + # Check server names + server_names = [s.server_name for s in servers_list] + assert stdio_name in server_names + assert sse_name in server_names + + finally: + # Cleanup + for server_id in servers_created: + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_get_specific_mcp_server(letta_client: Letta): + """Test getting a specific MCP server by ID.""" + # Create a server + server_name = f"get-test-{uuid.uuid4().hex[:8]}" + server_config = create_stdio_server_request(server_name, command="python", args=["-m", "mcp_server"]) + server_config.env["PYTHONPATH"] = "/usr/local/lib" + + created_server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = created_server.id + + try: + # Get the server by ID + retrieved_server = letta_client.mcp_servers.mcp_get_mcp_server(server_id) + + assert retrieved_server.id == server_id + assert retrieved_server.server_name == server_name + assert retrieved_server.command == "python" + assert retrieved_server.args == ["-m", "mcp_server"] + assert retrieved_server.env.get("PYTHONPATH") == "/usr/local/lib" + + finally: + # Cleanup + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_update_stdio_mcp_server(letta_client: Letta): + """Test updating a stdio MCP server.""" + # Create a server + server_name = f"update-test-stdio-{uuid.uuid4().hex[:8]}" + server_config = create_stdio_server_request(server_name, command="node", args=["old_server.js"]) + + created_server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = created_server.id + + try: + # Update the server + update_request = LettaSchemasMcpServerUpdateStdioMcpServer( + server_name="updated-stdio-server", + command="node", + args=["new_server.js", "--port", "3000"], + env={"NEW_ENV": "new_value", "PORT": "3000"}, + ) + + updated_server = letta_client.mcp_servers.mcp_update_mcp_server(server_id, request=update_request) + + assert updated_server.server_name == "updated-stdio-server" + assert updated_server.args == ["new_server.js", "--port", "3000"] + assert updated_server.env.get("NEW_ENV") == "new_value" + + finally: + # Cleanup + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_update_sse_mcp_server(letta_client: Letta): + """Test updating an SSE MCP server.""" + # Create an SSE server + server_name = f"update-test-sse-{uuid.uuid4().hex[:8]}" + server_config = create_sse_server_request(server_name, server_url="https://old.example.com/sse") + + created_server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = created_server.id + + try: + # Update the server + update_request = LettaSchemasMcpServerUpdateSsemcpServer( + server_name="updated-sse-server", + server_url="https://new.example.com/sse/v2", + auth_token="new_token_789", + custom_headers={"X-Updated": "true", "X-Version": "2.0"}, + ) + + updated_server = letta_client.mcp_servers.mcp_update_mcp_server(server_id, request=update_request) + + assert updated_server.server_name == "updated-sse-server" + assert updated_server.server_url == "https://new.example.com/sse/v2" + + finally: + # Cleanup + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_delete_mcp_server(letta_client: Letta): + """Test deleting an MCP server.""" + # Create a server to delete + server_name = f"delete-test-{uuid.uuid4().hex[:8]}" + server_config = create_stdio_server_request(server_name) + + created_server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = created_server.id + + # Delete the server + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + # Verify it's deleted (should raise ApiError with 404) + with pytest.raises(ApiError) as exc_info: + letta_client.mcp_servers.mcp_get_mcp_server(server_id) + assert exc_info.value.status_code == 404 + + +# ------------------------------ +# Test Cases for Error Handling +# ------------------------------ + + +def test_invalid_server_type(letta_client: Letta): + """Test creating server with invalid type.""" + # The SDK should handle type validation, so we'll test with an invalid configuration + # that would be rejected by the API + try: + # Try to create a server with an invalid configuration + # The SDK validates types, so this test might need adjustment based on actual SDK behavior + invalid_config = CreateStdioMcpServer( + server_name="invalid-server", + command="", # Empty command should be invalid + args=[], + ) + with pytest.raises(ApiError) as exc_info: + letta_client.mcp_servers.mcp_create_mcp_server(request=invalid_config) + assert exc_info.value.status_code in [400, 422] # Bad request or validation error + except Exception: + # SDK might handle validation differently + pass + + +# # ------------------------------ +# # Test Cases for Complex Scenarios +# # ------------------------------ + + +def test_multiple_server_types_coexist(letta_client: Letta): + """Test that multiple server types can coexist.""" + servers_created = [] + + try: + # Create one of each type + stdio_config = create_stdio_server_request(f"multi-stdio-{uuid.uuid4().hex[:8]}") + stdio_server = letta_client.mcp_servers.mcp_create_mcp_server(request=stdio_config) + servers_created.append(stdio_server.id) + + sse_config = create_sse_server_request(f"multi-sse-{uuid.uuid4().hex[:8]}") + sse_server = letta_client.mcp_servers.mcp_create_mcp_server(request=sse_config) + servers_created.append(sse_server.id) + + http_config = create_streamable_http_server_request(f"multi-http-{uuid.uuid4().hex[:8]}") + http_server = letta_client.mcp_servers.mcp_create_mcp_server(request=http_config) + servers_created.append(http_server.id) + + # List all servers + servers_list = letta_client.mcp_servers.mcp_list_mcp_servers() + server_ids = [s.id for s in servers_list] + + # Verify all three are present + assert stdio_server.id in server_ids + assert sse_server.id in server_ids + assert http_server.id in server_ids + + # Get each server and verify type-specific fields + stdio_retrieved = letta_client.mcp_servers.mcp_get_mcp_server(stdio_server.id) + assert stdio_retrieved.command == stdio_config.command + + sse_retrieved = letta_client.mcp_servers.mcp_get_mcp_server(sse_server.id) + assert sse_retrieved.server_url == sse_config.server_url + + http_retrieved = letta_client.mcp_servers.mcp_get_mcp_server(http_server.id) + assert http_retrieved.server_url == http_config.server_url + + finally: + # Cleanup all servers + for server_id in servers_created: + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_partial_update_preserves_fields(letta_client: Letta): + """Test that partial updates preserve non-updated fields.""" + # Create a server with all fields + server_name = f"partial-update-{uuid.uuid4().hex[:8]}" + server_config = CreateStdioMcpServer( + server_name=server_name, + command="node", + args=["server.js", "--port", "3000"], + env={"NODE_ENV": "production", "PORT": "3000", "DEBUG": "false"}, + ) + + created_server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = created_server.id + + try: + # Update only the server name + update_request = LettaSchemasMcpServerUpdateStdioMcpServer(server_name="renamed-server") + + updated_server = letta_client.mcp_servers.mcp_update_mcp_server(server_id, request=update_request) + + assert updated_server.server_name == "renamed-server" + # Other fields should be preserved + assert updated_server.command == "node" + assert updated_server.args == ["server.js", "--port", "3000"] + + finally: + # Cleanup + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_concurrent_server_operations(letta_client: Letta): + """Test multiple servers can be operated on concurrently.""" + servers_created = [] + + try: + # Create multiple servers quickly + for i in range(3): + server_config = create_stdio_server_request(f"concurrent-{i}-{uuid.uuid4().hex[:8]}", command="python", args=[f"server_{i}.py"]) + + server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + servers_created.append(server.id) + + # Update all servers + for i, server_id in enumerate(servers_created): + update_request = LettaSchemasMcpServerUpdateStdioMcpServer(server_name=f"updated-concurrent-{i}") + + updated_server = letta_client.mcp_servers.mcp_update_mcp_server(server_id, request=update_request) + assert updated_server.server_name == f"updated-concurrent-{i}" + + # Get all servers + for i, server_id in enumerate(servers_created): + server = letta_client.mcp_servers.mcp_get_mcp_server(server_id) + assert server.server_name == f"updated-concurrent-{i}" + + finally: + # Cleanup all servers + for server_id in servers_created: + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_full_server_lifecycle(letta_client: Letta): + """Test complete lifecycle: create, list, get, update, tools, delete.""" + # 1. Create server + server_name = f"lifecycle-test-{uuid.uuid4().hex[:8]}" + server_config = create_stdio_server_request(server_name, command="npx", args=["-y", "@modelcontextprotocol/server-everything"]) + server_config.env["TEST"] = "true" + + created_server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = created_server.id + + try: + # 2. List servers and verify it's there + servers_list = letta_client.mcp_servers.mcp_list_mcp_servers() + assert any(s.id == server_id for s in servers_list) + + # 3. Get specific server + retrieved_server = letta_client.mcp_servers.mcp_get_mcp_server(server_id) + assert retrieved_server.server_name == server_name + + # 4. Update server + update_request = LettaSchemasMcpServerUpdateStdioMcpServer( + server_name="lifecycle-updated", env={"TEST": "false", "NEW_VAR": "value"} + ) + updated_server = letta_client.mcp_servers.mcp_update_mcp_server(server_id, request=update_request) + assert updated_server.server_name == "lifecycle-updated" + + # 5. List tools + tools = letta_client.mcp_servers.mcp_list_mcp_tools_by_server(server_id) + assert isinstance(tools, list) + + # 6. If tools exist, try to get and run one + if len(tools) > 0: + # Find the echo tool specifically since we know its schema + echo_tool = next((t for t in tools if t.name == "echo"), None) + if echo_tool: + # Get specific tool + tool = letta_client.mcp_servers.mcp_get_mcp_tool(server_id, echo_tool.id) + assert tool.id == echo_tool.id + + # Run the tool directly with required args + result = letta_client.mcp_servers.mcp_run_tool(server_id, echo_tool.id, args={"message": "Test lifecycle tool execution"}) + assert hasattr(result, "status"), "Tool execution result should have status" + + finally: + # 9. Delete server + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + # 10. Verify it's deleted + with pytest.raises(ApiError) as exc_info: + letta_client.mcp_servers.mcp_get_mcp_server(server_id) + assert exc_info.value.status_code == 404 + + +# ------------------------------ +# Test Cases for Empty Responses +# ------------------------------ + + +def test_empty_tools_list(letta_client: Letta): + """Test handling of servers with no tools.""" + # Create a minimal server that likely has no tools + server_name = f"no-tools-{uuid.uuid4().hex[:8]}" + server_config = create_stdio_server_request(server_name, command="echo", args=["hello"]) + + created_server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = created_server.id + + try: + # List tools (should be empty) + tools = letta_client.mcp_servers.mcp_list_mcp_tools_by_server(server_id) + + assert tools is not None + assert isinstance(tools, list) + # Tools will be empty for a simple echo command + + finally: + # Cleanup + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +# ------------------------------ +# Test Cases for Tool Execution with Agents +# ------------------------------ + + +def test_mcp_echo_tool_with_agent(letta_client: Letta, agent_with_mcp_tools: AgentState): + """ + Test that an agent can successfully call the echo tool from the MCP server. + """ + test_message = "Hello from MCP integration test!" + + response = letta_client.agents.messages.create( + agent_id=agent_with_mcp_tools.id, + messages=[ + MessageCreate( + role="user", + content=f"Use the echo tool to echo back this exact message: '{test_message}'", + ) + ], + ) + + # Check for tool call message + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage" + + # Find the echo tool call + echo_call = next((m for m in tool_calls if m.tool_call.name == "echo"), None) + assert echo_call is not None, f"No echo tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}" + + # Check for tool return message + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage" + + # Find the return for the echo call + echo_return = next((m for m in tool_returns if m.tool_call_id == echo_call.tool_call.tool_call_id), None) + assert echo_return is not None, "No tool return found for echo call" + assert echo_return.status == "success", f"Echo tool failed with status: {echo_return.status}" + + # Verify the echo response contains our message + assert test_message in echo_return.tool_return, f"Expected '{test_message}' in tool return, got: {echo_return.tool_return}" + + +def test_mcp_add_tool_with_agent(letta_client: Letta, agent_with_mcp_tools: AgentState): + """ + Test that an agent can successfully call the add tool from the MCP server. + """ + a, b = 42, 58 + expected_sum = a + b + + response = letta_client.agents.messages.create( + agent_id=agent_with_mcp_tools.id, + messages=[ + MessageCreate( + role="user", + content=f"Use the add tool to add {a} and {b}.", + ) + ], + ) + + # Check for tool call message + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage" + + # Find the add tool call + add_call = next((m for m in tool_calls if m.tool_call.name == "add"), None) + assert add_call is not None, f"No add tool call found. Tool calls: {[m.tool_call.name for m in tool_calls]}" + + # Check for tool return message + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage" + + # Find the return for the add call + add_return = next((m for m in tool_returns if m.tool_call_id == add_call.tool_call.tool_call_id), None) + assert add_return is not None, "No tool return found for add call" + assert add_return.status == "success", f"Add tool failed with status: {add_return.status}" + + # Verify the result contains the expected sum + assert str(expected_sum) in add_return.tool_return, f"Expected '{expected_sum}' in tool return, got: {add_return.tool_return}" + + +def test_mcp_multiple_tools_in_sequence_with_agent(letta_client: Letta): + """ + Test that an agent can call multiple MCP tools in sequence. + """ + # Create server with multiple tools + script_dir = Path(__file__).parent + mcp_server_path = script_dir / "mock_mcp_server.py" + + if not mcp_server_path.exists(): + pytest.skip(f"Mock MCP server not found at {mcp_server_path}") + + server_name = f"test-multi-tools-{uuid.uuid4().hex[:8]}" + server_config = CreateStdioMcpServer( + server_name=server_name, + command=sys.executable, + args=[str(mcp_server_path)], + ) + + # Register the MCP server + server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = server.id + + try: + # List available MCP tools + mcp_tools = letta_client.mcp_servers.mcp_list_mcp_tools_by_server(server_id) + + # Get multiple tools + add_tool = next((t for t in mcp_tools if t.name == "add"), None) + multiply_tool = next((t for t in mcp_tools if t.name == "multiply"), None) + echo_tool = next((t for t in mcp_tools if t.name == "echo"), None) + + assert add_tool is not None, "add tool not found" + assert multiply_tool is not None, "multiply tool not found" + assert echo_tool is not None, "echo tool not found" + + # Create agent with multiple tools + agent = letta_client.agents.create( + name=f"test_multi_tools_{uuid.uuid4().hex[:8]}", + include_base_tools=True, + tool_ids=[add_tool.id, multiply_tool.id, echo_tool.id], + memory_blocks=[ + { + "label": "human", + "value": "Name: Test User", + }, + { + "label": "persona", + "value": "You are a helpful assistant that can use MCP tools to help the user.", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tags=["test_multi_tools"], + ) + + # Send message requiring multiple tool calls + response = letta_client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content="First use the add tool to add 10 and 20. Then use the multiply tool to multiply the result by 2. " + "Finally, use the echo tool to echo back the final result.", + ) + ], + ) + + # Check for tool call messages + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) >= 3, f"Expected at least 3 tool calls, got {len(tool_calls)}" + + # Verify all three tools were called + tool_names = [m.tool_call.name for m in tool_calls] + assert "add" in tool_names, f"add tool not called. Tools called: {tool_names}" + assert "multiply" in tool_names, f"multiply tool not called. Tools called: {tool_names}" + assert "echo" in tool_names, f"echo tool not called. Tools called: {tool_names}" + + # Check for tool return messages + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) >= 3, f"Expected at least 3 tool returns, got {len(tool_returns)}" + + # Verify all tools succeeded + for tool_return in tool_returns: + assert tool_return.status == "success", f"Tool call failed with status: {tool_return.status}" + + # Cleanup agent + letta_client.agents.delete(agent.id) + + finally: + # Cleanup MCP server + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_mcp_complex_schema_tool_with_agent(letta_client: Letta): + """ + Test that an agent can successfully call a tool with complex nested schema. + This tests the get_parameter_type_description tool which has: + - Enum-like preset parameter + - Optional string field + - Optional nested object with arrays of objects + """ + # Create server + script_dir = Path(__file__).parent + mcp_server_path = script_dir / "mock_mcp_server.py" + + if not mcp_server_path.exists(): + pytest.skip(f"Mock MCP server not found at {mcp_server_path}") + + server_name = f"test-complex-schema-{uuid.uuid4().hex[:8]}" + server_config = CreateStdioMcpServer( + server_name=server_name, + command=sys.executable, + args=[str(mcp_server_path)], + ) + + # Register the MCP server + server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = server.id + + try: + # List available tools + mcp_tools = letta_client.mcp_servers.mcp_list_mcp_tools_by_server(server_id) + + # Find the complex schema tool + complex_tool = next((t for t in mcp_tools if t.name == "get_parameter_type_description"), None) + assert complex_tool is not None, f"get_parameter_type_description tool not found. Available: {[t.name for t in mcp_tools]}" + + # Find other complex tools + create_person_tool = next((t for t in mcp_tools if t.name == "create_person"), None) + manage_tasks_tool = next((t for t in mcp_tools if t.name == "manage_tasks"), None) + + # Create agent with complex schema tools + tool_ids = [complex_tool.id] + if create_person_tool: + tool_ids.append(create_person_tool.id) + if manage_tasks_tool: + tool_ids.append(manage_tasks_tool.id) + + agent = letta_client.agents.create( + name=f"test_complex_schema_{uuid.uuid4().hex[:8]}", + include_base_tools=True, + tool_ids=tool_ids, + memory_blocks=[ + { + "label": "human", + "value": "Name: Test User", + }, + { + "label": "persona", + "value": "You are a helpful assistant that can use MCP tools with complex schemas.", + }, + ], + llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + tags=["test_complex_schema"], + ) + + # Test 1: Simple call with just preset + response = letta_client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content='Use the get_parameter_type_description tool with preset "a" to get parameter information.', + ) + ], + ) + + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage" + + complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None) + assert complex_call is not None, f"No get_parameter_type_description call found. Calls: {[m.tool_call.name for m in tool_calls]}" + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + assert len(tool_returns) > 0, "Expected at least one ToolReturnMessage" + + complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None) + assert complex_return is not None, "No tool return found for complex schema call" + assert complex_return.status == "success", f"Complex schema tool failed with status: {complex_return.status}" + assert "Preset: a" in complex_return.tool_return, f"Expected 'Preset: a' in return, got: {complex_return.tool_return}" + + # Test 2: Complex call with nested data + response = letta_client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content="Use the get_parameter_type_description tool with these arguments: " + 'preset="b", connected_service_descriptor="test-service", ' + "and instantiation_data with isAbstract=true, isMultiplicity=false, " + 'and one instantiation with doid="TEST123" and nodeFamilyId=42.', + ) + ], + ) + + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + assert len(tool_calls) > 0, "Expected at least one ToolCallMessage for complex nested call" + + complex_call = next((m for m in tool_calls if m.tool_call.name == "get_parameter_type_description"), None) + assert complex_call is not None, "No get_parameter_type_description call found for nested test" + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + complex_return = next((m for m in tool_returns if m.tool_call_id == complex_call.tool_call.tool_call_id), None) + assert complex_return is not None, "No tool return found for complex nested call" + assert complex_return.status == "success", f"Complex nested call failed with status: {complex_return.status}" + + # Verify the response contains our complex data + assert "Preset: b" in complex_return.tool_return, "Expected preset 'b' in response" + assert "test-service" in complex_return.tool_return, "Expected service descriptor in response" + + # Test 3: If create_person tool is available, test it + if create_person_tool: + response = letta_client.agents.messages.create( + agent_id=agent.id, + messages=[ + MessageCreate( + role="user", + content='Use the create_person tool to create a person named "John Doe", age 30, ' + 'email "john@example.com", with address at "123 Main St", city "New York", zip "10001".', + ) + ], + ) + + tool_calls = [m for m in response.messages if isinstance(m, ToolCallMessage)] + person_call = next((m for m in tool_calls if m.tool_call.name == "create_person"), None) + assert person_call is not None, "No create_person call found" + + tool_returns = [m for m in response.messages if isinstance(m, ToolReturnMessage)] + person_return = next((m for m in tool_returns if m.tool_call_id == person_call.tool_call.tool_call_id), None) + assert person_return is not None, "No tool return found for create_person call" + assert person_return.status == "success", f"create_person failed with status: {person_return.status}" + assert "John Doe" in person_return.tool_return, "Expected person name in response" + + # Cleanup agent + letta_client.agents.delete(agent.id) + + finally: + # Cleanup MCP server + letta_client.mcp_servers.mcp_delete_mcp_server(server_id) + + +def test_comprehensive_mcp_server_tool_listing(letta_client: Letta): + """ + Comprehensive test for MCP server registration, tool listing, and management. + """ + # Create server + script_dir = Path(__file__).parent + mcp_server_path = script_dir / "mock_mcp_server.py" + + if not mcp_server_path.exists(): + pytest.skip(f"Mock MCP server not found at {mcp_server_path}") + + server_name = f"test-comprehensive-{uuid.uuid4().hex[:8]}" + server_config = CreateStdioMcpServer( + server_name=server_name, + command=sys.executable, + args=[str(mcp_server_path)], + ) + + # Register the MCP server + server = letta_client.mcp_servers.mcp_create_mcp_server(request=server_config) + server_id = server.id + + try: + # Verify server is in the list + servers = letta_client.mcp_servers.mcp_list_mcp_servers() + server_ids = [s.id for s in servers] + assert server_id in server_ids, f"MCP server {server_id} not found in {server_ids}" + + # List available tools + mcp_tools = letta_client.mcp_servers.mcp_list_mcp_tools_by_server(server_id) + assert len(mcp_tools) > 0, "No tools found from MCP server" + + # Verify expected tools are present + tool_names = [t.name for t in mcp_tools] + expected_tools = [ + "echo", + "add", + "multiply", + "reverse_string", + "create_person", + "manage_tasks", + "search_with_filters", + "process_nested_data", + "get_parameter_type_description", + ] + + for expected_tool in expected_tools: + assert expected_tool in tool_names, f"Expected tool '{expected_tool}' not found. Available: {tool_names}" + + # Test getting individual tools + for tool in mcp_tools[:3]: # Test first 3 tools + retrieved_tool = letta_client.mcp_servers.mcp_get_mcp_tool(server_id, tool.id) + assert retrieved_tool.id == tool.id, f"Tool ID mismatch: expected {tool.id}, got {retrieved_tool.id}" + assert retrieved_tool.name == tool.name, f"Tool name mismatch: expected {tool.name}, got {retrieved_tool.name}" + + # Test running a simple tool directly (without agent) + echo_tool = next((t for t in mcp_tools if t.name == "echo"), None) + if echo_tool: + result = letta_client.mcp_servers.mcp_run_tool(server_id, echo_tool.id, args={"message": "Test direct tool execution"}) + assert hasattr(result, "status"), "Tool execution result should have status" + # The exact structure of result depends on the API implementation + + # Test tool schema inspection + complex_tool = next((t for t in mcp_tools if t.name == "get_parameter_type_description"), None) + if complex_tool: + # Verify the tool has appropriate schema/description + assert complex_tool.description is not None, "Complex tool should have a description" + # Could add more schema validation here if the API exposes it + + finally: + # Cleanup MCP server + letta_client.mcp_servers.mcp_delete_mcp_server(server_id)