feat: mcp custom headers and multiple fixes (#3079)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -0,0 +1,31 @@
|
||||
"""add_custom_headers_to_mcp_server
|
||||
|
||||
Revision ID: 56254216524f
|
||||
Revises: 60ed28ee7138
|
||||
Create Date: 2025-07-02 14:08:59.163861
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "56254216524f"
|
||||
down_revision: Union[str, None] = "60ed28ee7138"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("mcp_server", sa.Column("custom_headers", sa.JSON(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("mcp_server", "custom_headers")
|
||||
# ### end Alembic commands ###
|
||||
@@ -39,6 +39,9 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
|
||||
# access token / api key for MCP servers that require authentication
|
||||
token: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The access token or api key for the MCP server")
|
||||
|
||||
# custom headers for authentication (key-value pairs)
|
||||
custom_headers: Mapped[Optional[dict]] = mapped_column(JSON, nullable=True, doc="Custom authentication headers as key-value pairs")
|
||||
|
||||
# stdio server
|
||||
stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column(
|
||||
MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server"
|
||||
|
||||
@@ -19,12 +19,13 @@ class BaseMCPServer(LettaBase):
|
||||
|
||||
class MCPServer(BaseMCPServer):
|
||||
id: str = BaseMCPServer.generate_id_field()
|
||||
server_type: MCPServerType = MCPServerType.SSE
|
||||
server_type: MCPServerType = MCPServerType.STREAMABLE_HTTP
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
|
||||
# sse config
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
|
||||
# sse / streamable http config
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE/Streamable HTTP client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for authentication)")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
# stdio config
|
||||
stdio_config: Optional[StdioServerConfig] = Field(
|
||||
@@ -43,9 +44,9 @@ class MCPServer(BaseMCPServer):
|
||||
return SSEServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None,
|
||||
custom_headers=None,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
|
||||
custom_headers=self.custom_headers,
|
||||
)
|
||||
elif self.server_type == MCPServerType.STDIO:
|
||||
if self.stdio_config is None:
|
||||
@@ -57,9 +58,9 @@ class MCPServer(BaseMCPServer):
|
||||
return StreamableHTTPServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None,
|
||||
custom_headers=None,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
|
||||
custom_headers=self.custom_headers,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {self.server_type}")
|
||||
@@ -70,6 +71,7 @@ class RegisterSSEMCPServer(LettaBase):
|
||||
server_type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server used for authentication")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
|
||||
class RegisterStdioMCPServer(LettaBase):
|
||||
@@ -84,6 +86,7 @@ class RegisterStreamableHTTPMCPServer(LettaBase):
|
||||
server_url: str = Field(..., description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
|
||||
class UpdateSSEMCPServer(LettaBase):
|
||||
@@ -92,6 +95,7 @@ class UpdateSSEMCPServer(LettaBase):
|
||||
server_name: Optional[str] = Field(None, description="The name of the server")
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
|
||||
class UpdateStdioMCPServer(LettaBase):
|
||||
@@ -110,6 +114,7 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
server_url: Optional[str] = Field(None, description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
|
||||
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]
|
||||
|
||||
@@ -508,11 +508,19 @@ async def add_mcp_server_to_config(
|
||||
)
|
||||
elif isinstance(request, SSEServerConfig):
|
||||
mapped_request = MCPServer(
|
||||
server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token()
|
||||
server_name=request.server_name,
|
||||
server_type=request.type,
|
||||
server_url=request.server_url,
|
||||
token=request.resolve_token() if not request.custom_headers else None,
|
||||
custom_headers=request.custom_headers,
|
||||
)
|
||||
elif isinstance(request, StreamableHTTPServerConfig):
|
||||
mapped_request = MCPServer(
|
||||
server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token()
|
||||
server_name=request.server_name,
|
||||
server_type=request.type,
|
||||
server_url=request.server_url,
|
||||
token=request.resolve_token() if not request.custom_headers else None,
|
||||
custom_headers=request.custom_headers,
|
||||
)
|
||||
|
||||
await server.mcp_manager.create_mcp_server(mapped_request, actor=actor)
|
||||
@@ -624,7 +632,6 @@ async def test_mcp_server(
|
||||
|
||||
await client.connect_to_server()
|
||||
tools = await client.list_tools()
|
||||
await client.cleanup()
|
||||
return tools
|
||||
except ConnectionError as e:
|
||||
raise HTTPException(
|
||||
@@ -645,11 +652,6 @@ async def test_mcp_server(
|
||||
},
|
||||
)
|
||||
except Exception as e:
|
||||
if client:
|
||||
try:
|
||||
await client.cleanup()
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail={
|
||||
@@ -658,6 +660,12 @@ async def test_mcp_server(
|
||||
"server_name": request.server_name,
|
||||
},
|
||||
)
|
||||
finally:
|
||||
if client:
|
||||
try:
|
||||
await client.cleanup()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Error during MCP client cleanup: {cleanup_error}")
|
||||
|
||||
|
||||
class CodeInput(BaseModel):
|
||||
|
||||
@@ -77,6 +77,7 @@ class AsyncBaseMCPClient:
|
||||
logger.error("MCPClient has not been initialized")
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
# TODO: still hitting some async errors for voice agents, need to fix
|
||||
async def cleanup(self):
|
||||
"""Clean up resources - ensure this runs in the same task"""
|
||||
if hasattr(self, "_cleanup_task"):
|
||||
|
||||
@@ -2,6 +2,8 @@ import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from sqlalchemy import null
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.log import get_logger
|
||||
@@ -156,6 +158,10 @@ class MCPManager:
|
||||
pydantic_mcp_server.organization_id = actor.organization_id
|
||||
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
|
||||
|
||||
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
||||
if mcp_server_data.get("custom_headers") is None:
|
||||
mcp_server_data.pop("custom_headers", None)
|
||||
|
||||
mcp_server = MCPServerModel(**mcp_server_data)
|
||||
mcp_server = await mcp_server.create_async(session, actor=actor)
|
||||
return mcp_server.to_pydantic()
|
||||
@@ -168,7 +174,13 @@ class MCPManager:
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = mcp_server_update.model_dump(to_orm=True, exclude_none=True)
|
||||
update_data = mcp_server_update.model_dump(to_orm=True, exclude_unset=True)
|
||||
|
||||
# Ensure custom_headers None is stored as SQL NULL, not JSON null
|
||||
if update_data.get("custom_headers") is None:
|
||||
update_data.pop("custom_headers", None)
|
||||
setattr(mcp_server, "custom_headers", null())
|
||||
|
||||
for key, value in update_data.items():
|
||||
setattr(mcp_server, key, value)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user