fix: test_agent_serialization_v2.py and use bulk fetch when fetching mcp servers

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-07-25 11:13:05 -07:00
committed by GitHub
parent dbb8996442
commit 82fef362be
4 changed files with 26 additions and 14 deletions

View File

@@ -1,6 +1,7 @@
from datetime import datetime, timezone
from typing import Dict, List
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
from letta.errors import AgentFileExportError, AgentFileImportError
from letta.helpers.pinecone_utils import should_use_pinecone
from letta.log import get_logger
@@ -262,8 +263,6 @@ class AgentSerializationManager:
async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
"""Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
mcp_server_ids = set()
mcp_server_names = set()
for tool in tools:
@@ -280,14 +279,11 @@ class AgentSerializationManager:
mcp_servers = []
fetched_server_ids = set()
if mcp_server_ids:
for server_id in mcp_server_ids:
try:
mcp_server = await self.mcp_manager.get_mcp_server_by_id_async(server_id, actor)
if mcp_server:
mcp_servers.append(mcp_server)
fetched_server_ids.add(server_id)
except Exception as e:
logger.warning(f"Failed to fetch MCP server {server_id}: {e}")
try:
mcp_servers = await self.mcp_manager.get_mcp_servers_by_ids(list(mcp_server_ids), actor)
fetched_server_ids.update([mcp_server.id for mcp_server in mcp_servers])
except Exception as e:
logger.warning(f"Failed to fetch MCP servers by IDs {mcp_server_ids}: {e}")
# Fetch MCP servers by name if not already fetched by ID
if mcp_server_names:

View File

@@ -5,6 +5,7 @@ import uuid
from datetime import datetime, timedelta
from typing import Any, Dict, List, Optional, Tuple, Union
from fastapi import HTTPException
from sqlalchemy import null
import letta.constants as constants
@@ -199,7 +200,14 @@ class MCPManager:
"""Update an MCP server by its name."""
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
if not mcp_server_id:
raise ValueError(f"MCP server {mcp_server_name} not found")
raise HTTPException(
status_code=404,
detail={
"code": "MCPServerNotFoundError",
"message": f"MCP server {mcp_server_name} not found",
"mcp_server_name": mcp_server_name,
},
)
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
@enforce_types
@@ -240,7 +248,14 @@ class MCPManager:
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
if not mcp_server:
raise ValueError(f"MCP server {mcp_server_name} not found")
raise HTTPException(
status_code=404, # Not Found
detail={
"code": "MCPServerNotFoundError",
"message": f"MCP server {mcp_server_name} not found",
"mcp_server_name": mcp_server_name,
},
)
return mcp_server.to_pydantic()
# @enforce_types

View File

@@ -1092,6 +1092,7 @@ class TestAgentFileImport:
files=[],
sources=[],
tools=[],
mcp_servers=[],
)
with pytest.raises(AgentFileImportError):

View File

@@ -8579,8 +8579,8 @@ async def test_get_mcp_servers_by_ids(server, default_user, event_loop):
},
{
"name": "test_server_3",
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/sse"),
"type": MCPServerType.SSE,
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/mcp"),
"type": MCPServerType.STREAMABLE_HTTP,
},
]