fix: add URL validation for MCP server URLs to prevent malformed protocol errors (#8079)
Adds field validators to CreateSSEMCPServer, CreateStreamableHTTPMCPServer, and their Update counterparts to validate that server_url fields: - Start with 'http://' or 'https://' - Have a valid host This prevents errors like 'httpx.UnsupportedProtocol: Request URL has an unsupported protocol 'hthttps://'' caused by user input typos. Fixes #8078 👾 Generated with [Letta Code](https://letta.com) Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Letta <noreply@letta.com> Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
This commit is contained in:
committed by
Caren Thomas
parent
700409d943
commit
21df642a43
@@ -3,7 +3,9 @@ import logging
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -51,6 +53,21 @@ class MCPServer(BaseMCPServer):
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
def get_token_secret(self) -> Optional[Secret]:
|
||||
"""Get the token as a Secret object."""
|
||||
return self.token_enc
|
||||
@@ -199,6 +216,21 @@ class UpdateSSEMCPServer(LettaBase):
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateStdioMCPServer(LettaBase):
|
||||
"""Update a Stdio MCP server"""
|
||||
@@ -218,6 +250,21 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]
|
||||
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
from letta.functions.mcp_client.types import (
|
||||
MCP_AUTH_HEADER_AUTHORIZATION,
|
||||
@@ -41,6 +42,19 @@ class CreateSSEMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: str) -> str:
|
||||
"""Validate that server_url is a valid HTTP(S) URL."""
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
class CreateStreamableHTTPMCPServer(LettaBase):
|
||||
"""Create a new Streamable HTTP MCP server"""
|
||||
@@ -51,6 +65,19 @@ class CreateStreamableHTTPMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: str) -> str:
|
||||
"""Validate that server_url is a valid HTTP(S) URL."""
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
CreateMCPServerUnion = Union[CreateStdioMCPServer, CreateSSEMCPServer, CreateStreamableHTTPMCPServer]
|
||||
|
||||
@@ -99,6 +126,21 @@ class UpdateSSEMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
"""Update schema for Streamable HTTP MCP server - all fields optional"""
|
||||
@@ -109,6 +151,21 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with requests")
|
||||
|
||||
@field_validator("server_url")
|
||||
@classmethod
|
||||
def validate_server_url(cls, v: Optional[str]) -> Optional[str]:
|
||||
"""Validate that server_url is a valid HTTP(S) URL if provided."""
|
||||
if v is None:
|
||||
return v
|
||||
if not v:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
parsed = urlparse(v)
|
||||
if parsed.scheme not in ("http", "https"):
|
||||
raise ValueError(f"server_url must start with 'http://' or 'https://', got: '{v}'")
|
||||
if not parsed.netloc:
|
||||
raise ValueError(f"server_url must have a valid host, got: '{v}'")
|
||||
return v
|
||||
|
||||
|
||||
UpdateMCPServerUnion = Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user