From cf8c59aab99e9e12e6f476005e5ab3fc21cd64bd Mon Sep 17 00:00:00 2001 From: jnjpng Date: Mon, 28 Jul 2025 18:20:58 -0700 Subject: [PATCH] feat: allow mcp authentication overrides per agent (#3318) Co-authored-by: Jin Peng --- letta/functions/mcp_client/types.py | 108 +++++++++++++++++- .../lmstudio_embedding_list.json | 5 +- letta/schemas/mcp.py | 19 ++- letta/server/rest_api/routers/v1/tools.py | 4 +- letta/services/mcp_manager.py | 9 +- .../tool_executor/mcp_tool_executor.py | 11 +- tests/test_managers.py | 2 +- 7 files changed, 147 insertions(+), 11 deletions(-) diff --git a/letta/functions/mcp_client/types.py b/letta/functions/mcp_client/types.py index 179dd8f8..adf695df 100644 --- a/letta/functions/mcp_client/types.py +++ b/letta/functions/mcp_client/types.py @@ -1,5 +1,7 @@ +import re +from abc import abstractmethod from enum import Enum -from typing import List, Optional +from typing import Dict, List, Optional from mcp import Tool from pydantic import BaseModel, Field @@ -7,6 +9,9 @@ from pydantic import BaseModel, Field # MCP Authentication Constants MCP_AUTH_HEADER_AUTHORIZATION = "Authorization" MCP_AUTH_TOKEN_BEARER_PREFIX = "Bearer" +TEMPLATED_VARIABLE_REGEX = ( + r"\{\{\s*([A-Z_][A-Z0-9_]*)\s*(?:\|\s*([^}]+?)\s*)?\}\}" # Allows for optional whitespace around the variable name and default value +) class MCPTool(Tool): @@ -23,6 +28,91 @@ class BaseServerConfig(BaseModel): server_name: str = Field(..., description="The name of the server") type: MCPServerType + def is_templated_tool_variable(self, value: str) -> bool: + """ + Check if string contains templated variables. + + Args: + value: The value string to check + + Returns: + True if the value contains templated variables in the format {{ VARIABLE_NAME }} or {{ VARIABLE_NAME | default }}, False otherwise + """ + return bool(re.search(TEMPLATED_VARIABLE_REGEX, value)) + + def get_tool_variable(self, value: str, environment_variables: Dict[str, str]) -> Optional[str]: + """ + Replace templated variables in a value string with their values from environment variables. + Supports fallback/default values with pipe syntax. + + Args: + value: The value string that may contain templated variables (e.g., "Bearer {{ API_KEY | default_token }}") + environment_variables: Dictionary of environment variables + + Returns: + The string with templated variables replaced, or None if no templated variables found + """ + + # If no templated variables found or default value provided, return the original value + if not self.is_templated_tool_variable(value): + return value + + def replace_template(match): + variable_name = match.group(1) + default_value = match.group(2) if match.group(2) else None + + # Try to get the value from environment variables + env_value = environment_variables.get(variable_name) if environment_variables else None + + # Return environment value if found, otherwise return default value, otherwise return empty string + if env_value is not None: + return env_value + elif default_value is not None: + return default_value + else: + # If no environment value and no default, return the original template + return match.group(0) + + # Replace all templated variables in the token + result = re.sub(TEMPLATED_VARIABLE_REGEX, replace_template, value) + + # If the result still contains unreplaced templates, just return original value + if re.search(TEMPLATED_VARIABLE_REGEX, result): + logger.warning(f"Unable to resolve templated variable in value: {value}") + return value + + return result + + def resolve_custom_headers( + self, custom_headers: Optional[Dict[str, str]], environment_variables: Optional[Dict[str, str]] = None + ) -> Optional[Dict[str, str]]: + """ + Resolve templated variables in custom headers dictionary. + + Args: + custom_headers: Dictionary of custom headers that may contain templated variables + environment_variables: Dictionary of environment variables for resolving templates + + Returns: + Dictionary with resolved header values, or None if custom_headers is None + """ + if custom_headers is None: + return None + + resolved_headers = {} + for key, value in custom_headers.items(): + # Resolve templated variables in each header value + if self.is_templated_tool_variable(value): + resolved_headers[key] = self.get_tool_variable(value, environment_variables) + else: + resolved_headers[key] = value + + return resolved_headers + + @abstractmethod + def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None: + raise NotImplementedError + class SSEServerConfig(BaseServerConfig): """ @@ -47,6 +137,12 @@ class SSEServerConfig(BaseServerConfig): return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :] return self.auth_token + def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None: + if self.auth_token and super().is_templated_tool_variable(self.auth_token): + self.auth_token = super().get_tool_variable(self.auth_token, environment_variables) + + self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables) + def to_dict(self) -> dict: values = { "transport": "sse", @@ -72,6 +168,10 @@ class StdioServerConfig(BaseServerConfig): args: List[str] = Field(..., description="The arguments to pass to the command") env: Optional[dict[str, str]] = Field(None, description="Environment variables to set") + # TODO: @jnjpng templated auth handling for stdio + def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None: + pass + def to_dict(self) -> dict: values = { "transport": "stdio", @@ -106,6 +206,12 @@ class StreamableHTTPServerConfig(BaseServerConfig): return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :] return self.auth_token + def resolve_environment_variables(self, environment_variables: Optional[Dict[str, str]] = None) -> None: + if self.auth_token and super().is_templated_tool_variable(self.auth_token): + self.auth_token = super().get_tool_variable(self.auth_token, environment_variables) + + self.custom_headers = super().resolve_custom_headers(self.custom_headers, environment_variables) + def model_post_init(self, __context) -> None: """Validate the server URL format.""" # Basic validation for streamable HTTP URLs diff --git a/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json b/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json index 25489ff3..dc2a2d2c 100644 --- a/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json +++ b/letta/llm_api/sample_response_jsons/lmstudio_embedding_list.json @@ -11,5 +11,6 @@ "quantization": "Q4_0", "state": "not-loaded", "max_context_length": 2048 - }, - ... + } + ] +} diff --git a/letta/schemas/mcp.py b/letta/schemas/mcp.py index f7070e8b..f65147f6 100644 --- a/letta/schemas/mcp.py +++ b/letta/schemas/mcp.py @@ -41,29 +41,42 @@ class MCPServer(BaseMCPServer): last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.") metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.") - def to_config(self) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]: + def to_config( + self, + environment_variables: Optional[Dict[str, str]] = None, + resolve_variables: bool = True, + ) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]: if self.server_type == MCPServerType.SSE: - return SSEServerConfig( + config = SSEServerConfig( server_name=self.server_name, server_url=self.server_url, auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None, auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None, custom_headers=self.custom_headers, ) + if resolve_variables: + config.resolve_environment_variables(environment_variables) + return config elif self.server_type == MCPServerType.STDIO: if self.stdio_config is None: raise ValueError("stdio_config is required for STDIO server type") + if resolve_variables: + self.stdio_config.resolve_environment_variables(environment_variables) return self.stdio_config elif self.server_type == MCPServerType.STREAMABLE_HTTP: if self.server_url is None: raise ValueError("server_url is required for STREAMABLE_HTTP server type") - return StreamableHTTPServerConfig( + + config = StreamableHTTPServerConfig( server_name=self.server_name, server_url=self.server_url, auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token and not self.custom_headers else None, auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token and not self.custom_headers else None, custom_headers=self.custom_headers, ) + if resolve_variables: + config.resolve_environment_variables(environment_variables) + return config else: raise ValueError(f"Unsupported server type: {self.server_type}") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 86d50a37..151ff254 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -394,7 +394,7 @@ async def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_ else: actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id) mcp_servers = await server.mcp_manager.list_mcp_servers(actor=actor) - return {server.server_name: server.to_config() for server in mcp_servers} + return {server.server_name: server.to_config(resolve_variables=False) for server in mcp_servers} # NOTE: async because the MCP client/session calls are async @@ -639,6 +639,7 @@ async def test_mcp_server( client = None try: actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + request.resolve_environment_variables() client = await server.mcp_manager.get_mcp_client(request, actor) await client.connect_to_server() @@ -719,6 +720,7 @@ async def connect_mcp_server( # Create MCP client with respective transport type try: + request.resolve_environment_variables() client = await server.mcp_manager.get_mcp_client(request, actor) except ValueError as e: yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e)) diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index e0361822..2be49e04 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -66,7 +66,12 @@ class MCPManager: @enforce_types async def execute_mcp_server_tool( - self, mcp_server_name: str, tool_name: str, tool_args: Optional[Dict[str, Any]], actor: PydanticUser + self, + mcp_server_name: str, + tool_name: str, + tool_args: Optional[Dict[str, Any]], + environment_variables: Dict[str, str], + actor: PydanticUser, ) -> Tuple[str, bool]: """Call a specific tool from a specific MCP server.""" from letta.settings import tool_settings @@ -75,7 +80,7 @@ class MCPManager: # read from DB mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor) mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) - server_config = mcp_config.to_config() + server_config = mcp_config.to_config(environment_variables) else: # read from config file mcp_config = self.read_mcp_config() diff --git a/letta/services/tool_executor/mcp_tool_executor.py b/letta/services/tool_executor/mcp_tool_executor.py index 1640c145..0fd2eea5 100644 --- a/letta/services/tool_executor/mcp_tool_executor.py +++ b/letta/services/tool_executor/mcp_tool_executor.py @@ -35,8 +35,17 @@ class ExternalMCPToolExecutor(ToolExecutor): mcp_manager = MCPManager() # TODO: may need to have better client connection management + + environment_variables = {} + if agent_state: + environment_variables = agent_state.get_agent_env_vars_as_dict() + function_response, success = await mcp_manager.execute_mcp_server_tool( - mcp_server_name=mcp_server_name, tool_name=function_name, tool_args=function_args, actor=actor + mcp_server_name=mcp_server_name, + tool_name=function_name, + tool_args=function_args, + environment_variables=environment_variables, + actor=actor, ) return ToolExecutionResult( diff --git a/tests/test_managers.py b/tests/test_managers.py index 70baf8aa..7af0e4be 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8554,7 +8554,7 @@ async def test_create_mcp_server(server, default_user, event_loop): tool_name = "ask_question" tool_args = {"repoName": "letta-ai/letta", "question": "What is the primary programming language of this repository?"} result = await server.mcp_manager.execute_mcp_server_tool( - created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user + created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user, environment_variables={} ) print(result)