feat: add new routes for add/deleting MCP servers (#1272)
This commit is contained in:
@@ -11,6 +11,9 @@ from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
# see: https://modelcontextprotocol.io/quickstart/user
|
||||
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
||||
@@ -18,7 +21,7 @@ class MCPTool(Tool):
|
||||
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
LOCAL = "local"
|
||||
STDIO = "stdio"
|
||||
|
||||
|
||||
class BaseServerConfig(BaseModel):
|
||||
@@ -30,11 +33,29 @@ class SSEServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "sse",
|
||||
"url": self.server_url,
|
||||
}
|
||||
return values
|
||||
|
||||
class LocalServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.LOCAL
|
||||
|
||||
class StdioServerConfig(BaseServerConfig):
|
||||
type: MCPServerType = MCPServerType.STDIO
|
||||
command: str = Field(..., description="The command to run (MCP 'local' client will run this command)")
|
||||
args: List[str] = Field(..., description="The arguments to pass to the command")
|
||||
env: Optional[dict[str, str]] = Field(None, description="Environment variables to set")
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "stdio",
|
||||
"command": self.command,
|
||||
"args": self.args,
|
||||
}
|
||||
if self.env is not None:
|
||||
values["env"] = self.env
|
||||
return values
|
||||
|
||||
|
||||
class BaseMCPClient:
|
||||
@@ -83,8 +104,8 @@ class BaseMCPClient:
|
||||
logger.info("Cleaned up MCP clients on shutdown.")
|
||||
|
||||
|
||||
class LocalMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: LocalServerConfig):
|
||||
class StdioMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: StdioServerConfig):
|
||||
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
||||
stdio_cm = stdio_client(server_params)
|
||||
stdio_transport = self.loop.run_until_complete(stdio_cm.__aenter__())
|
||||
|
||||
@@ -13,7 +13,7 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
|
||||
from letta.errors import LettaToolCreateError
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.mcp_helpers import LocalServerConfig, MCPTool, SSEServerConfig
|
||||
from letta.helpers.mcp_helpers import MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
@@ -333,7 +333,7 @@ def add_composio_tool(
|
||||
|
||||
|
||||
# Specific routes for MCP
|
||||
@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, LocalServerConfig]], operation_id="list_mcp_servers")
|
||||
@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, StdioServerConfig]], operation_id="list_mcp_servers")
|
||||
def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")):
|
||||
"""
|
||||
Get a list of all configured MCP servers
|
||||
@@ -376,7 +376,7 @@ def add_mcp_tool(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new MCP tool by server + tool name
|
||||
Register a new MCP tool as a Letta server by MCP server + tool name
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
@@ -399,3 +399,31 @@ def add_mcp_tool(
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, actor=actor)
|
||||
|
||||
|
||||
@router.put("/mcp/servers", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="add_mcp_server")
|
||||
def add_mcp_server_to_config(
|
||||
request: Union[StdioServerConfig, SSEServerConfig] = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/mcp/servers/{mcp_server_name}", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="delete_mcp_server"
|
||||
)
|
||||
def delete_mcp_server_from_config(
|
||||
mcp_server_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.delete_mcp_server_from_config(server_name=mcp_server_name)
|
||||
|
||||
@@ -23,13 +23,14 @@ from letta.dynamic_multi_agent import DynamicMultiAgent
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.mcp_helpers import (
|
||||
MCP_CONFIG_TOPLEVEL_KEY,
|
||||
BaseMCPClient,
|
||||
LocalMCPClient,
|
||||
LocalServerConfig,
|
||||
MCPServerType,
|
||||
MCPTool,
|
||||
SSEMCPClient,
|
||||
SSEServerConfig,
|
||||
StdioMCPClient,
|
||||
StdioServerConfig,
|
||||
)
|
||||
|
||||
# TODO use custom interface
|
||||
@@ -338,8 +339,8 @@ class SyncServer(Server):
|
||||
for server_name, server_config in mcp_server_configs.items():
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
self.mcp_clients[server_name] = SSEMCPClient()
|
||||
elif server_config.type == MCPServerType.LOCAL:
|
||||
self.mcp_clients[server_name] = LocalMCPClient()
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
self.mcp_clients[server_name] = StdioMCPClient()
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
try:
|
||||
@@ -1258,7 +1259,7 @@ class SyncServer(Server):
|
||||
|
||||
# MCP wrappers
|
||||
# TODO support both command + SSE servers (via config)
|
||||
def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, LocalServerConfig]]:
|
||||
def get_mcp_servers(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
"""List the MCP servers in the config (doesn't test that they are actually working)"""
|
||||
mcp_server_list = {}
|
||||
|
||||
@@ -1276,8 +1277,8 @@ class SyncServer(Server):
|
||||
# Proper formatting is "mcpServers" key at the top level,
|
||||
# then a dict with the MCP server name as the key,
|
||||
# with the value being the schema from StdioServerParameters
|
||||
if "mcpServers" in mcp_config:
|
||||
for server_name, server_params_raw in mcp_config["mcpServers"].items():
|
||||
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
||||
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
||||
|
||||
# No support for duplicate server names
|
||||
if server_name in mcp_server_list:
|
||||
@@ -1298,7 +1299,7 @@ class SyncServer(Server):
|
||||
else:
|
||||
# Attempt to parse the server params as a StdioServerParameters
|
||||
try:
|
||||
server_params = LocalServerConfig(
|
||||
server_params = StdioServerConfig(
|
||||
server_name=server_name,
|
||||
command=server_params_raw["command"],
|
||||
args=server_params_raw.get("args", []),
|
||||
@@ -1318,6 +1319,98 @@ class SyncServer(Server):
|
||||
|
||||
return self.mcp_clients[mcp_server_name].list_tools()
|
||||
|
||||
def add_mcp_server_to_config(
|
||||
self, server_config: Union[SSEServerConfig, StdioServerConfig], allow_upsert: bool = True
|
||||
) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
"""Add a new server config to the MCP config file"""
|
||||
|
||||
# If the config file doesn't exist, throw an error.
|
||||
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
||||
if not os.path.exists(mcp_config_path):
|
||||
raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}")
|
||||
|
||||
# If the file does exist, attempt to parse it get calling get_mcp_servers
|
||||
try:
|
||||
current_mcp_servers = self.get_mcp_servers()
|
||||
except Exception as e:
|
||||
# Raise an error telling the user to fix the config file
|
||||
logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to parse MCP config file {mcp_config_path}")
|
||||
|
||||
# Check if the server name is already in the config
|
||||
if server_config.server_name in current_mcp_servers and not allow_upsert:
|
||||
raise ValueError(f"Server name {server_config.server_name} is already in the config file")
|
||||
|
||||
# Attempt to initialize the connection to the server
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
new_mcp_client = SSEMCPClient()
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
new_mcp_client = StdioMCPClient()
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
try:
|
||||
new_mcp_client.connect_to_server(server_config)
|
||||
except:
|
||||
logger.exception(f"Failed to connect to MCP server: {server_config.server_name}")
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {server_config.server_name}")
|
||||
# Print out the tools that are connected
|
||||
logger.info(f"Attempting to fetch tools from MCP server: {server_config.server_name}")
|
||||
new_mcp_tools = new_mcp_client.list_tools()
|
||||
logger.info(f"MCP tools connected: {", ".join([t.name for t in new_mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {"\n".join([str(t) for t in new_mcp_tools])}")
|
||||
|
||||
# Now that we've confirmed the config is working, let's add it to the client list
|
||||
self.mcp_clients[server_config.server_name] = new_mcp_client
|
||||
|
||||
# Add to the server file
|
||||
current_mcp_servers[server_config.server_name] = server_config
|
||||
|
||||
# Write out the file, and make sure to in include the top-level mcpConfig
|
||||
try:
|
||||
new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}}
|
||||
with open(mcp_config_path, "w") as f:
|
||||
json.dump(new_mcp_file, f, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to write MCP config file {mcp_config_path}")
|
||||
|
||||
return list(current_mcp_servers.values())
|
||||
|
||||
def delete_mcp_server_from_config(self, server_name: str) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
"""Delete a server config from the MCP config file"""
|
||||
|
||||
# If the config file doesn't exist, throw an error.
|
||||
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
||||
if not os.path.exists(mcp_config_path):
|
||||
raise FileNotFoundError(f"MCP config file not found: {mcp_config_path}")
|
||||
|
||||
# If the file does exist, attempt to parse it get calling get_mcp_servers
|
||||
try:
|
||||
current_mcp_servers = self.get_mcp_servers()
|
||||
except Exception as e:
|
||||
# Raise an error telling the user to fix the config file
|
||||
logger.error(f"Failed to parse MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to parse MCP config file {mcp_config_path}")
|
||||
|
||||
# Check if the server name is already in the config
|
||||
# If it's not, throw an error
|
||||
if server_name not in current_mcp_servers:
|
||||
raise ValueError(f"Server name {server_name} not found in MCP config file")
|
||||
|
||||
# Remove from the server file
|
||||
del current_mcp_servers[server_name]
|
||||
|
||||
# Write out the file, and make sure to in include the top-level mcpConfig
|
||||
try:
|
||||
new_mcp_file = {MCP_CONFIG_TOPLEVEL_KEY: {k: v.to_dict() for k, v in current_mcp_servers.items()}}
|
||||
with open(mcp_config_path, "w") as f:
|
||||
json.dump(new_mcp_file, f, indent=4)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to write MCP config file at {mcp_config_path}: {e}")
|
||||
raise ValueError(f"Failed to write MCP config file {mcp_config_path}")
|
||||
|
||||
return list(current_mcp_servers.values())
|
||||
|
||||
@trace_method
|
||||
async def send_message_to_agent(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user