feat: mcp custom headers and multiple fixes (#3079)

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-07-02 17:06:21 -07:00
committed by GitHub
parent c41cb35cd4
commit fe96b7001d
6 changed files with 79 additions and 19 deletions

View File

@@ -0,0 +1,31 @@
"""add_custom_headers_to_mcp_server
Revision ID: 56254216524f
Revises: 60ed28ee7138
Create Date: 2025-07-02 14:08:59.163861
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "56254216524f"
down_revision: Union[str, None] = "60ed28ee7138"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("mcp_server", sa.Column("custom_headers", sa.JSON(), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("mcp_server", "custom_headers")
# ### end Alembic commands ###

View File

@@ -39,6 +39,9 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
# access token / api key for MCP servers that require authentication
token: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The access token or api key for the MCP server")
# custom headers for authentication (key-value pairs)
custom_headers: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Custom authentication headers as key-value pairs")
# stdio server
stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column(
MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server"

View File

@@ -19,12 +19,13 @@ class BaseMCPServer(LettaBase):
class MCPServer(BaseMCPServer):
id: str = BaseMCPServer.generate_id_field()
server_type: MCPServerType = MCPServerType.SSE
server_type: MCPServerType = MCPServerType.STREAMABLE_HTTP
server_name: str = Field(..., description="The name of the server")
# sse config
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
# sse / streamable http config
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE/Streamable HTTP client will connect to this URL)")
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for authentication)")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
# stdio config
stdio_config: Optional[StdioServerConfig] = Field(
@@ -43,9 +44,9 @@ class MCPServer(BaseMCPServer):
return SSEServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None,
custom_headers=None,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
custom_headers=self.custom_headers,
)
elif self.server_type == MCPServerType.STDIO:
if self.stdio_config is None:
@@ -57,9 +58,9 @@ class MCPServer(BaseMCPServer):
return StreamableHTTPServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None,
custom_headers=None,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
custom_headers=self.custom_headers,
)
else:
raise ValueError(f"Unsupported server type: {self.server_type}")
@@ -70,6 +71,7 @@ class RegisterSSEMCPServer(LettaBase):
server_type: MCPServerType = MCPServerType.SSE
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
token: Optional[str] = Field(None, description="The access token or API key for the MCP server used for authentication")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
class RegisterStdioMCPServer(LettaBase):
@@ -84,6 +86,7 @@ class RegisterStreamableHTTPMCPServer(LettaBase):
server_url: str = Field(..., description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
class UpdateSSEMCPServer(LettaBase):
@@ -92,6 +95,7 @@ class UpdateSSEMCPServer(LettaBase):
server_name: Optional[str] = Field(None, description="The name of the server")
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
class UpdateStdioMCPServer(LettaBase):
@@ -110,6 +114,7 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
server_url: Optional[str] = Field(None, description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]

View File

@@ -508,11 +508,19 @@ async def add_mcp_server_to_config(
)
elif isinstance(request, SSEServerConfig):
mapped_request = MCPServer(
server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token()
server_name=request.server_name,
server_type=request.type,
server_url=request.server_url,
token=request.resolve_token() if not request.custom_headers else None,
custom_headers=request.custom_headers,
)
elif isinstance(request, StreamableHTTPServerConfig):
mapped_request = MCPServer(
server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token()
server_name=request.server_name,
server_type=request.type,
server_url=request.server_url,
token=request.resolve_token() if not request.custom_headers else None,
custom_headers=request.custom_headers,
)
await server.mcp_manager.create_mcp_server(mapped_request, actor=actor)
@@ -624,7 +632,6 @@ async def test_mcp_server(
await client.connect_to_server()
tools = await client.list_tools()
await client.cleanup()
return tools
except ConnectionError as e:
raise HTTPException(
@@ -645,11 +652,6 @@ async def test_mcp_server(
},
)
except Exception as e:
if client:
try:
await client.cleanup()
except:
pass
raise HTTPException(
status_code=500,
detail={
@@ -658,6 +660,12 @@ async def test_mcp_server(
"server_name": request.server_name,
},
)
finally:
if client:
try:
await client.cleanup()
except Exception as cleanup_error:
logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
class CodeInput(BaseModel):

View File

@@ -77,6 +77,7 @@ class AsyncBaseMCPClient:
logger.error("MCPClient has not been initialized")
raise RuntimeError("MCPClient has not been initialized")
# TODO: still hitting some async errors for voice agents, need to fix
async def cleanup(self):
"""Clean up resources - ensure this runs in the same task"""
if hasattr(self, "_cleanup_task"):

View File

@@ -2,6 +2,8 @@ import json
import os
from typing import Any, Dict, List, Optional, Tuple, Union
from sqlalchemy import null
import letta.constants as constants
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
from letta.log import get_logger
@@ -156,6 +158,10 @@ class MCPManager:
pydantic_mcp_server.organization_id = actor.organization_id
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
# Ensure custom_headers None is stored as SQL NULL, not JSON null
if mcp_server_data.get("custom_headers") is None:
mcp_server_data.pop("custom_headers", None)
mcp_server = MCPServerModel(**mcp_server_data)
mcp_server = await mcp_server.create_async(session, actor=actor)
return mcp_server.to_pydantic()
@@ -168,7 +174,13 @@ class MCPManager:
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
# Update tool attributes with only the fields that were explicitly set
update_data = mcp_server_update.model_dump(to_orm=True, exclude_none=True)
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
# Ensure custom_headers None is stored as SQL NULL, not JSON null
if update_data.get("custom_headers") is None:
update_data.pop("custom_headers", None)
setattr(mcp_server, "custom_headers", null())
for key, value in update_data.items():
setattr(mcp_server, key, value)