* auto fixes * auto fix pt2 and transitive deps and undefined var checking locals() * manual fixes (ignored or letta-code fixed) * fix circular import
311 lines
12 KiB
Python
311 lines
12 KiB
Python
from typing import AsyncGenerator, List, Optional, Union
|
|
|
|
from fastapi import APIRouter, Body, Depends, Request
|
|
from httpx import HTTPStatusError
|
|
from starlette.responses import StreamingResponse
|
|
|
|
from letta.errors import LettaMCPConnectionError
|
|
from letta.functions.mcp_client.types import SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig
|
|
from letta.log import get_logger
|
|
from letta.schemas.mcp_server import (
|
|
CreateMCPServerRequest,
|
|
MCPServerUnion,
|
|
ToolExecuteRequest,
|
|
UpdateMCPServerRequest,
|
|
convert_generic_to_union,
|
|
convert_update_to_internal,
|
|
)
|
|
from letta.schemas.tool import Tool
|
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
|
from letta.server.rest_api.dependencies import (
|
|
HeaderParams,
|
|
get_headers,
|
|
get_letta_server,
|
|
)
|
|
from letta.server.rest_api.streaming_response import StreamingResponseWithStatusCode
|
|
from letta.server.server import SyncServer
|
|
from letta.services.mcp.oauth_utils import drill_down_exception, oauth_stream_event
|
|
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
|
from letta.services.mcp.types import OauthStreamEvent
|
|
|
|
router = APIRouter(prefix="/mcp-servers", tags=["mcp-servers"])
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
@router.post(
|
|
"/",
|
|
response_model=MCPServerUnion,
|
|
operation_id="mcp_create_mcp_server",
|
|
)
|
|
async def create_mcp_server(
|
|
request: CreateMCPServerRequest = Body(...),
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Add a new MCP server to the Letta MCP server config
|
|
"""
|
|
# TODO: add the tools to the MCP server table we made.
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
new_server = await server.mcp_server_manager.create_mcp_server_from_request(request, actor=actor)
|
|
return await convert_generic_to_union(new_server)
|
|
|
|
|
|
@router.get(
|
|
"/",
|
|
response_model=List[MCPServerUnion],
|
|
operation_id="mcp_list_mcp_servers",
|
|
)
|
|
async def list_mcp_servers(
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a list of all configured MCP servers
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
mcp_servers = await server.mcp_server_manager.list_mcp_servers(actor=actor)
|
|
result = []
|
|
for mcp_server in mcp_servers:
|
|
result.append(await convert_generic_to_union(mcp_server))
|
|
return result
|
|
|
|
|
|
@router.get(
|
|
"/{mcp_server_id}",
|
|
response_model=MCPServerUnion,
|
|
operation_id="mcp_retrieve_mcp_server",
|
|
)
|
|
async def retrieve_mcp_server(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a specific MCP server
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
current_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
|
return await convert_generic_to_union(current_server)
|
|
|
|
|
|
@router.delete(
|
|
"/{mcp_server_id}",
|
|
status_code=204,
|
|
operation_id="mcp_delete_mcp_server",
|
|
)
|
|
async def delete_mcp_server(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Delete an MCP server by its ID
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
await server.mcp_server_manager.delete_mcp_server_by_id(mcp_server_id, actor=actor)
|
|
|
|
|
|
@router.patch(
|
|
"/{mcp_server_id}",
|
|
response_model=MCPServerUnion,
|
|
operation_id="mcp_update_mcp_server",
|
|
)
|
|
async def update_mcp_server(
|
|
mcp_server_id: str,
|
|
request: UpdateMCPServerRequest = Body(...),
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Update an existing MCP server configuration
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
# Convert external update payload to internal manager union
|
|
internal_update = convert_update_to_internal(request)
|
|
updated_server = await server.mcp_server_manager.update_mcp_server_by_id(
|
|
mcp_server_id=mcp_server_id, mcp_server_update=internal_update, actor=actor
|
|
)
|
|
return await convert_generic_to_union(updated_server)
|
|
|
|
|
|
@router.get("/{mcp_server_id}/tools", response_model=List[Tool], operation_id="mcp_list_tools_for_mcp_server")
|
|
async def list_tools_for_mcp_server(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a list of all tools for a specific MCP server
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
# Use the new efficient method that queries from the database using MCPTools mapping
|
|
tools = await server.mcp_server_manager.list_tools_by_mcp_server_from_db(mcp_server_id, actor=actor)
|
|
return tools
|
|
|
|
|
|
@router.get("/{mcp_server_id}/tools/{tool_id}", response_model=Tool, operation_id="mcp_retrieve_mcp_tool")
|
|
async def retrieve_mcp_tool(
|
|
mcp_server_id: str,
|
|
tool_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
):
|
|
"""
|
|
Get a specific MCP tool by its ID
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
tool = await server.mcp_server_manager.get_tool_by_mcp_server(mcp_server_id, tool_id, actor=actor)
|
|
return tool
|
|
|
|
|
|
@router.post("/{mcp_server_id}/tools/{tool_id}/run", response_model=ToolExecutionResult, operation_id="mcp_run_tool")
|
|
async def run_mcp_tool(
|
|
mcp_server_id: str,
|
|
tool_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
request: ToolExecuteRequest = Body(default=ToolExecuteRequest()),
|
|
):
|
|
"""
|
|
Execute a specific MCP tool
|
|
|
|
The request body should contain the tool arguments in the ToolExecuteRequest format.
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
|
|
# Execute the tool
|
|
result, success = await server.mcp_server_manager.execute_mcp_server_tool(
|
|
mcp_server_id=mcp_server_id,
|
|
tool_id=tool_id,
|
|
tool_args=request.args,
|
|
environment_variables={}, # TODO: Get environment variables from somewhere if needed
|
|
actor=actor,
|
|
)
|
|
|
|
# Create a ToolExecutionResult
|
|
return ToolExecutionResult(
|
|
status="success" if success else "error",
|
|
func_return=result,
|
|
)
|
|
|
|
|
|
@router.patch("/{mcp_server_id}/refresh", operation_id="mcp_refresh_mcp_server_tools")
|
|
async def refresh_mcp_server_tools(
|
|
mcp_server_id: str,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
agent_id: Optional[str] = None,
|
|
):
|
|
"""
|
|
Refresh tools for an MCP server by:
|
|
1. Fetching current tools from the MCP server
|
|
2. Deleting tools that no longer exist on the server
|
|
3. Updating schemas for existing tools
|
|
4. Adding new tools from the server
|
|
|
|
Returns a summary of changes made.
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
result = await server.mcp_server_manager.resync_mcp_server_tools(mcp_server_id, actor=actor, agent_id=agent_id)
|
|
return result
|
|
|
|
|
|
@router.get(
|
|
"/connect/{mcp_server_id}",
|
|
response_model=None,
|
|
# TODO: make this into a model?
|
|
responses={
|
|
200: {
|
|
"description": "Successful response",
|
|
"content": {
|
|
"text/event-stream": {"description": "Server-Sent Events stream"},
|
|
},
|
|
}
|
|
},
|
|
operation_id="mcp_connect_mcp_server",
|
|
)
|
|
async def connect_mcp_server(
|
|
mcp_server_id: str,
|
|
request: Request,
|
|
server: SyncServer = Depends(get_letta_server),
|
|
headers: HeaderParams = Depends(get_headers),
|
|
) -> StreamingResponse:
|
|
"""
|
|
Connect to an MCP server with support for OAuth via SSE.
|
|
Returns a stream of events handling authorization state and exchange if OAuth is required.
|
|
"""
|
|
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
|
mcp_server = await server.mcp_server_manager.get_mcp_server_by_id_async(mcp_server_id=mcp_server_id, actor=actor)
|
|
|
|
# Convert the MCP server to the appropriate config type
|
|
config = await mcp_server.to_config_async(resolve_variables=False)
|
|
|
|
async def oauth_stream_generator(
|
|
mcp_config: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig],
|
|
http_request: Request,
|
|
) -> AsyncGenerator[str, None]:
|
|
client = None
|
|
|
|
oauth_flow_attempted = False
|
|
try:
|
|
# Acknowledge connection attempt
|
|
yield oauth_stream_event(OauthStreamEvent.CONNECTION_ATTEMPT, server_name=mcp_config.server_name)
|
|
|
|
# Create MCP client with respective transport type
|
|
try:
|
|
mcp_config.resolve_environment_variables()
|
|
client = await server.mcp_server_manager.get_mcp_client(mcp_config, actor)
|
|
except ValueError as e:
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message=str(e))
|
|
return
|
|
|
|
# Try normal connection first for flows that don't require OAuth
|
|
try:
|
|
await client.connect_to_server()
|
|
tools = await client.list_tools(serialize=True)
|
|
yield oauth_stream_event(OauthStreamEvent.SUCCESS, tools=tools)
|
|
return
|
|
except (ConnectionError, LettaMCPConnectionError):
|
|
if isinstance(client, AsyncStdioMCPClient):
|
|
logger.warning("OAuth not supported for stdio")
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message="OAuth not supported for stdio")
|
|
return
|
|
# Continue to OAuth flow
|
|
logger.info(f"Attempting OAuth flow for {mcp_config}...")
|
|
except Exception as e:
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Connection failed: {str(e)}")
|
|
return
|
|
finally:
|
|
if client:
|
|
try:
|
|
await client.cleanup()
|
|
# This is a workaround to catch the expected 401 Unauthorized from the official MCP SDK, see their streamable_http.py
|
|
# For SSE transport types, we catch the ConnectionError above, but Streamable HTTP doesn't bubble up the exception
|
|
except HTTPStatusError:
|
|
oauth_flow_attempted = True
|
|
async for event in server.mcp_server_manager.handle_oauth_flow(
|
|
request=mcp_config, actor=actor, http_request=http_request
|
|
):
|
|
yield event
|
|
|
|
# Failsafe to make sure we don't try to handle OAuth flow twice
|
|
if not oauth_flow_attempted:
|
|
async for event in server.mcp_server_manager.handle_oauth_flow(request=mcp_config, actor=actor, http_request=http_request):
|
|
yield event
|
|
return
|
|
except Exception as e:
|
|
detailed_error = drill_down_exception(e)
|
|
logger.error(f"Error in OAuth stream:\n{detailed_error}")
|
|
yield oauth_stream_event(OauthStreamEvent.ERROR, message=f"Internal error: {detailed_error}")
|
|
|
|
finally:
|
|
if client:
|
|
try:
|
|
await client.cleanup()
|
|
except Exception as cleanup_error:
|
|
logger.warning(f"Error during temp MCP client cleanup: {cleanup_error}")
|
|
|
|
return StreamingResponseWithStatusCode(oauth_stream_generator(config, request), media_type="text/event-stream")
|