feat: allow mcp authentication overrides per agent (#3318)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -1,5 +1,7 @@
|
||||
import re
|
||||
from abc import abstractmethod
|
||||
from enum import Enum
|
||||
from typing import List, Optional
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from mcp import Tool
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -7,6 +9,9 @@ from pydantic import BaseModel, Field
|
||||
# MCP Authentication Constants
|
||||
MCP_AUTH_HEADER_AUTHORIZATION = "Authorization"
|
||||
MCP_AUTH_TOKEN_BEARER_PREFIX = "Bearer"
|
||||
TEMPLATED_VARIABLE_REGEX = (
|
||||
r"\{\{\s*([A-Z_][A-Z0-9_]*)\s*(?:\|\s*([^}]+?)\s*)?\}\}" # Allows for optional whitespace around the variable name and default value
|
||||
)
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
@@ -23,6 +28,91 @@ class BaseServerConfig(BaseModel):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
type: MCPServerType
|
||||
|
||||
def is_templated_tool_variable(self, value: str) -> bool:
|
||||
"""
|
||||
Check if string contains templated variables.
|
||||
|
||||
Args:
|
||||
value: The value string to check
|
||||
|
||||
Returns:
|
||||
True if the value contains templated variables in the format {{ VARIABLE_NAME }} or {{ VARIABLE_NAME | default }}, False otherwise
|
||||
"""
|
||||
return bool(re.search(TEMPLATED_VARIABLE_REGEX, value))
|
||||
|
||||
def get_tool_variable(self, value: str, environment_variables: Dict[str, str]) -> Optional[str]:
|
||||
"""
|
||||
Replace templated variables in a value string with their values from environment variables.
|
||||
Supports fallback/default values with pipe syntax.
|
||||
|
||||
Args:
|
||||
value: The value string that may contain templated variables (e.g., "Bearer {{ API_KEY | default_token }}")
|
||||
environment_variables: Dictionary of environment variables
|
||||
|
||||
Returns:
|
||||
The string with templated variables replaced, or None if no templated variables found
|
||||
"""
|
||||
|
||||
# If no templated variables found or default value provided, return the original value
|
||||
if not self.is_templated_tool_variable(value):
|
||||
return value
|
||||
|
||||
def replace_template(match):
|
||||
variable_name = match.group(1)
|
||||
default_value = match.group(2) if match.group(2) else None
|
||||
|
||||
# Try to get the value from environment variables
|
||||
env_value = environment_variables.get(variable_name) if environment_variables else None
|
||||
|
||||
# Return environment value if found, otherwise return default value, otherwise return empty string
|
||||
if env_value is not None:
|
||||
return env_value
|
||||
elif default_value is not None:
|
||||
return default_value
|
||||
else:
|
||||
# If no environment value and no default, return the original template
|
||||
return match.group(0)
|
||||
|
||||
# Replace all templated variables in the token
|
||||
result = re.sub(TEMPLATED_VARIABLE_REGEX, replace_template, value)
|
||||
|
||||
# If the result still contains unreplaced templates, just return original value
|
||||
if re.search(TEMPLATED_VARIABLE_REGEX, result):
|
||||
logger.warning(f"Unable to resolve templated variable in value: {value}")
|
||||
return value
|
||||
|
||||
return result
|
||||
|
||||
def resolve_custom_headers(
|
||||
self, custom_headers: Optional[Dict[str, str]], environment_variables: Optional[Dict[str, str]] = None
|
||||
) -> Optional[Dict[str, str]]:
|
||||
"""
|
||||
Resolve templated variables in custom headers dictionary.
|
||||
|
||||
Args:
|
||||
custom_headers: Dictionary of custom headers that may contain templated variables
|
||||
environment_variables: Dictionary of environment variables for resolving templates
|
||||
|
||||
Returns:
|
||||
Dictionary with resolved header values, or None if custom_headers is None
|
||||
"""
|
||||
if custom_headers is None:
|
||||
return None
|
||||
|
||||
resolved_headers = {}
|
||||
for key, value in custom_headers.items():
|
||||
# Resolve templated variables in each header value
|
||||
if self.is_templated_tool_variable(value):
|
||||
resolved_headers[key] = self.get_tool_variable(value, environment_variables)
|
||||
else:
|
||||
resolved_headers[key] = value
|
||||
|
||||
return resolved_headers
|
||||
|
||||
@abstractmethod
|
||||
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class SSEServerConfig(BaseServerConfig):
|
||||
"""
|
||||
@@ -47,6 +137,12 @@ class SSEServerConfig(BaseServerConfig):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
|
||||
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
|
||||
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
|
||||
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
|
||||
|
||||
self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables)
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "sse",
|
||||
@@ -72,6 +168,10 @@ class StdioServerConfig(BaseServerConfig):
|
||||
args: List[str] = Field(..., description="The arguments to pass to the command")
|
||||
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
|
||||
|
||||
# TODO: @jnjpng templated auth handling for stdio
|
||||
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
|
||||
pass
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "stdio",
|
||||
@@ -106,6 +206,12 @@ class StreamableHTTPServerConfig(BaseServerConfig):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
|
||||
def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None:
|
||||
if self.auth_token and super().is_templated_tool_variable(self.auth_token):
|
||||
self.auth_token = super().get_tool_variable(self.auth_token, environment_variables)
|
||||
|
||||
self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables)
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""Validate the server URL format."""
|
||||
# Basic validation for streamable HTTP URLs
|
||||
|
||||
@@ -11,5 +11,6 @@
|
||||
"quantization": "Q4_0",
|
||||
"state": "not-loaded",
|
||||
"max_context_length": 2048
|
||||
},
|
||||
...
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
@@ -41,29 +41,42 @@ class MCPServer(BaseMCPServer):
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
def to_config(self) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
|
||||
def to_config(
|
||||
self,
|
||||
environment_variables: Optional[Dict[str, str]] = None,
|
||||
resolve_variables: bool = True,
|
||||
) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
|
||||
if self.server_type == MCPServerType.SSE:
|
||||
return SSEServerConfig(
|
||||
config = SSEServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
|
||||
custom_headers=self.custom_headers,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
return config
|
||||
elif self.server_type == MCPServerType.STDIO:
|
||||
if self.stdio_config is None:
|
||||
raise ValueError("stdio_config is required for STDIO server type")
|
||||
if resolve_variables:
|
||||
self.stdio_config.resolve_environment_variables(environment_variables)
|
||||
return self.stdio_config
|
||||
elif self.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
if self.server_url is None:
|
||||
raise ValueError("server_url is required for STREAMABLE_HTTP server type")
|
||||
return StreamableHTTPServerConfig(
|
||||
|
||||
config = StreamableHTTPServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None,
|
||||
custom_headers=self.custom_headers,
|
||||
)
|
||||
if resolve_variables:
|
||||
config.resolve_environment_variables(environment_variables)
|
||||
return config
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {self.server_type}")
|
||||
|
||||
|
||||
@@ -394,7 +394,7 @@ async def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_
|
||||
else:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
|
||||
mcp_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return {server.server_name: server.to_config() for server in mcp_servers}
|
||||
return {server.server_name: server.to_config(resolve_variables=False) for server in mcp_servers}
|
||||
|
||||
|
||||
# NOTE: async because the MCP client/session calls are async
|
||||
@@ -639,6 +639,7 @@ async def test_mcp_server(
|
||||
client = None
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
request.resolve_environment_variables()
|
||||
client = await server.mcp_manager.get_mcp_client(request, actor)
|
||||
|
||||
await client.connect_to_server()
|
||||
@@ -719,6 +720,7 @@ async def connect_mcp_server(
|
||||
|
||||
# Create MCP client with respective transport type
|
||||
try:
|
||||
request.resolve_environment_variables()
|
||||
client = await server.mcp_manager.get_mcp_client(request, actor)
|
||||
except ValueError as e:
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
|
||||
|
||||
@@ -66,7 +66,12 @@ class MCPManager:
|
||||
|
||||
@enforce_types
|
||||
async def execute_mcp_server_tool(
|
||||
self, mcp_server_name: str, tool_name: str, tool_args: Optional[Dict[str, Any]], actor: PydanticUser
|
||||
self,
|
||||
mcp_server_name: str,
|
||||
tool_name: str,
|
||||
tool_args: Optional[Dict[str, Any]],
|
||||
environment_variables: Dict[str, str],
|
||||
actor: PydanticUser,
|
||||
) -> Tuple[str, bool]:
|
||||
"""Call a specific tool from a specific MCP server."""
|
||||
from letta.settings import tool_settings
|
||||
@@ -75,7 +80,7 @@ class MCPManager:
|
||||
# read from DB
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
server_config = mcp_config.to_config(environment_variables)
|
||||
else:
|
||||
# read from config file
|
||||
mcp_config = self.read_mcp_config()
|
||||
|
||||
@@ -35,8 +35,17 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
|
||||
mcp_manager = MCPManager()
|
||||
# TODO: may need to have better client connection management
|
||||
|
||||
environment_variables = {}
|
||||
if agent_state:
|
||||
environment_variables = agent_state.get_agent_env_vars_as_dict()
|
||||
|
||||
function_response, success = await mcp_manager.execute_mcp_server_tool(
|
||||
mcp_server_name=mcp_server_name, tool_name=function_name, tool_args=function_args, actor=actor
|
||||
mcp_server_name=mcp_server_name,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
environment_variables=environment_variables,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
|
||||
@@ -8554,7 +8554,7 @@ async def test_create_mcp_server(server, default_user, event_loop):
|
||||
tool_name = "ask_question"
|
||||
tool_args = {"repoName": "letta-ai/letta", "question": "What is the primary programming language of this repository?"}
|
||||
result = await server.mcp_manager.execute_mcp_server_tool(
|
||||
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user
|
||||
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={}
|
||||
)
|
||||
print(result)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user