fix: test_agent_serialization_v2.py and use bulk fetch when fetching mcp servers
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Dict, List
|
from typing import Dict, List
|
||||||
|
|
||||||
|
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
|
||||||
from letta.errors import AgentFileExportError, AgentFileImportError
|
from letta.errors import AgentFileExportError, AgentFileImportError
|
||||||
from letta.helpers.pinecone_utils import should_use_pinecone
|
from letta.helpers.pinecone_utils import should_use_pinecone
|
||||||
from letta.log import get_logger
|
from letta.log import get_logger
|
||||||
@@ -262,8 +263,6 @@ class AgentSerializationManager:
|
|||||||
|
|
||||||
async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
|
async def _extract_unique_mcp_servers(self, tools: List, actor: User) -> List:
|
||||||
"""Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
|
"""Extract unique MCP servers from tools based on metadata, using server_id if available, otherwise falling back to server_name."""
|
||||||
from letta.constants import MCP_TOOL_TAG_NAME_PREFIX
|
|
||||||
|
|
||||||
mcp_server_ids = set()
|
mcp_server_ids = set()
|
||||||
mcp_server_names = set()
|
mcp_server_names = set()
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
@@ -280,14 +279,11 @@ class AgentSerializationManager:
|
|||||||
mcp_servers = []
|
mcp_servers = []
|
||||||
fetched_server_ids = set()
|
fetched_server_ids = set()
|
||||||
if mcp_server_ids:
|
if mcp_server_ids:
|
||||||
for server_id in mcp_server_ids:
|
try:
|
||||||
try:
|
mcp_servers = await self.mcp_manager.get_mcp_servers_by_ids(list(mcp_server_ids), actor)
|
||||||
mcp_server = await self.mcp_manager.get_mcp_server_by_id_async(server_id, actor)
|
fetched_server_ids.update([mcp_server.id for mcp_server in mcp_servers])
|
||||||
if mcp_server:
|
except Exception as e:
|
||||||
mcp_servers.append(mcp_server)
|
logger.warning(f"Failed to fetch MCP servers by IDs {mcp_server_ids}: {e}")
|
||||||
fetched_server_ids.add(server_id)
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to fetch MCP server {server_id}: {e}")
|
|
||||||
|
|
||||||
# Fetch MCP servers by name if not already fetched by ID
|
# Fetch MCP servers by name if not already fetched by ID
|
||||||
if mcp_server_names:
|
if mcp_server_names:
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import uuid
|
|||||||
from datetime import datetime, timedelta
|
from datetime import datetime, timedelta
|
||||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
from sqlalchemy import null
|
from sqlalchemy import null
|
||||||
|
|
||||||
import letta.constants as constants
|
import letta.constants as constants
|
||||||
@@ -199,7 +200,14 @@ class MCPManager:
|
|||||||
"""Update an MCP server by its name."""
|
"""Update an MCP server by its name."""
|
||||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||||
if not mcp_server_id:
|
if not mcp_server_id:
|
||||||
raise ValueError(f"MCP server {mcp_server_name} not found")
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={
|
||||||
|
"code": "MCPServerNotFoundError",
|
||||||
|
"message": f"MCP server {mcp_server_name} not found",
|
||||||
|
"mcp_server_name": mcp_server_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
|
return await self.update_mcp_server_by_id(mcp_server_id, mcp_server_update, actor)
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
@@ -240,7 +248,14 @@ class MCPManager:
|
|||||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||||
if not mcp_server:
|
if not mcp_server:
|
||||||
raise ValueError(f"MCP server {mcp_server_name} not found")
|
raise HTTPException(
|
||||||
|
status_code=404, # Not Found
|
||||||
|
detail={
|
||||||
|
"code": "MCPServerNotFoundError",
|
||||||
|
"message": f"MCP server {mcp_server_name} not found",
|
||||||
|
"mcp_server_name": mcp_server_name,
|
||||||
|
},
|
||||||
|
)
|
||||||
return mcp_server.to_pydantic()
|
return mcp_server.to_pydantic()
|
||||||
|
|
||||||
# @enforce_types
|
# @enforce_types
|
||||||
|
|||||||
@@ -1092,6 +1092,7 @@ class TestAgentFileImport:
|
|||||||
files=[],
|
files=[],
|
||||||
sources=[],
|
sources=[],
|
||||||
tools=[],
|
tools=[],
|
||||||
|
mcp_servers=[],
|
||||||
)
|
)
|
||||||
|
|
||||||
with pytest.raises(AgentFileImportError):
|
with pytest.raises(AgentFileImportError):
|
||||||
|
|||||||
@@ -8579,8 +8579,8 @@ async def test_get_mcp_servers_by_ids(server, default_user, event_loop):
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "test_server_3",
|
"name": "test_server_3",
|
||||||
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/sse"),
|
"config": SSEServerConfig(server_name="test_server_3", server_url="https://test3.example.com/mcp"),
|
||||||
"type": MCPServerType.SSE,
|
"type": MCPServerType.STREAMABLE_HTTP,
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user