feat: add token column to mcp_servers and pipe through to sse server config (#2775)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
31
alembic/versions/c0ef3ff26306_add_token_to_mcp_server.py
Normal file
31
alembic/versions/c0ef3ff26306_add_token_to_mcp_server.py
Normal file
@@ -0,0 +1,31 @@
|
||||
"""add_token_to_mcp_server
|
||||
|
||||
Revision ID: c0ef3ff26306
|
||||
Revises: 1c6b6a38b713
|
||||
Create Date: 2025-06-14 14:59:53.835883
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c0ef3ff26306"
|
||||
down_revision: Union[str, None] = "1c6b6a38b713"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column("mcp_server", sa.Column("token", sa.String(), nullable=True))
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_column("mcp_server", "token")
|
||||
# ### end Alembic commands ###
|
||||
@@ -4,6 +4,10 @@ from typing import List, Optional
|
||||
from mcp import Tool
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# MCP Authentication Constants
|
||||
MCP_AUTH_HEADER_AUTHORIZATION = "Authorization"
|
||||
MCP_AUTH_TOKEN_BEARER_PREFIX = "Bearer"
|
||||
|
||||
|
||||
class MCPTool(Tool):
|
||||
"""A simple wrapper around MCP's tool definition (to avoid conflict with our own)"""
|
||||
@@ -12,6 +16,7 @@ class MCPTool(Tool):
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
STDIO = "stdio"
|
||||
STREAMABLE_HTTP = "streamable_http"
|
||||
|
||||
|
||||
class BaseServerConfig(BaseModel):
|
||||
@@ -20,14 +25,44 @@ class BaseServerConfig(BaseModel):
|
||||
|
||||
|
||||
class SSEServerConfig(BaseServerConfig):
|
||||
"""
|
||||
Configuration for an MCP server using SSE
|
||||
|
||||
Authentication can be provided in multiple ways:
|
||||
1. Using auth_header + auth_token: Will add a specific header with the token
|
||||
Example: auth_header="Authorization", auth_token="Bearer abc123"
|
||||
|
||||
2. Using the custom_headers dict: For more complex authentication scenarios
|
||||
Example: custom_headers={"X-API-Key": "abc123", "X-Custom-Header": "value"}
|
||||
"""
|
||||
|
||||
type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with SSE requests")
|
||||
|
||||
def resolve_token(self) -> Optional[str]:
|
||||
if self.auth_token and self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "sse",
|
||||
"url": self.server_url,
|
||||
}
|
||||
|
||||
# TODO: handle custom headers
|
||||
if self.custom_headers is not None or (self.auth_header is not None and self.auth_token is not None):
|
||||
headers = self.custom_headers.copy() if self.custom_headers else {}
|
||||
|
||||
# Add auth header if specified
|
||||
if self.auth_header is not None and self.auth_token is not None:
|
||||
headers[self.auth_header] = self.auth_token
|
||||
|
||||
values["headers"] = headers
|
||||
|
||||
return values
|
||||
|
||||
|
||||
@@ -46,3 +81,63 @@ class StdioServerConfig(BaseServerConfig):
|
||||
if self.env is not None:
|
||||
values["env"] = self.env
|
||||
return values
|
||||
|
||||
|
||||
class StreamableHTTPServerConfig(BaseServerConfig):
|
||||
"""
|
||||
Configuration for an MCP server using Streamable HTTP
|
||||
|
||||
Authentication can be provided in multiple ways:
|
||||
1. Using auth_header + auth_token: Will add a specific header with the token
|
||||
Example: auth_header="Authorization", auth_token="Bearer abc123"
|
||||
|
||||
2. Using the custom_headers dict: For more complex authentication scenarios
|
||||
Example: custom_headers={"X-API-Key": "abc123", "X-Custom-Header": "value"}
|
||||
"""
|
||||
|
||||
type: MCPServerType = MCPServerType.STREAMABLE_HTTP
|
||||
server_url: str = Field(..., description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
custom_headers: Optional[dict[str, str]] = Field(None, description="Custom HTTP headers to include with streamable HTTP requests")
|
||||
|
||||
def resolve_token(self) -> Optional[str]:
|
||||
if self.auth_token and self.auth_token.startswith(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} "):
|
||||
return self.auth_token[len(f"{MCP_AUTH_TOKEN_BEARER_PREFIX} ") :]
|
||||
return self.auth_token
|
||||
|
||||
def model_post_init(self, __context) -> None:
|
||||
"""Validate the server URL format."""
|
||||
# Basic validation for streamable HTTP URLs
|
||||
if not self.server_url:
|
||||
raise ValueError("server_url cannot be empty")
|
||||
|
||||
# For streamable HTTP, the URL should typically be a path or full URL
|
||||
# We'll be lenient and allow both formats
|
||||
if self.server_url.startswith("http://") or self.server_url.startswith("https://"):
|
||||
# Full URL format - this is what the user is trying
|
||||
pass
|
||||
elif "/" in self.server_url:
|
||||
# Path format like "example/mcp" - this is the typical format
|
||||
pass
|
||||
else:
|
||||
# Single word - might be valid but warn in logs
|
||||
pass
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
values = {
|
||||
"transport": "streamable_http",
|
||||
"url": self.server_url,
|
||||
}
|
||||
|
||||
# Handle custom headers
|
||||
if self.custom_headers is not None or (self.auth_header is not None and self.auth_token is not None):
|
||||
headers = self.custom_headers.copy() if self.custom_headers else {}
|
||||
|
||||
# Add auth header if specified
|
||||
if self.auth_header is not None and self.auth_token is not None:
|
||||
headers[self.auth_header] = self.auth_token
|
||||
|
||||
values["headers"] = headers
|
||||
|
||||
return values
|
||||
|
||||
@@ -100,7 +100,11 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
|
||||
try:
|
||||
if field == "type":
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v["type"]))
|
||||
# Handle both string and array types, e.g. {"type": ["string", "null"]}
|
||||
if isinstance(v["type"], list):
|
||||
function_tokens += len(encoding.encode(",".join(v["type"])))
|
||||
else:
|
||||
function_tokens += len(encoding.encode(v["type"]))
|
||||
elif field == "description":
|
||||
function_tokens += 2
|
||||
function_tokens += len(encoding.encode(v["description"]))
|
||||
|
||||
@@ -38,3 +38,4 @@ class ActorType(str, Enum):
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
STDIO = "stdio"
|
||||
STREAMABLE_HTTP = "streamable_http"
|
||||
|
||||
@@ -36,6 +36,9 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
|
||||
String, nullable=True, doc="The URL of the server (MCP SSE client will connect to this URL)"
|
||||
)
|
||||
|
||||
# access token / api key for MCP servers that require authentication
|
||||
token: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="The access token or api key for the MCP server")
|
||||
|
||||
# stdio server
|
||||
stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column(
|
||||
MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server"
|
||||
|
||||
@@ -2,7 +2,14 @@ from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.functions.mcp_client.types import MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.functions.mcp_client.types import (
|
||||
MCP_AUTH_HEADER_AUTHORIZATION,
|
||||
MCP_AUTH_TOKEN_BEARER_PREFIX,
|
||||
MCPServerType,
|
||||
SSEServerConfig,
|
||||
StdioServerConfig,
|
||||
StreamableHTTPServerConfig,
|
||||
)
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
@@ -17,6 +24,7 @@ class MCPServer(BaseMCPServer):
|
||||
|
||||
# sse config
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
|
||||
|
||||
# stdio config
|
||||
stdio_config: Optional[StdioServerConfig] = Field(
|
||||
@@ -30,22 +38,38 @@ class MCPServer(BaseMCPServer):
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
# TODO: add tokens?
|
||||
|
||||
def to_config(self) -> Union[SSEServerConfig, StdioServerConfig]:
|
||||
def to_config(self) -> Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]:
|
||||
if self.server_type == MCPServerType.SSE:
|
||||
return SSEServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None,
|
||||
custom_headers=None,
|
||||
)
|
||||
elif self.server_type == MCPServerType.STDIO:
|
||||
if self.stdio_config is None:
|
||||
raise ValueError("stdio_config is required for STDIO server type")
|
||||
return self.stdio_config
|
||||
elif self.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
if self.server_url is None:
|
||||
raise ValueError("server_url is required for STREAMABLE_HTTP server type")
|
||||
return StreamableHTTPServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
auth_header=MCP_AUTH_HEADER_AUTHORIZATION if self.token else None,
|
||||
auth_token=f"{MCP_AUTH_TOKEN_BEARER_PREFIX} {self.token}" if self.token else None,
|
||||
custom_headers=None,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {self.server_type}")
|
||||
|
||||
|
||||
class RegisterSSEMCPServer(LettaBase):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
server_type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server used for authentication")
|
||||
|
||||
|
||||
class RegisterStdioMCPServer(LettaBase):
|
||||
@@ -54,11 +78,20 @@ class RegisterStdioMCPServer(LettaBase):
|
||||
stdio_config: StdioServerConfig = Field(..., description="The configuration for the server (MCP 'local' client will run this command)")
|
||||
|
||||
|
||||
class RegisterStreamableHTTPMCPServer(LettaBase):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
server_type: MCPServerType = MCPServerType.STREAMABLE_HTTP
|
||||
server_url: str = Field(..., description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
|
||||
|
||||
class UpdateSSEMCPServer(LettaBase):
|
||||
"""Update an SSE MCP server"""
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the server")
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
|
||||
|
||||
|
||||
class UpdateStdioMCPServer(LettaBase):
|
||||
@@ -70,5 +103,14 @@ class UpdateStdioMCPServer(LettaBase):
|
||||
)
|
||||
|
||||
|
||||
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer]
|
||||
RegisterMCPServer = Union[RegisterSSEMCPServer, RegisterStdioMCPServer]
|
||||
class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
"""Update a Streamable HTTP MCP server"""
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the server")
|
||||
server_url: Optional[str] = Field(None, description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
|
||||
|
||||
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer]
|
||||
RegisterMCPServer = Union[RegisterSSEMCPServer, RegisterStdioMCPServer, RegisterStreamableHTTPMCPServer]
|
||||
|
||||
@@ -489,7 +489,9 @@ async def add_mcp_server_to_config(
|
||||
if tool_settings.mcp_disable_stdio: # protected server
|
||||
raise HTTPException(status_code=400, detail="StdioServerConfig is not supported")
|
||||
elif isinstance(request, SSEServerConfig):
|
||||
mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, server_url=request.server_url)
|
||||
mapped_request = MCPServer(
|
||||
server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token()
|
||||
)
|
||||
# TODO: add HTTP streaming
|
||||
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mapped_request, actor=actor)
|
||||
|
||||
|
||||
@@ -24,11 +24,22 @@ class AsyncBaseMCPClient:
|
||||
await self._initialize_connection(self.server_config)
|
||||
await self.session.initialize()
|
||||
self.initialized = True
|
||||
except ConnectionError as e:
|
||||
logger.error(f"MCP connection failed: {str(e)}")
|
||||
raise e
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
||||
f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}. Error: {str(e)}"
|
||||
)
|
||||
raise e
|
||||
if hasattr(self.server_config, "server_url") and self.server_config.server_url:
|
||||
server_info = f"server URL '{self.server_config.server_url}'"
|
||||
elif hasattr(self.server_config, "command") and self.server_config.command:
|
||||
server_info = f"command '{self.server_config.command}'"
|
||||
else:
|
||||
server_info = f"server '{self.server_config.server_name}'"
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to MCP {server_info}. Please check your configuration and ensure the server is accessible."
|
||||
) from e
|
||||
|
||||
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
|
||||
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
|
||||
@@ -14,7 +14,14 @@ logger = get_logger(__name__)
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
|
||||
sse_cm = sse_client(url=server_config.server_url)
|
||||
headers = {}
|
||||
if server_config.custom_headers:
|
||||
headers.update(server_config.custom_headers)
|
||||
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None)
|
||||
sse_transport = await self.exit_stack.enter_async_context(sse_cm)
|
||||
self.stdio, self.write = sse_transport
|
||||
|
||||
|
||||
56
letta/services/mcp/streamable_http_client.py
Normal file
56
letta/services/mcp/streamable_http_client.py
Normal file
@@ -0,0 +1,56 @@
|
||||
from mcp import ClientSession
|
||||
from mcp.client.streamable_http import streamablehttp_client
|
||||
|
||||
from letta.functions.mcp_client.types import BaseServerConfig, StreamableHTTPServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
|
||||
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
|
||||
if not isinstance(server_config, StreamableHTTPServerConfig):
|
||||
raise ValueError("Expected StreamableHTTPServerConfig")
|
||||
|
||||
try:
|
||||
# Prepare headers for authentication
|
||||
headers = {}
|
||||
if server_config.custom_headers:
|
||||
headers.update(server_config.custom_headers)
|
||||
|
||||
# Add auth header if specified
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
# Use streamablehttp_client context manager with headers if provided
|
||||
if headers:
|
||||
streamable_http_cm = streamablehttp_client(server_config.server_url, headers=headers)
|
||||
else:
|
||||
streamable_http_cm = streamablehttp_client(server_config.server_url)
|
||||
read_stream, write_stream, _ = await self.exit_stack.enter_async_context(streamable_http_cm)
|
||||
|
||||
# Create and enter the ClientSession context manager
|
||||
session_cm = ClientSession(read_stream, write_stream)
|
||||
self.session = await self.exit_stack.enter_async_context(session_cm)
|
||||
except Exception as e:
|
||||
# Provide more helpful error messages for specific error types
|
||||
if "404" in str(e) or "Not Found" in str(e):
|
||||
raise ConnectionError(
|
||||
f"MCP server not found at URL: {server_config.server_url}. "
|
||||
"Please verify the URL is correct and the server supports the MCP protocol."
|
||||
) from e
|
||||
elif "Connection" in str(e) or "connect" in str(e).lower():
|
||||
raise ConnectionError(
|
||||
f"Failed to connect to MCP server at: {server_config.server_url}. "
|
||||
"Please check that the server is running and accessible."
|
||||
) from e
|
||||
elif "JSON" in str(e) and "validation" in str(e):
|
||||
raise ConnectionError(
|
||||
f"MCP server at {server_config.server_url} is not returning valid JSON-RPC responses. "
|
||||
"The server may not be a proper MCP server or may be returning empty/invalid JSON. "
|
||||
"Please verify this is an MCP-compatible server endpoint."
|
||||
) from e
|
||||
else:
|
||||
# Re-raise other exceptions with additional context
|
||||
raise ConnectionError(f"Failed to initialize streamable HTTP connection to {server_config.server_url}: {str(e)}") from e
|
||||
@@ -3,17 +3,18 @@ import os
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.mcp_server import MCPServer as MCPServerModel
|
||||
from letta.schemas.mcp import MCPServer, UpdateMCPServer, UpdateSSEMCPServer, UpdateStdioMCPServer
|
||||
from letta.schemas.mcp import MCPServer, UpdateMCPServer, UpdateSSEMCPServer, UpdateStdioMCPServer, UpdateStreamableHTTPMCPServer
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp.streamable_http_client import AsyncStreamableHTTPMCPClient
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
@@ -31,7 +32,6 @@ class MCPManager:
|
||||
@enforce_types
|
||||
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]:
|
||||
"""Get a list of all tools for a specific MCP server."""
|
||||
print("mcp_server_name", mcp_server_name)
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
@@ -40,12 +40,16 @@ class MCPManager:
|
||||
mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
elif mcp_config.server_type == MCPServerType.STDIO:
|
||||
mcp_client = AsyncStdioMCPClient(server_config=server_config)
|
||||
elif mcp_config.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported MCP server type: {mcp_config.server_type}")
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# list tools
|
||||
tools = await mcp_client.list_tools()
|
||||
# TODO: change to pydantic tools
|
||||
|
||||
# TODO: change to pydantic tools
|
||||
await mcp_client.cleanup()
|
||||
|
||||
return tools
|
||||
@@ -55,7 +59,6 @@ class MCPManager:
|
||||
self, mcp_server_name: str, tool_name: str, tool_args: Optional[Dict[str, Any]], actor: PydanticUser
|
||||
) -> Tuple[str, bool]:
|
||||
"""Call a specific tool from a specific MCP server."""
|
||||
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if not tool_settings.mcp_read_from_config:
|
||||
@@ -75,6 +78,10 @@ class MCPManager:
|
||||
mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
elif isinstance(server_config, StdioServerConfig):
|
||||
mcp_client = AsyncStdioMCPClient(server_config=server_config)
|
||||
elif isinstance(server_config, StreamableHTTPServerConfig):
|
||||
mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# call tool
|
||||
@@ -114,7 +121,6 @@ class MCPManager:
|
||||
async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name=pydantic_mcp_server.server_name, actor=actor)
|
||||
print("FOUND SERVER", mcp_server_id, pydantic_mcp_server.server_name)
|
||||
if mcp_server_id:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_mcp_server.model_dump(exclude_unset=True, exclude_none=True)
|
||||
@@ -122,11 +128,16 @@ class MCPManager:
|
||||
# If there's anything to update (can only update the configs, not the name)
|
||||
if update_data:
|
||||
if pydantic_mcp_server.server_type == MCPServerType.SSE:
|
||||
update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url)
|
||||
update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token)
|
||||
elif pydantic_mcp_server.server_type == MCPServerType.STDIO:
|
||||
update_request = UpdateStdioMCPServer(stdio_config=pydantic_mcp_server.stdio_config)
|
||||
elif pydantic_mcp_server.server_type == MCPServerType.STREAMABLE_HTTP:
|
||||
update_request = UpdateStreamableHTTPMCPServer(
|
||||
server_url=pydantic_mcp_server.server_url, token=pydantic_mcp_server.token
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server type: {pydantic_mcp_server.server_type}")
|
||||
mcp_server = await self.update_mcp_server_by_id(mcp_server_id, update_request, actor)
|
||||
print("RETURN", mcp_server)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_mcp_server` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_mcp_server.server_name}, but found existing mcp server with nothing to update."
|
||||
@@ -229,7 +240,7 @@ class MCPManager:
|
||||
except NoResultFound:
|
||||
raise ValueError(f"MCP server with id {mcp_server_id} not found.")
|
||||
|
||||
def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]]:
|
||||
mcp_server_list = {}
|
||||
|
||||
# Attempt to read from ~/.letta/mcp_config.json
|
||||
@@ -260,6 +271,9 @@ class MCPManager:
|
||||
server_params = SSEServerConfig(
|
||||
server_name=server_name,
|
||||
server_url=server_params_raw["url"],
|
||||
auth_header=server_params_raw.get("auth_header", None),
|
||||
auth_token=server_params_raw.get("auth_token", None),
|
||||
headers=server_params_raw.get("headers", None),
|
||||
)
|
||||
mcp_server_list[server_name] = server_params
|
||||
except Exception as e:
|
||||
|
||||
42
poetry.lock
generated
42
poetry.lock
generated
@@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 2.1.2 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 2.1.3 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aiohappyeyeballs"
|
||||
@@ -2106,6 +2106,12 @@ files = [
|
||||
{file = "geventhttpclient-2.3.3-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:447fc2d49a41449684154c12c03ab80176a413e9810d974363a061b71bdbf5a0"},
|
||||
{file = "geventhttpclient-2.3.3-pp310-pypy310_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4598c2aa14c866a10a07a2944e2c212f53d0c337ce211336ad68ae8243646216"},
|
||||
{file = "geventhttpclient-2.3.3-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:69d2bd7ab7f94a6c73325f4b88fd07b0d5f4865672ed7a519f2d896949353761"},
|
||||
{file = "geventhttpclient-2.3.3-pp311-pypy311_pp73-macosx_10_15_x86_64.whl", hash = "sha256:45a3f7e3531dd2650f5bb840ed11ce77d0eeb45d0f4c9cd6985eb805e17490e6"},
|
||||
{file = "geventhttpclient-2.3.3-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:73b427e0ea8c2750ee05980196893287bfc9f2a155a282c0f248b472ea7ae3e7"},
|
||||
{file = "geventhttpclient-2.3.3-pp311-pypy311_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c2959ef84271e4fa646c3dbaad9e6f2912bf54dcdfefa5999c2ef7c927d92127"},
|
||||
{file = "geventhttpclient-2.3.3-pp311-pypy311_pp73-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a800fcb8e53a8f4a7c02b4b403d2325a16cad63a877e57bd603aa50bf0e475b"},
|
||||
{file = "geventhttpclient-2.3.3-pp311-pypy311_pp73-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:528321e9aab686435ba09cc6ff90f12e577ace79762f74831ec2265eeab624a8"},
|
||||
{file = "geventhttpclient-2.3.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:034be44ff3318359e3c678cb5c4ed13efd69aeb558f2981a32bd3e3fb5355700"},
|
||||
{file = "geventhttpclient-2.3.3-pp39-pypy39_pp73-macosx_10_15_x86_64.whl", hash = "sha256:7a3182f1457599c2901c48a1def37a5bc4762f696077e186e2050fcc60b2fbdf"},
|
||||
{file = "geventhttpclient-2.3.3-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:86b489238dc2cbfa53cdd5621e888786a53031d327e0a8509529c7568292b0ce"},
|
||||
{file = "geventhttpclient-2.3.3-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c4c8aca6ab5da4211870c1d8410c699a9d543e86304aac47e1558ec94d0da97a"},
|
||||
@@ -3899,14 +3905,14 @@ traitlets = "*"
|
||||
|
||||
[[package]]
|
||||
name = "mcp"
|
||||
version = "1.6.0"
|
||||
version = "1.9.4"
|
||||
description = "Model Context Protocol SDK"
|
||||
optional = false
|
||||
python-versions = ">=3.10"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "mcp-1.6.0-py3-none-any.whl", hash = "sha256:7bd24c6ea042dbec44c754f100984d186620d8b841ec30f1b19eda9b93a634d0"},
|
||||
{file = "mcp-1.6.0.tar.gz", hash = "sha256:d9324876de2c5637369f43161cd71eebfd803df5a95e46225cab8d280e366723"},
|
||||
{file = "mcp-1.9.4-py3-none-any.whl", hash = "sha256:7fcf36b62936adb8e63f89346bccca1268eeca9bf6dfb562ee10b1dfbda9dac0"},
|
||||
{file = "mcp-1.9.4.tar.gz", hash = "sha256:cfb0bcd1a9535b42edaef89947b9e18a8feb49362e1cc059d6e7fc636f2cb09f"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
@@ -3915,9 +3921,12 @@ httpx = ">=0.27"
|
||||
httpx-sse = ">=0.4"
|
||||
pydantic = ">=2.7.2,<3.0.0"
|
||||
pydantic-settings = ">=2.5.2"
|
||||
python-dotenv = {version = ">=1.0.0", optional = true, markers = "extra == \"cli\""}
|
||||
python-multipart = ">=0.0.9"
|
||||
sse-starlette = ">=1.6.1"
|
||||
starlette = ">=0.27"
|
||||
uvicorn = ">=0.23.1"
|
||||
typer = {version = ">=0.12.4", optional = true, markers = "extra == \"cli\""}
|
||||
uvicorn = {version = ">=0.23.1", markers = "sys_platform != \"emscripten\""}
|
||||
|
||||
[package.extras]
|
||||
cli = ["python-dotenv (>=1.0.0)", "typer (>=0.12.4)"]
|
||||
@@ -6943,29 +6952,22 @@ test = ["argcomplete (>=3.0.3)", "mypy (>=1.7.0)", "pre-commit", "pytest (>=7.0,
|
||||
|
||||
[[package]]
|
||||
name = "typer"
|
||||
version = "0.9.4"
|
||||
version = "0.15.4"
|
||||
description = "Typer, build great CLIs. Easy to code. Based on Python type hints."
|
||||
optional = false
|
||||
python-versions = ">=3.6"
|
||||
python-versions = ">=3.7"
|
||||
groups = ["main"]
|
||||
files = [
|
||||
{file = "typer-0.9.4-py3-none-any.whl", hash = "sha256:aa6c4a4e2329d868b80ecbaf16f807f2b54e192209d7ac9dd42691d63f7a54eb"},
|
||||
{file = "typer-0.9.4.tar.gz", hash = "sha256:f714c2d90afae3a7929fcd72a3abb08df305e1ff61719381384211c4070af57f"},
|
||||
{file = "typer-0.15.4-py3-none-any.whl", hash = "sha256:eb0651654dcdea706780c466cf06d8f174405a659ffff8f163cfbfee98c0e173"},
|
||||
{file = "typer-0.15.4.tar.gz", hash = "sha256:89507b104f9b6a0730354f27c39fae5b63ccd0c95b1ce1f1a6ba0cfd329997c3"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
click = ">=7.1.1,<9.0.0"
|
||||
colorama = {version = ">=0.4.3,<0.5.0", optional = true, markers = "extra == \"all\""}
|
||||
rich = {version = ">=10.11.0,<14.0.0", optional = true, markers = "extra == \"all\""}
|
||||
shellingham = {version = ">=1.3.0,<2.0.0", optional = true, markers = "extra == \"all\""}
|
||||
click = ">=8.0.0,<8.2"
|
||||
rich = ">=10.11.0"
|
||||
shellingham = ">=1.3.0"
|
||||
typing-extensions = ">=3.7.4.3"
|
||||
|
||||
[package.extras]
|
||||
all = ["colorama (>=0.4.3,<0.5.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
|
||||
dev = ["autoflake (>=1.3.1,<2.0.0)", "flake8 (>=3.8.3,<4.0.0)", "pre-commit (>=2.17.0,<3.0.0)"]
|
||||
doc = ["cairosvg (>=2.5.2,<3.0.0)", "mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pillow (>=9.3.0,<10.0.0)"]
|
||||
test = ["black (>=22.3.0,<23.0.0)", "coverage (>=6.2,<7.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.971)", "pytest (>=4.4.0,<8.0.0)", "pytest-cov (>=2.10.0,<5.0.0)", "pytest-sugar (>=0.9.4,<0.10.0)", "pytest-xdist (>=1.32.0,<4.0.0)", "rich (>=10.11.0,<14.0.0)", "shellingham (>=1.3.0,<2.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "typing-extensions"
|
||||
version = "4.13.2"
|
||||
@@ -7799,4 +7801,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.1"
|
||||
python-versions = "<3.14,>=3.10"
|
||||
content-hash = "ce40c1a5c61327463e698e66ededc408fd37e9fb9682d5e1c20a6f7036d91635"
|
||||
content-hash = "0e74c3c79cf0358e9612971b2bcf2fde7d2d9882ff49b40a1ac1a84f0abefa26"
|
||||
|
||||
@@ -16,7 +16,7 @@ letta = "letta.main:app"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = "<3.14,>=3.10"
|
||||
typer = {extras = ["all"], version = "^0.9.0"}
|
||||
typer = "^0.15.2"
|
||||
questionary = "^2.0.1"
|
||||
pytz = "^2023.3.post1"
|
||||
tqdm = "^4.66.1"
|
||||
@@ -84,7 +84,7 @@ colorama = "^0.4.6"
|
||||
marshmallow-sqlalchemy = "^1.4.1"
|
||||
boto3 = {version = "^1.36.24", optional = true}
|
||||
datamodel-code-generator = {extras = ["http"], version = "^0.25.0"}
|
||||
mcp = "^1.3.0"
|
||||
mcp = {extras = ["cli"], version = "^1.9.4"}
|
||||
firecrawl-py = "^1.15.0"
|
||||
apscheduler = "^3.11.0"
|
||||
aiomultiprocess = "^0.9.1"
|
||||
|
||||
Reference in New Issue
Block a user