From 76b9dc1599baaec8a710b06587ed457f98e8e5aa Mon Sep 17 00:00:00 2001 From: jnjpng Date: Tue, 17 Jun 2025 16:19:27 -0700 Subject: [PATCH] fix: mcp fixes and update flow (#2851) Co-authored-by: Jin Peng --- letta/server/rest_api/routers/v1/tools.py | 133 ++++++++++++++++------ letta/services/mcp_manager.py | 25 +++- 2 files changed, 118 insertions(+), 40 deletions(-) diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index afc14166..4519402b 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -13,11 +13,12 @@ from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from letta.errors import LettaToolCreateError from letta.functions.mcp_client.exceptions import MCPTimeoutError -from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig +from letta.functions.mcp_client.types import MCPTool, SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig from letta.helpers.composio_helpers import get_composio_api_key from letta.log import get_logger from letta.orm.errors import UniqueConstraintViolationError from letta.schemas.letta_message import ToolReturnMessage +from letta.schemas.mcp import UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -325,15 +326,6 @@ async def add_composio_tool( "composio_action_name": composio_action_name, }, ) - except ComposioClientError as e: - raise HTTPException( - status_code=400, # Bad Request - detail={ - "code": "ComposioClientError", - "message": str(e), - "composio_action_name": composio_action_name, - }, - ) except ApiKeyNotProvidedError as e: raise HTTPException( status_code=400, # Bad Request @@ -343,6 +335,15 @@ async def add_composio_tool( "composio_action_name": composio_action_name, }, ) + except ComposioClientError as e: + raise HTTPException( + status_code=400, # Bad Request + detail={ + "code": "ComposioClientError", + "message": str(e), + "composio_action_name": composio_action_name, + }, + ) except ComposioSDKError as e: raise HTTPException( status_code=400, # Bad Request @@ -355,7 +356,11 @@ async def add_composio_tool( # Specific routes for MCP -@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, StdioServerConfig]], operation_id="list_mcp_servers") +@router.get( + "/mcp/servers", + response_model=dict[str, Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig]], + operation_id="list_mcp_servers", +) async def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")): """ Get a list of all configured MCP servers @@ -466,44 +471,102 @@ async def add_mcp_tool( return await server.mcp_manager.add_tool_from_mcp_server(mcp_server_name=mcp_server_name, mcp_tool_name=mcp_tool_name, actor=actor) -@router.put("/mcp/servers", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="add_mcp_server") +@router.put( + "/mcp/servers", + response_model=List[Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig]], + operation_id="add_mcp_server", +) async def add_mcp_server_to_config( - request: Union[StdioServerConfig, SSEServerConfig] = Body(...), + request: Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig] = Body(...), server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Add a new MCP server to the Letta MCP server config """ + try: + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) - actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + if tool_settings.mcp_read_from_config: + # write to config file + return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True) + else: + # log to DB + from letta.schemas.mcp import MCPServer - if tool_settings.mcp_read_from_config: - # write to config file - return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True) - else: - # log to DB - from letta.schemas.mcp import MCPServer + if isinstance(request, StdioServerConfig): + mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, stdio_config=request) + # don't allow stdio servers + if tool_settings.mcp_disable_stdio: # protected server + raise HTTPException( + status_code=400, + detail="stdio is not supported in the current environment, please use a self-hosted Letta server in order to add a stdio MCP server", + ) + elif isinstance(request, SSEServerConfig): + mapped_request = MCPServer( + server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token() + ) + elif isinstance(request, StreamableHTTPServerConfig): + mapped_request = MCPServer( + server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token() + ) - if isinstance(request, StdioServerConfig): - mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, stdio_config=request) - # don't allow stdio servers - if tool_settings.mcp_disable_stdio: # protected server - raise HTTPException(status_code=400, detail="StdioServerConfig is not supported") - elif isinstance(request, SSEServerConfig): - mapped_request = MCPServer( - server_name=request.server_name, server_type=request.type, server_url=request.server_url, token=request.resolve_token() + await server.mcp_manager.create_mcp_server(mapped_request, actor=actor) + + # TODO: don't do this in the future (just return MCPServer) + all_servers = await server.mcp_manager.list_mcp_servers(actor=actor) + return [server.to_config() for server in all_servers] + except UniqueConstraintViolationError: + # If server name already exists, throw 409 conflict error + raise HTTPException( + status_code=409, + detail={ + "code": "MCPServerNameAlreadyExistsError", + "message": f"MCP server with name '{request.server_name}' already exists", + "server_name": request.server_name, + }, + ) + except Exception as e: + print(f"Unexpected error occurred while adding MCP server: {e}") + raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") + + +@router.patch( + "/mcp/servers/{mcp_server_name}", + response_model=Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig], + operation_id="update_mcp_server", +) +async def update_mcp_server( + mcp_server_name: str, + request: Union[UpdateSSEMCPServer, UpdateStreamableHTTPMCPServer] = Body(...), + server: SyncServer = Depends(get_letta_server), + actor_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Update an existing MCP server configuration + """ + try: + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) + + if tool_settings.mcp_read_from_config: + raise HTTPException(status_code=501, detail="Update not implemented for config file mode, config files to be deprecated.") + else: + updated_server = await server.mcp_manager.update_mcp_server_by_name( + mcp_server_name=mcp_server_name, mcp_server_update=request, actor=actor ) - # TODO: add HTTP streaming - mcp_server = await server.mcp_manager.create_or_update_mcp_server(mapped_request, actor=actor) - - # TODO: don't do this in the future (just return MCPServer) - all_servers = await server.mcp_manager.list_mcp_servers(actor=actor) - return [server.to_config() for server in all_servers] + return updated_server.to_config() + except HTTPException: + # Re-raise HTTP exceptions (like 404) + raise + except Exception as e: + print(f"Unexpected error occurred while updating MCP server: {e}") + raise HTTPException(status_code=500, detail=f"An unexpected error occurred: {str(e)}") @router.delete( - "/mcp/servers/{mcp_server_name}", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="delete_mcp_server" + "/mcp/servers/{mcp_server_name}", + response_model=List[Union[StdioServerConfig, SSEServerConfig, StreamableHTTPServerConfig]], + operation_id="delete_mcp_server", ) async def delete_mcp_server_from_config( mcp_server_name: str, diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 8a5f8af3..9763e553 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -149,19 +149,19 @@ class MCPManager: return mcp_server @enforce_types - async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> PydanticTool: - """Create a new tool based on the ToolCreate schema.""" - with db_registry.session() as session: + async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer: + """Create a new MCP server.""" + async with db_registry.async_session() as session: # Set the organization id at the ORM layer pydantic_mcp_server.organization_id = actor.organization_id mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True) mcp_server = MCPServerModel(**mcp_server_data) - mcp_server.create(session, actor=actor) # Re-raise other database-related errors + mcp_server = await mcp_server.create_async(session, actor=actor) return mcp_server.to_pydantic() @enforce_types - async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> PydanticTool: + async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer: """Update a tool by its ID with the given ToolUpdate object.""" async with db_registry.async_session() as session: # Fetch the tool by ID @@ -177,6 +177,21 @@ class MCPManager: # Save the updated tool to the database mcp_server = await mcp_server.update_async(db_session=session, actor=actor) return mcp_server.to_pydantic() + @enforce_types + async def update_mcp_server_by_name(self, mcp_server_name: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> MCPServer: + """Update an MCP server by its name.""" + mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) + if not mcp_server_id: + raise HTTPException( + status_code=404, + detail={ + "code": "MCPServerNotFoundError", + "message": f"MCP server {mcp_server_name} not found", + "mcp_server_name": mcp_server_name, + }, + ) + return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor) + @enforce_types async def get_mcp_server_id_by_name(self, mcp_server_name: str, actor: PydanticUser) -> Optional[str]: """Retrieve a MCP server by its name and a user"""