feat: allow mcp authentication overrides per agent (#3318)

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-07-28 18:20:58 -07:00
committed by GitHub
parent 84ea52172a
commit cf8c59aab9
7 changed files with 147 additions and 11 deletions

View File

@@ -1,5 +1,7 @@
import re
from abc import abstractmethod
from enum import Enum
from typing import List, Optional
from typing import Dict, List, Optional
from mcp import Tool
from pydantic import BaseModel, Field
@@ -7,6 +9,9 @@ from pydantic import BaseModel, Field
# MCP Authentication Constants
MCP_AUTH_HEADER_AUTHORIZATION = "Authorization"
MCP_AUTH_TOKEN_BEARER_PREFIX = "Bearer"
TEMPLATED_VARIABLE_REGEX = (
r"\{\{\s*([A-Z_][A-Z0-9_]*)\s*(?:\|\s*([^}]+?)\s*)?\}\}" # Allows for optional whitespace around the variable name and default value
)
class MCPTool(Tool):
@@ -23,6 +28,91 @@ class BaseServerConfig(BaseModel):
server_name: str = Field(..., description="The name of the server")
type: MCPServerType
def is_templated_tool_variable(self, value: str) -> bool:
"""
Check if string contains templated variables.
Args:
value: The value string to check
Returns:
True if the value contains templated variables in the format {{ VARIABLE_NAME }} or {{ VARIABLE_NAME | default }}, False otherwise
"""
return bool(re.search(TEMPLATED_VARIABLE_REGEX, value))
def get_tool_variable(self, value: str, environment_variables: Dict[str, str]) -> Optional[str]:
"""
Replace templated variables in a value string with their values from environment variables.
Supports fallback/default values with pipe syntax.
Args:
value: The value string that may contain templated variables (e.g., "Bearer {{ API_KEY | default_token }}")
environment_variables: Dictionary of environment variables
Returns:
The string with templated variables replaced, or None if no templated variables found
"""
# If no templated variables found or default value provided, return the original value
if not self.is_templated_tool_variable(value):
return value
def replace_template(match):
variable_name = match.group(1)
default_value = match.group(2) if match.group(2) else None
# Try to get the value from environment variables
env_value = environment_variables.get(variable_name) if environment_variables else None
# Return environment value if found, otherwise return default value, otherwise return empty string
if env_value is not None:
return env_value
elif default_value is not None:
return default_value
else:
# If no environment value and no default, return the original template
return match.group(0)
# Replace all templated variables in the token
result = re.sub(TEMPLATED_VARIABLE_REGEX, replace_template, value)
# If the result still contains unreplaced templates, just return original value
if re.search(TEMPLATED_VARIABLE_REGEX, result):
logger.warning(f"Unable to resolve templated variable in value: {value}")
return value
return result
def resolve_custom_headers(
self, custom_headers: Optional[Dict[str, str]], environment_variables: Optional[Dict[str, str]] = None
) -> Optional[Dict[str, str]]:
"""
Resolve templated variables in custom headers dictionary.
Args:
custom_headers: Dictionary of custom headers that may contain templated variables
environment_variables: Dictionary of environment variables for resolving templates
Returns:
Dictionary with resolved header values, or None if custom_headers is None
"""
if custom_headers is None:
return None
resolved_headers = {}
for key, value in custom_headers.items():
# Resolve templated variables in each header value
if self.is_templated_tool_variable(value):
resolved_headers[key] = self.get_tool_variable(value, environment_variables)
else:
resolved_headers[key] = value
return resolved_headers
@abstractmethod
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
raise NotImplementedError
class SSEServerConfig(BaseServerConfig):
"""
@@ -47,6 +137,12 @@ class SSEServerConfig(BaseServerConfig):
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
return self.auth_token
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables)
def to_dict(self) -> dict:
values = {
"transport": "sse",
@@ -72,6 +168,10 @@ class StdioServerConfig(BaseServerConfig):
args: List[str] = Field(..., description="The arguments to pass to the command")
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
# TODO: @jnjpng templated auth handling for stdio
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
pass
def to_dict(self) -> dict:
values = {
"transport": "stdio",
@@ -106,6 +206,12 @@ class StreamableHTTPServerConfig(BaseServerConfig):
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
return self.auth_token
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables)
def model_post_init(self, __context) -> None:
"""Validate the server URL format."""
# Basic validation for streamable HTTP URLs

View File

@@ -11,5 +11,6 @@
"quantization": "Q4_0",
"state": "not-loaded",
"max_context_length": 2048
},
...
}
]
}

View File

@@ -41,29 +41,42 @@ class MCPServer(BaseMCPServer):
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
def to_config(self) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
def to_config(
self,
environment_variables: Optional[Dict[str, str]] = None,
resolve_variables: bool = True,
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
if self.server_type == MCPServerType.SSE:
return SSEServerConfig(
config = SSEServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
custom_headers=self.custom_headers,
)
if resolve_variables:
config.resolve_environment_variables(environment_variables)
return config
elif self.server_type == MCPServerType.STDIO:
if self.stdio_config is None:
raise ValueError("stdio_config is required for STDIO server type")
if resolve_variables:
self.stdio_config.resolve_environment_variables(environment_variables)
return self.stdio_config
elif self.server_type == MCPServerType.STREAMABLE_HTTP:
if self.server_url is None:
raise ValueError("server_url is required for STREAMABLE_HTTP server type")
return StreamableHTTPServerConfig(
config = StreamableHTTPServerConfig(
server_name=self.server_name,
server_url=self.server_url,
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
custom_headers=self.custom_headers,
)
if resolve_variables:
config.resolve_environment_variables(environment_variables)
return config
else:
raise ValueError(f"Unsupported server type: {self.server_type}")

View File

@@ -394,7 +394,7 @@ async def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_
else:
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
mcp_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
return {server.server_name: server.to_config() for server in mcp_servers}
return {server.server_name: server.to_config(resolve_variables=False) for server in mcp_servers}
# NOTE: async because the MCP client/session calls are async
@@ -639,6 +639,7 @@ async def test_mcp_server(
client = None
try:
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
request.resolve_environment_variables()
client = await server.mcp_manager.get_mcp_client(request, actor)
await client.connect_to_server()
@@ -719,6 +720,7 @@ async def connect_mcp_server(
# Create MCP client with respective transport type
try:
request.resolve_environment_variables()
client = await server.mcp_manager.get_mcp_client(request, actor)
except ValueError as e:
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))

View File

@@ -66,7 +66,12 @@ class MCPManager:
@enforce_types
async def execute_mcp_server_tool(
self, mcp_server_name: str, tool_name: str, tool_args: Optional[Dict[str, Any]], actor: PydanticUser
self,
mcp_server_name: str,
tool_name: str,
tool_args: Optional[Dict[str, Any]],
environment_variables: Dict[str, str],
actor: PydanticUser,
) -> Tuple[str, bool]:
"""Call a specific tool from a specific MCP server."""
from letta.settings import tool_settings
@@ -75,7 +80,7 @@ class MCPManager:
# read from DB
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config()
server_config = mcp_config.to_config(environment_variables)
else:
# read from config file
mcp_config = self.read_mcp_config()

View File

@@ -35,8 +35,17 @@ class ExternalMCPToolExecutor(ToolExecutor):
mcp_manager = MCPManager()
# TODO: may need to have better client connection management
environment_variables = {}
if agent_state:
environment_variables = agent_state.get_agent_env_vars_as_dict()
function_response, success = await mcp_manager.execute_mcp_server_tool(
mcp_server_name=mcp_server_name, tool_name=function_name, tool_args=function_args, actor=actor
mcp_server_name=mcp_server_name,
tool_name=function_name,
tool_args=function_args,
environment_variables=environment_variables,
actor=actor,
)
return ToolExecutionResult(

View File

@@ -8554,7 +8554,7 @@ async def test_create_mcp_server(server, default_user, event_loop):
tool_name = "ask_question"
tool_args = {"repoName": "letta-ai/letta", "question": "What is the primary programming language of this repository?"}
result = await server.mcp_manager.execute_mcp_server_tool(
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={}
)
print(result)