From b9ff1ea6240949bc2c1a92a494bfec477f99d059 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Thu, 13 Mar 2025 17:10:12 -0700 Subject: [PATCH] feat: add new routes for add/deleting MCP servers (#1272) --- letta/helpers/mcp_helpers.py | 31 +++++- letta/server/rest_api/routers/v1/tools.py | 34 ++++++- letta/server/server.py | 109 ++++++++++++++++++++-- 3 files changed, 158 insertions(+), 16 deletions(-) diff --git a/letta/helpers/mcp_helpers.py b/letta/helpers/mcp_helpers.py index 1d8adf92..450622a3 100644 --- a/letta/helpers/mcp_helpers.py +++ b/letta/helpers/mcp_helpers.py @@ -11,6 +11,9 @@ from letta.log import get_logger logger = get_logger(__name__) +# see: https://modelcontextprotocol.io/quickstart/user +MCP_CONFIG_TOPLEVEL_KEY = "mcpServers" + class MCPTool(Tool): """A simple wrapper around MCP's tool definition (to avoid conflict with our own)""" @@ -18,7 +21,7 @@ class MCPTool(Tool): class MCPServerType(str, Enum): SSE = "sse" - LOCAL = "local" + STDIO = "stdio" class BaseServerConfig(BaseModel): @@ -30,11 +33,29 @@ class SSEServerConfig(BaseServerConfig): type: MCPServerType = MCPServerType.SSE server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)") + def to_dict(self) -> dict: + values = { + "transport": "sse", + "url": self.server_url, + } + return values -class LocalServerConfig(BaseServerConfig): - type: MCPServerType = MCPServerType.LOCAL + +class StdioServerConfig(BaseServerConfig): + type: MCPServerType = MCPServerType.STDIO command: str = Field(..., description="The command to run (MCP 'local' client will run this command)") args: List[str] = Field(..., description="The arguments to pass to the command") + env: Optional[dict[str, str]] = Field(None, description="Environment variables to set") + + def to_dict(self) -> dict: + values = { + "transport": "stdio", + "command": self.command, + "args": self.args, + } + if self.env is not None: + values["env"] = self.env + return values class BaseMCPClient: @@ -83,8 +104,8 @@ class BaseMCPClient: logger.info("Cleaned up MCP clients on shutdown.") -class LocalMCPClient(BaseMCPClient): - def _initialize_connection(self, server_config: LocalServerConfig): +class StdioMCPClient(BaseMCPClient): + def _initialize_connection(self, server_config: StdioServerConfig): server_params = StdioServerParameters(command=server_config.command, args=server_config.args) stdio_cm = stdio_client(server_params) stdio_transport = self.loop.run_until_complete(stdio_cm.__aenter__()) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index c7c5800e..2290d281 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -13,7 +13,7 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException from letta.errors import LettaToolCreateError from letta.helpers.composio_helpers import get_composio_api_key -from letta.helpers.mcp_helpers import LocalServerConfig, MCPTool, SSEServerConfig +from letta.helpers.mcp_helpers import MCPTool, SSEServerConfig, StdioServerConfig from letta.log import get_logger from letta.orm.errors import UniqueConstraintViolationError from letta.schemas.letta_message import ToolReturnMessage @@ -333,7 +333,7 @@ def add_composio_tool( # Specific routes for MCP -@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, LocalServerConfig]], operation_id="list_mcp_servers") +@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, StdioServerConfig]], operation_id="list_mcp_servers") def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")): """ Get a list of all configured MCP servers @@ -376,7 +376,7 @@ def add_mcp_tool( actor_id: Optional[str] = Header(None, alias="user_id"), ): """ - Add a new MCP tool by server + tool name + Register a new MCP tool as a Letta server by MCP server + tool name """ actor = server.user_manager.get_user_or_default(user_id=actor_id) @@ -399,3 +399,31 @@ def add_mcp_tool( tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool) return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, actor=actor) + + +@router.put("/mcp/servers", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="add_mcp_server") +def add_mcp_server_to_config( + request: Union[StdioServerConfig, SSEServerConfig] = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Add a new MCP server to the Letta MCP server config + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.add_mcp_server_to_config(server_config=request, allow_upsert=True) + + +@router.delete( + "/mcp/servers/{mcp_server_name}", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="delete_mcp_server" +) +def delete_mcp_server_from_config( + mcp_server_name: str, + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Add a new MCP server to the Letta MCP server config + """ + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.delete_mcp_server_from_config(server_name=mcp_server_name) diff --git a/letta/server/server.py b/letta/server/server.py index 00a80a47..336ba086 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -23,13 +23,14 @@ from letta.dynamic_multi_agent import DynamicMultiAgent from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads from letta.helpers.mcp_helpers import ( + MCP_CONFIG_TOPLEVEL_KEY, BaseMCPClient, - LocalMCPClient, - LocalServerConfig, MCPServerType, MCPTool, SSEMCPClient, SSEServerConfig, + StdioMCPClient, + StdioServerConfig, ) # TODO use custom interface @@ -338,8 +339,8 @@ class SyncServer(Server): for server_name, server_config in mcp_server_configs.items(): if server_config.type == MCPServerType.SSE: self.mcp_clients[server_name] = SSEMCPClient() - elif server_config.type == MCPServerType.LOCAL: - self.mcp_clients[server_name] = LocalMCPClient() + elif server_config.type == MCPServerType.STDIO: + self.mcp_clients[server_name] = StdioMCPClient() else: raise ValueError(f"Invalid MCP server config: {server_config}") try: @@ -1258,7 +1259,7 @@ class SyncServer(Server): # MCP wrappers # TODO support both command + SSE servers (via config) - def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, LocalServerConfig]]: + def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]: """List the MCP servers in the config (doesn't test that they are actually working)""" mcp_server_list = {} @@ -1276,8 +1277,8 @@ class SyncServer(Server): # Proper formatting is "mcpServers" key at the top level, # then a dict with the MCP server name as the key, # with the value being the schema from StdioServerParameters - if "mcpServers" in mcp_config: - for server_name, server_params_raw in mcp_config["mcpServers"].items(): + if MCP_CONFIG_TOPLEVEL_KEY in mcp_config: + for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items(): # No support for duplicate server names if server_name in mcp_server_list: @@ -1298,7 +1299,7 @@ class SyncServer(Server): else: # Attempt to parse the server params as a StdioServerParameters try: - server_params = LocalServerConfig( + server_params = StdioServerConfig( server_name=server_name, command=server_params_raw["command"], args=server_params_raw.get("args", []), @@ -1318,6 +1319,98 @@ class SyncServer(Server): return self.mcp_clients[mcp_server_name].list_tools() + def add_mcp_server_to_config( + self, server_config: Union[SSEServerConfig, StdioServerConfig], allow_upsert: bool = True + ) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]: + """Add a new server config to the MCP config file""" + + # If the config file doesn't exist, throw an error. + mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) + if not os.path.exists(mcp_config_path): + raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}") + + # If the file does exist, attempt to parse it get calling get_mcp_servers + try: + current_mcp_servers = self.get_mcp_servers() + except Exception as e: + # Raise an error telling the user to fix the config file + logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}") + raise ValueError(f"Failed to parse MCP config file {mcp_config_path}") + + # Check if the server name is already in the config + if server_config.server_name in current_mcp_servers and not allow_upsert: + raise ValueError(f"Server name {server_config.server_name} is already in the config file") + + # Attempt to initialize the connection to the server + if server_config.type == MCPServerType.SSE: + new_mcp_client = SSEMCPClient() + elif server_config.type == MCPServerType.STDIO: + new_mcp_client = StdioMCPClient() + else: + raise ValueError(f"Invalid MCP server config: {server_config}") + try: + new_mcp_client.connect_to_server(server_config) + except: + logger.exception(f"Failed to connect to MCP server: {server_config.server_name}") + raise RuntimeError(f"Failed to connect to MCP server: {server_config.server_name}") + # Print out the tools that are connected + logger.info(f"Attempting to fetch tools from MCP server: {server_config.server_name}") + new_mcp_tools = new_mcp_client.list_tools() + logger.info(f"MCP tools connected: {", ".join([t.name for t in new_mcp_tools])}") + logger.debug(f"MCP tools: {"\n".join([str(t) for t in new_mcp_tools])}") + + # Now that we've confirmed the config is working, let's add it to the client list + self.mcp_clients[server_config.server_name] = new_mcp_client + + # Add to the server file + current_mcp_servers[server_config.server_name] = server_config + + # Write out the file, and make sure to in include the top-level mcpConfig + try: + new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}} + with open(mcp_config_path, "w") as f: + json.dump(new_mcp_file, f, indent=4) + except Exception as e: + logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}") + raise ValueError(f"Failed to write MCP config file {mcp_config_path}") + + return list(current_mcp_servers.values()) + + def delete_mcp_server_from_config(self, server_name: str) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]: + """Delete a server config from the MCP config file""" + + # If the config file doesn't exist, throw an error. + mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME) + if not os.path.exists(mcp_config_path): + raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}") + + # If the file does exist, attempt to parse it get calling get_mcp_servers + try: + current_mcp_servers = self.get_mcp_servers() + except Exception as e: + # Raise an error telling the user to fix the config file + logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}") + raise ValueError(f"Failed to parse MCP config file {mcp_config_path}") + + # Check if the server name is already in the config + # If it's not, throw an error + if server_name not in current_mcp_servers: + raise ValueError(f"Server name {server_name} not found in MCP config file") + + # Remove from the server file + del current_mcp_servers[server_name] + + # Write out the file, and make sure to in include the top-level mcpConfig + try: + new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}} + with open(mcp_config_path, "w") as f: + json.dump(new_mcp_file, f, indent=4) + except Exception as e: + logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}") + raise ValueError(f"Failed to write MCP config file {mcp_config_path}") + + return list(current_mcp_servers.values()) + @trace_method async def send_message_to_agent( self,