diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index 166d1608..3cb1ca02 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -1,6 +1,7 @@ from datetime import datetime, timezone from typing import Dict, List +from letta.constants import MCP_TOOL_TAG_NAME_PREFIX from letta.errors import AgentFileExportError, AgentFileImportError from letta.helpers.pinecone_utils import should_use_pinecone from letta.log import get_logger @@ -262,8 +263,6 @@ class AgentSerializationManager: async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List: """Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name.""" - from letta.constants import MCP_TOOL_TAG_NAME_PREFIX - mcp_server_ids = set() mcp_server_names = set() for tool in tools: @@ -280,14 +279,11 @@ class AgentSerializationManager: mcp_servers = [] fetched_server_ids = set() if mcp_server_ids: - for server_id in mcp_server_ids: - try: - mcp_server = await self.mcp_manager.get_mcp_server_by_id_async(server_id, actor) - if mcp_server: - mcp_servers.append(mcp_server) - fetched_server_ids.add(server_id) - except Exception as e: - logger.warning(f"Failed to fetch MCP server {server_id}: {e}") + try: + mcp_servers = await self.mcp_manager.get_mcp_servers_by_ids(list(mcp_server_ids), actor) + fetched_server_ids.update([mcp_server.id for mcp_server in mcp_servers]) + except Exception as e: + logger.warning(f"Failed to fetch MCP servers by IDs {mcp_server_ids}: {e}") # Fetch MCP servers by name if not already fetched by ID if mcp_server_names: diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index a075b5d0..02d37a85 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime, timedelta from typing import Any, Dict, List, Optional, Tuple, Union +from fastapi import HTTPException from sqlalchemy import null import letta.constants as constants @@ -199,7 +200,14 @@ class MCPManager: """Update an MCP server by its name.""" mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) if not mcp_server_id: - raise ValueError(f"MCP server {mcp_server_name} not found") + raise HTTPException( + status_code=404, + detail={ + "code": "MCPServerNotFoundError", + "message": f"MCP server {mcp_server_name} not found", + "mcp_server_name": mcp_server_name, + }, + ) return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor) @enforce_types @@ -240,7 +248,14 @@ class MCPManager: mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) if not mcp_server: - raise ValueError(f"MCP server {mcp_server_name} not found") + raise HTTPException( + status_code=404, # Not Found + detail={ + "code": "MCPServerNotFoundError", + "message": f"MCP server {mcp_server_name} not found", + "mcp_server_name": mcp_server_name, + }, + ) return mcp_server.to_pydantic() # @enforce_types diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index 10f5480c..dd0e3b0d 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -1092,6 +1092,7 @@ class TestAgentFileImport: files=[], sources=[], tools=[], + mcp_servers=[], ) with pytest.raises(AgentFileImportError): diff --git a/tests/test_managers.py b/tests/test_managers.py index 395f3ebc..4b4b0c5f 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -8579,8 +8579,8 @@ async def test_get_mcp_servers_by_ids(server, default_user, event_loop): }, { "name": "test_server_3", - "config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/sse"), - "type": MCPServerType.SSE, + "config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/mcp"), + "type": MCPServerType.STREAMABLE_HTTP, }, ]