fix: test async context fix for mcp clients (#2880)

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-06-18 09:52:58 -07:00
committed by GitHub
parent 76b9dc1599
commit c6aca63d56
2 changed files with 39 additions and 14 deletions

View File

@@ -1,3 +1,4 @@
import asyncio
from contextlib import AsyncExitStack
from typing import Optional, Tuple
@@ -18,6 +19,9 @@ class AsyncBaseMCPClient:
self.exit_stack = AsyncExitStack()
self.session: Optional[ClientSession] = None
self.initialized = False
# Track the task that created this client
self._creation_task = asyncio.current_task()
self._cleanup_queue = asyncio.Queue(maxsize=1)
async def connect_to_server(self):
try:
@@ -74,8 +78,29 @@ class AsyncBaseMCPClient:
raise RuntimeError("MCPClient has not been initialized")
async def cleanup(self):
"""Clean up resources"""
"""Clean up resources - ensure this runs in the same task"""
if hasattr(self, "_cleanup_task"):
# If we're in a different task, schedule cleanup in original task
current_task = asyncio.current_task()
if current_task != self._creation_task:
# Create a future to signal completion
cleanup_done = asyncio.Future()
self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done))
await cleanup_done
return
# Normal cleanup
await self.exit_stack.aclose()
def to_sync_client(self):
raise NotImplementedError("Subclasses must implement to_sync_client")
async def __aenter__(self):
"""Enter the async context manager."""
await self.connect_to_server()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Exit the async context manager."""
await self.cleanup()
return False # Don't suppress exceptions

View File

@@ -75,23 +75,23 @@ class MCPManager:
server_config = mcp_config[mcp_server_name]
if isinstance(server_config, SSEServerConfig):
mcp_client = AsyncSSEMCPClient(server_config=server_config)
# mcp_client = AsyncSSEMCPClient(server_config=server_config)
async with AsyncSSEMCPClient(server_config=server_config) as mcp_client:
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
return result, success
elif isinstance(server_config, StdioServerConfig):
mcp_client = AsyncStdioMCPClient(server_config=server_config)
async with AsyncStdioMCPClient(server_config=server_config) as mcp_client:
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
return result, success
elif isinstance(server_config, StreamableHTTPServerConfig):
mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config)
async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client:
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
return result, success
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")
await mcp_client.connect_to_server()
# call tool
result, success = await mcp_client.execute_tool(tool_name, tool_args)
logger.info(f"MCP Result: {result}, Success: {success}")
# TODO: change to pydantic tool
await mcp_client.cleanup()
return result, success
@enforce_types
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool: