feat: add new mcp_servers routes [LET-4321] (#5675)
--------- Co-authored-by: Ari Webb <ari@letta.com> Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
This commit is contained in:
47
alembic/versions/c6c43222e2de_add_mcp_tools_table.py
Normal file
47
alembic/versions/c6c43222e2de_add_mcp_tools_table.py
Normal file
@@ -0,0 +1,47 @@
|
||||
"""Add mcp_tools table
|
||||
|
||||
Revision ID: c6c43222e2de
|
||||
Revises: 6756d04c3ddb
|
||||
Create Date: 2025-10-20 17:25:54.334037
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c6c43222e2de"
|
||||
down_revision: Union[str, None] = "6756d04c3ddb"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"mcp_tools",
|
||||
sa.Column("mcp_server_id", sa.String(), nullable=False),
|
||||
sa.Column("tool_id", sa.String(), nullable=False),
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("mcp_tools")
|
||||
# ### end Alembic commands ###
|
||||
1506
fern/openapi.json
1506
fern/openapi.json
File diff suppressed because it is too large
Load Diff
@@ -504,14 +504,43 @@ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormat
|
||||
|
||||
|
||||
def serialize_mcp_stdio_config(config: Union[Optional[StdioServerConfig], Dict]) -> Optional[Dict]:
|
||||
"""Convert an StdioServerConfig object into a JSON-serializable dictionary."""
|
||||
"""Convert an StdioServerConfig object into a JSON-serializable dictionary.
|
||||
|
||||
Persist required fields for successful deserialization back into a
|
||||
StdioServerConfig model (namely `server_name` and `type`). The
|
||||
`to_dict()` helper intentionally omits these since they're not needed
|
||||
by MCP transport, but our ORM deserializer reconstructs the pydantic
|
||||
model and requires them.
|
||||
"""
|
||||
if config and isinstance(config, StdioServerConfig):
|
||||
return config.to_dict()
|
||||
data = config.to_dict()
|
||||
# Preserve required fields for pydantic reconstruction
|
||||
data["server_name"] = config.server_name
|
||||
# Store enum as its value; pydantic will coerce on load
|
||||
data["type"] = config.type.value if hasattr(config.type, "value") else str(config.type)
|
||||
return data
|
||||
return config
|
||||
|
||||
|
||||
def deserialize_mcp_stdio_config(data: Optional[Dict]) -> Optional[StdioServerConfig]:
|
||||
"""Convert a dictionary back into an StdioServerConfig object."""
|
||||
"""Convert a dictionary back into an StdioServerConfig object.
|
||||
|
||||
Backwards-compatibility notes:
|
||||
- Older rows may only include `transport`, `command`, `args`, `env`.
|
||||
In that case, provide defaults for `server_name` and `type` to
|
||||
satisfy the pydantic model requirements.
|
||||
- If both `type` and `transport` are present, prefer `type`.
|
||||
"""
|
||||
if not data:
|
||||
return None
|
||||
return StdioServerConfig(**data)
|
||||
|
||||
payload = dict(data)
|
||||
# Map legacy `transport` field to required `type` if missing
|
||||
if "type" not in payload and "transport" in payload:
|
||||
payload["type"] = payload["transport"]
|
||||
|
||||
# Ensure required field exists; use a sensible placeholder when unknown
|
||||
if "server_name" not in payload:
|
||||
payload["server_name"] = payload.get("name", "unknown")
|
||||
|
||||
return StdioServerConfig(**payload)
|
||||
|
||||
@@ -56,3 +56,12 @@ class MCPServer(SqlalchemyBase, OrganizationMixin):
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON, default=lambda: {}, doc="A dictionary of additional metadata for the MCP server."
|
||||
)
|
||||
|
||||
|
||||
class MCPTools(SqlalchemyBase, OrganizationMixin):
|
||||
"""Represents a mapping of MCP server ID to tool ID"""
|
||||
|
||||
__tablename__ = "mcp_tools"
|
||||
|
||||
mcp_server_id: Mapped[str] = mapped_column(String, doc="The ID of the MCP server")
|
||||
tool_id: Mapped[str] = mapped_column(String, doc="The ID of the tool")
|
||||
|
||||
@@ -148,6 +148,7 @@ class MCPServer(BaseMCPServer):
|
||||
class UpdateSSEMCPServer(LettaBase):
|
||||
"""Update an SSE MCP server"""
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the MCP server")
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
token: Optional[str] = Field(None, description="The access token or API key for the MCP server (used for SSE authentication)")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom authentication headers as key-value pairs")
|
||||
@@ -156,6 +157,7 @@ class UpdateSSEMCPServer(LettaBase):
|
||||
class UpdateStdioMCPServer(LettaBase):
|
||||
"""Update a Stdio MCP server"""
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the MCP server")
|
||||
stdio_config: Optional[StdioServerConfig] = Field(
|
||||
None, description="The configuration for the server (MCP 'local' client will run this command)"
|
||||
)
|
||||
@@ -164,6 +166,7 @@ class UpdateStdioMCPServer(LettaBase):
|
||||
class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
"""Update a Streamable HTTP MCP server"""
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the MCP server")
|
||||
server_url: Optional[str] = Field(None, description="The URL path for the streamable HTTP server (e.g., 'example/mcp')")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
|
||||
@@ -41,18 +41,21 @@ class StdioMCPServer(CreateStdioMCPServer):
|
||||
"""A Stdio MCP server"""
|
||||
|
||||
id: str = BaseMCPServer.generate_id_field()
|
||||
type: MCPServerType = MCPServerType.STDIO
|
||||
|
||||
|
||||
class SSEMCPServer(CreateSSEMCPServer):
|
||||
"""An SSE MCP server"""
|
||||
|
||||
id: str = BaseMCPServer.generate_id_field()
|
||||
type: MCPServerType = MCPServerType.SSE
|
||||
|
||||
|
||||
class StreamableHTTPMCPServer(CreateStreamableHTTPMCPServer):
|
||||
"""A Streamable HTTP MCP server"""
|
||||
|
||||
id: str = BaseMCPServer.generate_id_field()
|
||||
type: MCPServerType = MCPServerType.STREAMABLE_HTTP
|
||||
|
||||
|
||||
MCPServerUnion = Union[StdioMCPServer, SSEMCPServer, StreamableHTTPMCPServer]
|
||||
@@ -74,9 +77,10 @@ class UpdateSSEMCPServer(LettaBase):
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the MCP server")
|
||||
server_url: Optional[str] = Field(None, description="The URL of the SSE MCP server")
|
||||
# Note: auth_token is renamed to token to match the ORM field
|
||||
token: Optional[str] = Field(None, description="The authentication token")
|
||||
# auth_header is excluded as it's derived from the token
|
||||
# Accept both `auth_token` (API surface) and `token` (internal ORM naming)
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
token: Optional[str] = Field(None, description="The authentication token (internal)")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to send with requests")
|
||||
|
||||
|
||||
@@ -85,9 +89,10 @@ class UpdateStreamableHTTPMCPServer(LettaBase):
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the MCP server")
|
||||
server_url: Optional[str] = Field(None, description="The URL of the Streamable HTTP MCP server")
|
||||
# Note: auth_token is renamed to token to match the ORM field
|
||||
token: Optional[str] = Field(None, description="The authentication token")
|
||||
# auth_header is excluded as it's derived from the token
|
||||
# Accept both `auth_token` (API surface) and `token` (internal ORM naming)
|
||||
auth_token: Optional[str] = Field(None, description="The authentication token or API key value")
|
||||
token: Optional[str] = Field(None, description="The authentication token (internal)")
|
||||
auth_header: Optional[str] = Field(None, description="The name of the authentication header (e.g., 'Authorization')")
|
||||
custom_headers: Optional[Dict[str, str]] = Field(None, description="Custom headers to send with requests")
|
||||
|
||||
|
||||
@@ -296,3 +301,49 @@ def convert_generic_to_union(server) -> MCPServerUnion:
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unknown server type: {server.server_type}")
|
||||
|
||||
|
||||
def convert_update_to_internal(request: Union[UpdateStdioMCPServer, UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer]):
|
||||
"""Convert external API update models to internal UpdateMCPServer union used by the manager.
|
||||
|
||||
- Flattens stdio fields into StdioServerConfig inside UpdateStdioMCPServer
|
||||
- Maps `auth_token` to `token` for HTTP-based transports
|
||||
- Ignores `auth_header` at update time (header is derived from token)
|
||||
"""
|
||||
# Local import to avoid circulars
|
||||
from letta.functions.mcp_client.types import MCPServerType as MCPType, StdioServerConfig as StdioCfg
|
||||
from letta.schemas.mcp import (
|
||||
UpdateSSEMCPServer as InternalUpdateSSE,
|
||||
UpdateStdioMCPServer as InternalUpdateStdio,
|
||||
UpdateStreamableHTTPMCPServer as InternalUpdateHTTP,
|
||||
)
|
||||
|
||||
if isinstance(request, UpdateStdioMCPServer):
|
||||
stdio_cfg = None
|
||||
# Only build stdio_config if command and args are explicitly provided to avoid overwriting existing config
|
||||
if request.command is not None and request.args is not None:
|
||||
stdio_cfg = StdioCfg(
|
||||
server_name=request.server_name or "",
|
||||
type=MCPType.STDIO,
|
||||
command=request.command,
|
||||
args=request.args,
|
||||
env=request.env,
|
||||
)
|
||||
kwargs: dict = {}
|
||||
if request.server_name is not None:
|
||||
kwargs["server_name"] = request.server_name
|
||||
if stdio_cfg is not None:
|
||||
kwargs["stdio_config"] = stdio_cfg
|
||||
return InternalUpdateStdio(**kwargs)
|
||||
elif isinstance(request, UpdateSSEMCPServer):
|
||||
token_value = request.auth_token or request.token
|
||||
return InternalUpdateSSE(
|
||||
server_name=request.server_name, server_url=request.server_url, token=token_value, custom_headers=request.custom_headers
|
||||
)
|
||||
elif isinstance(request, UpdateStreamableHTTPMCPServer):
|
||||
token_value = request.auth_token or request.token
|
||||
return InternalUpdateHTTP(
|
||||
server_name=request.server_name, server_url=request.server_url, auth_token=token_value, custom_headers=request.custom_headers
|
||||
)
|
||||
else:
|
||||
raise TypeError(f"Unsupported update request type: {type(request)}")
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.server.rest_api.routers.v1.internal_runs import router as internal_ru
|
||||
from letta.server.rest_api.routers.v1.internal_templates import router as internal_templates_router
|
||||
from letta.server.rest_api.routers.v1.jobs import router as jobs_router
|
||||
from letta.server.rest_api.routers.v1.llms import router as llm_router
|
||||
from letta.server.rest_api.routers.v1.mcp_servers import router as mcp_servers_router
|
||||
from letta.server.rest_api.routers.v1.messages import router as messages_router
|
||||
from letta.server.rest_api.routers.v1.providers import router as providers_router
|
||||
from letta.server.rest_api.routers.v1.runs import router as runs_router
|
||||
@@ -34,6 +35,7 @@ ROUTERS = [
|
||||
internal_runs_router,
|
||||
internal_templates_router,
|
||||
llm_router,
|
||||
mcp_servers_router,
|
||||
blocks_router,
|
||||
jobs_router,
|
||||
health_router,
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Union
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Request
|
||||
from httpx import HTTPStatusError
|
||||
from starlette.responses import StreamingResponse
|
||||
|
||||
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.mcp_server import (
|
||||
@@ -11,14 +13,20 @@ from letta.schemas.mcp_server import (
|
||||
MCPToolExecuteRequest,
|
||||
UpdateMCPServerUnion,
|
||||
convert_generic_to_union,
|
||||
convert_update_to_internal,
|
||||
)
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.server.rest_api.dependencies import (
|
||||
HeaderParams,
|
||||
get_headers,
|
||||
get_letta_server,
|
||||
)
|
||||
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp.types import OauthStreamEvent
|
||||
from letta.settings import tool_settings
|
||||
|
||||
router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"])
|
||||
@@ -39,6 +47,7 @@ async def create_mcp_server(
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
# TODO: add the tools to the MCP server table we made.
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
new_server = await server.mcp_server_manager.create_mcp_server_from_config_with_tools(request, actor=actor)
|
||||
return convert_generic_to_union(new_server)
|
||||
@@ -56,7 +65,6 @@ async def list_mcp_servers(
|
||||
"""
|
||||
Get a list of all configured MCP servers
|
||||
"""
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
|
||||
return [convert_generic_to_union(mcp_server) for mcp_server in mcp_servers]
|
||||
@@ -112,8 +120,10 @@ async def update_mcp_server(
|
||||
Update an existing MCP server configuration
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
# Convert external update payload to internal manager union
|
||||
internal_update = convert_update_to_internal(request)
|
||||
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
|
||||
mcp_server_id=mcp_server_id, mcp_server_update=request, actor=actor
|
||||
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
|
||||
)
|
||||
return convert_generic_to_union(updated_server)
|
||||
|
||||
@@ -127,24 +137,10 @@ async def list_mcp_tools_by_server(
|
||||
"""
|
||||
Get a list of all tools for a specific MCP server
|
||||
"""
|
||||
# TODO: implement this. We want to use the new tools table instead of going to the mcp server.
|
||||
pass
|
||||
# actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
# mcp_tools = await server.mcp_server_manager.list_mcp_server_tools(mcp_server_id, actor=actor)
|
||||
# # Convert MCPTool objects to Tool objects
|
||||
# tools = []
|
||||
# for mcp_tool in mcp_tools:
|
||||
# from letta.schemas.tool import ToolCreate
|
||||
# tool_create = ToolCreate.from_mcp(mcp_server_name="", mcp_tool=mcp_tool)
|
||||
# tools.append(Tool(
|
||||
# id=f"mcp-tool-{mcp_tool.name}", # Generate a temporary ID
|
||||
# name=mcp_tool.name,
|
||||
# description=tool_create.description,
|
||||
# json_schema=tool_create.json_schema,
|
||||
# source_code=tool_create.source_code,
|
||||
# tags=tool_create.tags,
|
||||
# ))
|
||||
# return tools
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
# Use the new efficient method that queries from the database using MCPTools mapping
|
||||
tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor)
|
||||
return tools
|
||||
|
||||
|
||||
@router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_get_mcp_tool")
|
||||
@@ -158,13 +154,11 @@ async def get_mcp_tool(
|
||||
Get a specific MCP tool by its ID
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
# Use the tool_manager's existing method to get the tool by ID
|
||||
# Verify the tool belongs to the MCP server (optional check)
|
||||
tool = await server.tool_manager.get_tool_by_id_async(tool_id=tool_id, actor=actor)
|
||||
tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor)
|
||||
return tool
|
||||
|
||||
|
||||
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolReturnMessage, operation_id="mcp_run_tool")
|
||||
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool")
|
||||
async def run_mcp_tool(
|
||||
mcp_server_id: str,
|
||||
tool_id: str,
|
||||
@@ -188,9 +182,10 @@ async def run_mcp_tool(
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Create a ToolReturnMessage
|
||||
return ToolReturnMessage(
|
||||
id=f"tool-return-{tool_id}", tool_call_id=f"call-{tool_id}", tool_return=result, status="success" if success else "error"
|
||||
# Create a ToolExecutionResult
|
||||
return ToolExecutionResult(
|
||||
status="success" if success else "error",
|
||||
func_return=result,
|
||||
)
|
||||
|
||||
|
||||
@@ -231,6 +226,7 @@ async def refresh_mcp_server_tools(
|
||||
)
|
||||
async def connect_mcp_server(
|
||||
mcp_server_id: str,
|
||||
request: Request,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
) -> StreamingResponse:
|
||||
@@ -238,72 +234,76 @@ async def connect_mcp_server(
|
||||
Connect to an MCP server with support for OAuth via SSE.
|
||||
Returns a stream of events handling authorization state and exchange if OAuth is required.
|
||||
"""
|
||||
pass
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
||||
|
||||
# async def oauth_stream_generator(
|
||||
# request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
|
||||
# http_request: Request,
|
||||
# ) -> AsyncGenerator[str, None]:
|
||||
# client = None
|
||||
# Convert the MCP server to the appropriate config type
|
||||
config = mcp_server.to_config(resolve_variables=False)
|
||||
|
||||
# oauth_flow_attempted = False
|
||||
# try:
|
||||
# # Acknolwedge connection attempt
|
||||
# yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=request.server_name)
|
||||
async def oauth_stream_generator(
|
||||
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
|
||||
http_request: Request,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
client = None
|
||||
|
||||
# actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
oauth_flow_attempted = False
|
||||
try:
|
||||
# Acknowledge connection attempt
|
||||
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name)
|
||||
|
||||
# # Create MCP client with respective transport type
|
||||
# try:
|
||||
# request.resolve_environment_variables()
|
||||
# client = await server.mcp_server_manager.get_mcp_client(request, actor)
|
||||
# except ValueError as e:
|
||||
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
|
||||
# return
|
||||
# Create MCP client with respective transport type
|
||||
try:
|
||||
mcp_config.resolve_environment_variables()
|
||||
client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor)
|
||||
except ValueError as e:
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
|
||||
return
|
||||
|
||||
# # Try normal connection first for flows that don't require OAuth
|
||||
# try:
|
||||
# await client.connect_to_server()
|
||||
# tools = await client.list_tools(serialize=True)
|
||||
# yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
||||
# return
|
||||
# except ConnectionError:
|
||||
# # TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
|
||||
# if isinstance(client, AsyncStdioMCPClient):
|
||||
# logger.warning("OAuth not supported for stdio")
|
||||
# yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
|
||||
# return
|
||||
# # Continue to OAuth flow
|
||||
# logger.info(f"Attempting OAuth flow for {request}...")
|
||||
# except Exception as e:
|
||||
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
|
||||
# return
|
||||
# finally:
|
||||
# if client:
|
||||
# try:
|
||||
# await client.cleanup()
|
||||
# # This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
|
||||
# # For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
|
||||
# except* HTTPStatusError:
|
||||
# oauth_flow_attempted = True
|
||||
# async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request):
|
||||
# yield event
|
||||
# Try normal connection first for flows that don't require OAuth
|
||||
try:
|
||||
await client.connect_to_server()
|
||||
tools = await client.list_tools(serialize=True)
|
||||
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
||||
return
|
||||
except ConnectionError:
|
||||
# TODO: jnjpng make this connection error check more specific to the 401 unauthorized error
|
||||
if isinstance(client, AsyncStdioMCPClient):
|
||||
logger.warning("OAuth not supported for stdio")
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
|
||||
return
|
||||
# Continue to OAuth flow
|
||||
logger.info(f"Attempting OAuth flow for {mcp_config}...")
|
||||
except Exception as e:
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
|
||||
return
|
||||
finally:
|
||||
if client:
|
||||
try:
|
||||
await client.cleanup()
|
||||
# This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
|
||||
# For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
|
||||
except HTTPStatusError:
|
||||
oauth_flow_attempted = True
|
||||
async for event in server.mcp_server_manager.handle_oauth_flow(
|
||||
request=mcp_config, actor=actor, http_request=http_request
|
||||
):
|
||||
yield event
|
||||
|
||||
# # Failsafe to make sure we don't try to handle OAuth flow twice
|
||||
# if not oauth_flow_attempted:
|
||||
# async for event in server.mcp_server_manager.handle_oauth_flow(request=request, actor=actor, http_request=http_request):
|
||||
# yield event
|
||||
# return
|
||||
# except Exception as e:
|
||||
# detailed_error = drill_down_exception(e)
|
||||
# logger.error(f"Error in OAuth stream:\n{detailed_error}")
|
||||
# yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
|
||||
# Failsafe to make sure we don't try to handle OAuth flow twice
|
||||
if not oauth_flow_attempted:
|
||||
async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request):
|
||||
yield event
|
||||
return
|
||||
except Exception as e:
|
||||
detailed_error = drill_down_exception(e)
|
||||
logger.error(f"Error in OAuth stream:\n{detailed_error}")
|
||||
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
|
||||
|
||||
# finally:
|
||||
# if client:
|
||||
# try:
|
||||
# await client.cleanup()
|
||||
# except Exception as cleanup_error:
|
||||
# logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|
||||
finally:
|
||||
if client:
|
||||
try:
|
||||
await client.cleanup()
|
||||
except Exception as cleanup_error:
|
||||
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|
||||
|
||||
# return StreamingResponseWithStatusCode(oauth_stream_generator(request, http_request), media_type="text/event-stream")
|
||||
return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream")
|
||||
|
||||
@@ -94,6 +94,7 @@ from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
from letta.services.mcp_server_manager import MCPServerManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -154,6 +155,7 @@ class SyncServer(object):
|
||||
self.user_manager = UserManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.mcp_manager = MCPManager()
|
||||
self.mcp_server_manager = MCPServerManager()
|
||||
self.block_manager = BlockManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.sandbox_config_manager = SandboxConfigManager()
|
||||
|
||||
1331
letta/services/mcp_server_manager.py
Normal file
1331
letta/services/mcp_server_manager.py
Normal file
File diff suppressed because it is too large
Load Diff
1050
tests/integration_test_mcp_servers.py
Normal file
1050
tests/integration_test_mcp_servers.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user