fix: test_agent_serialization_v2.py and use bulk fetch when fetching mcp servers

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-07-25 11:13:05 -07:00
committed by GitHub
parent dbb8996442
commit 82fef362be
4 changed files with 26 additions and 14 deletions

View File

@@ -1,6 +1,7 @@
from datetime import datetime, timezone from datetime import datetime, timezone
from typing import Dict, List from typing import Dict, List
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
from letta.errors import AgentFileExportError, AgentFileImportError from letta.errors import AgentFileExportError, AgentFileImportError
from letta.helpers.pinecone_utils import should_use_pinecone from letta.helpers.pinecone_utils import should_use_pinecone
from letta.log import get_logger from letta.log import get_logger
@@ -262,8 +263,6 @@ class AgentSerializationManager:
async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List: async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
"""Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name.""" """Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
mcp_server_ids = set() mcp_server_ids = set()
mcp_server_names = set() mcp_server_names = set()
for tool in tools: for tool in tools:
@@ -280,14 +279,11 @@ class AgentSerializationManager:
mcp_servers = [] mcp_servers = []
fetched_server_ids = set() fetched_server_ids = set()
if mcp_server_ids: if mcp_server_ids:
for server_id in mcp_server_ids: try:
try: mcp_servers = await self.mcp_manager.get_mcp_servers_by_ids(list(mcp_server_ids), actor)
mcp_server = await self.mcp_manager.get_mcp_server_by_id_async(server_id, actor) fetched_server_ids.update([mcp_server.id for mcp_server in mcp_servers])
if mcp_server: except Exception as e:
mcp_servers.append(mcp_server) logger.warning(f"Failed to fetch MCP servers by IDs {mcp_server_ids}: {e}")
fetched_server_ids.add(server_id)
except Exception as e:
logger.warning(f"Failed to fetch MCP server {server_id}: {e}")
# Fetch MCP servers by name if not already fetched by ID # Fetch MCP servers by name if not already fetched by ID
if mcp_server_names: if mcp_server_names:

View File

@@ -5,6 +5,7 @@ import uuid
from datetime import datetime, timedelta from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Union from typing import Any, Dict, List, Optional, Tuple, Union
from fastapi import HTTPException
from sqlalchemy import null from sqlalchemy import null
import letta.constants as constants import letta.constants as constants
@@ -199,7 +200,14 @@ class MCPManager:
"""Update an MCP server by its name.""" """Update an MCP server by its name."""
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
if not mcp_server_id: if not mcp_server_id:
raise ValueError(f"MCP server {mcp_server_name} not found") raise HTTPException(
status_code=404,
detail={
"code": "MCPServerNotFoundError",
"message": f"MCP server {mcp_server_name} not found",
"mcp_server_name": mcp_server_name,
},
)
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor) return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
@enforce_types @enforce_types
@@ -240,7 +248,14 @@ class MCPManager:
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor) mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor) mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
if not mcp_server: if not mcp_server:
raise ValueError(f"MCP server {mcp_server_name} not found") raise HTTPException(
status_code=404, # Not Found
detail={
"code": "MCPServerNotFoundError",
"message": f"MCP server {mcp_server_name} not found",
"mcp_server_name": mcp_server_name,
},
)
return mcp_server.to_pydantic() return mcp_server.to_pydantic()
# @enforce_types # @enforce_types

View File

@@ -1092,6 +1092,7 @@ class TestAgentFileImport:
files=[], files=[],
sources=[], sources=[],
tools=[], tools=[],
mcp_servers=[],
) )
with pytest.raises(AgentFileImportError): with pytest.raises(AgentFileImportError):

View File

@@ -8579,8 +8579,8 @@ async def test_get_mcp_servers_by_ids(server, default_user, event_loop):
}, },
{ {
"name": "test_server_3", "name": "test_server_3",
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/sse"), "config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/mcp"),
"type": MCPServerType.SSE, "type": MCPServerType.STREAMABLE_HTTP,
}, },
] ]