From fe96b7001de45333bcf86ffdee6846a209f29000 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Wed, 2 Jul 2025 17:06:21 -0700 Subject: [PATCH] feat: mcp custom headers and multiple fixes (#3079) Co-authored-by: Jin Peng --- ...16524f_add_custom_headers_to_mcp_server.py | 31 +++++++++++++++++++ letta/orm/mcp_server.py | 3 ++ letta/schemas/mcp.py | 25 +++++++++------ letta/server/rest_api/routers/v1/tools.py | 24 +++++++++----- letta/services/mcp/base_client.py | 1 + letta/services/mcp_manager.py | 14 ++++++++- 6 files changed, 79 insertions(+), 19 deletions(-) create mode 100644 alembic/versions/56254216524f_add_custom_headers_to_mcp_server.py diff --git a/alembic/versions/56254216524f_add_custom_headers_to_mcp_server.py b/alembic/versions/56254216524f_add_custom_headers_to_mcp_server.py new file mode 100644 index 00000000..62331ccb --- /dev/null +++ b/alembic/versions/56254216524f_add_custom_headers_to_mcp_server.py @@ -0,0 +1,31 @@ +"""add_custom_headers_to_mcp_server + +Revision ID: 56254216524f +Revises: 60ed28ee7138 +Create Date: 2025-07-02 14:08:59.163861 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "56254216524f" +down_revision: Union[str, None] = "60ed28ee7138" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("mcp_server", sa.Column("custom_headers", sa.JSON(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("mcp_server", "custom_headers") + # ### end Alembic commands ### diff --git a/letta/orm/mcp_server.py b/letta/orm/mcp_server.py index dd16485b..113b7183 100644 --- a/letta/orm/mcp_server.py +++ b/letta/orm/mcp_server.py @@ -39,6 +39,9 @@ class MCPServer(SqlalchemyBase, OrganizationMixin): # access token / api key for MCP servers that require authentication token: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The access token or api key for the MCP server") + # custom headers for authentication (key-value pairs) + custom_headers: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Custom authentication headers as key-value pairs") + # stdio server stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column( MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server" diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index 22be1073..b851f2d5 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -19,12 +19,13 @@ class BaseMCPServer(LettaBase): class MCPServer(BaseMCPServer): id: str = BaseMCPServer.generate_id_field() - server_type: MCPServerType = MCPServerType.SSE + server_type: MCPServerType = MCPServerType.STREAMABLE_HTTP server_name: str = Field(..., description="The name of the server") - # sse config - server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)") - token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)") + # sse / streamable http config + server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE/Streamable HTTP client will connect to this URL)") + token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for authentication)") + custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") # stdio config stdio_config: Optional[StdioServerConfig] = Field( @@ -43,9 +44,9 @@ class MCPServer(BaseMCPServer): return SSEServerConfig( server_name=self.server_name, server_url=self.server_url, - auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None, - auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None, - custom_headers=None, + auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None, + auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None, + custom_headers=self.custom_headers, ) elif self.server_type == MCPServerType.STDIO: if self.stdio_config is None: @@ -57,9 +58,9 @@ class MCPServer(BaseMCPServer): return StreamableHTTPServerConfig( server_name=self.server_name, server_url=self.server_url, - auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None, - auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None, - custom_headers=None, + auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None, + auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None, + custom_headers=self.custom_headers, ) else: raise ValueError(f"Unsupported server type: {self.server_type}") @@ -70,6 +71,7 @@ class RegisterSSEMCPServer(LettaBase): server_type: MCPServerType = MCPServerType.SSE server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)") token: Optional[str] = Field(None, description="The access token or API key for the MCP server used for authentication") + custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") class RegisterStdioMCPServer(LettaBase): @@ -84,6 +86,7 @@ class RegisterStreamableHTTPMCPServer(LettaBase): server_url: str = Field(..., description="The URL path for the streamable HTTP server (e.g., 'example/mcp')") auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')") auth_token: Optional[str] = Field(None, description="The authentication token or API key value") + custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") class UpdateSSEMCPServer(LettaBase): @@ -92,6 +95,7 @@ class UpdateSSEMCPServer(LettaBase): server_name: Optional[str] = Field(None, description="The name of the server") server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)") token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)") + custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") class UpdateStdioMCPServer(LettaBase): @@ -110,6 +114,7 @@ class UpdateStreamableHTTPMCPServer(LettaBase): server_url: Optional[str] = Field(None, description="The URL path for the streamable HTTP server (e.g., 'example/mcp')") auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')") auth_token: Optional[str] = Field(None, description="The authentication token or API key value") + custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs") UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer] diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index be423a52..4feec778 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -508,11 +508,19 @@ async def add_mcp_server_to_config( ) elif isinstance(request, SSEServerConfig): mapped_request = MCPServer( - server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token() + server_name=request.server_name, + server_type=request.type, + server_url=request.server_url, + token=request.resolve_token() if not request.custom_headers else None, + custom_headers=request.custom_headers, ) elif isinstance(request, StreamableHTTPServerConfig): mapped_request = MCPServer( - server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token() + server_name=request.server_name, + server_type=request.type, + server_url=request.server_url, + token=request.resolve_token() if not request.custom_headers else None, + custom_headers=request.custom_headers, ) await server.mcp_manager.create_mcp_server(mapped_request, actor=actor) @@ -624,7 +632,6 @@ async def test_mcp_server( await client.connect_to_server() tools = await client.list_tools() - await client.cleanup() return tools except ConnectionError as e: raise HTTPException( @@ -645,11 +652,6 @@ async def test_mcp_server( }, ) except Exception as e: - if client: - try: - await client.cleanup() - except: - pass raise HTTPException( status_code=500, detail={ @@ -658,6 +660,12 @@ async def test_mcp_server( "server_name": request.server_name, }, ) + finally: + if client: + try: + await client.cleanup() + except Exception as cleanup_error: + logger.warning(f"Error during MCP client cleanup: {cleanup_error}") class CodeInput(BaseModel): diff --git a/letta/services/mcp/base_client.py b/letta/services/mcp/base_client.py index d0bc1094..8aeda67f 100644 --- a/letta/services/mcp/base_client.py +++ b/letta/services/mcp/base_client.py @@ -77,6 +77,7 @@ class AsyncBaseMCPClient: logger.error("MCPClient has not been initialized") raise RuntimeError("MCPClient has not been initialized") + # TODO: still hitting some async errors for voice agents, need to fix async def cleanup(self): """Clean up resources - ensure this runs in the same task""" if hasattr(self, "_cleanup_task"): diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 846a37f5..63a825bd 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -2,6 +2,8 @@ import json import os from typing import Any, Dict, List, Optional, Tuple, Union +from sqlalchemy import null + import letta.constants as constants from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig from letta.log import get_logger @@ -156,6 +158,10 @@ class MCPManager: pydantic_mcp_server.organization_id = actor.organization_id mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) + # Ensure custom_headers None is stored as SQL NULL, not JSON null + if mcp_server_data.get("custom_headers") is None: + mcp_server_data.pop("custom_headers", None) + mcp_server = MCPServerModel(**mcp_server_data) mcp_server = await mcp_server.create_async(session, actor=actor) return mcp_server.to_pydantic() @@ -168,7 +174,13 @@ class MCPManager: mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) # Update tool attributes with only the fields that were explicitly set - update_data = mcp_server_update.model_dump(to_orm=True, exclude_none=True) + update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True) + + # Ensure custom_headers None is stored as SQL NULL, not JSON null + if update_data.get("custom_headers") is None: + update_data.pop("custom_headers", None) + setattr(mcp_server, "custom_headers", null()) + for key, value in update_data.items(): setattr(mcp_server, key, value)