fix: test async context fix for mcp clients (#2880)
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from contextlib import AsyncExitStack
|
||||
from typing import Optional, Tuple
|
||||
|
||||
@@ -18,6 +19,9 @@ class AsyncBaseMCPClient:
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.initialized = False
|
||||
# Track the task that created this client
|
||||
self._creation_task = asyncio.current_task()
|
||||
self._cleanup_queue = asyncio.Queue(maxsize=1)
|
||||
|
||||
async def connect_to_server(self):
|
||||
try:
|
||||
@@ -74,8 +78,29 @@ class AsyncBaseMCPClient:
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
"""Clean up resources - ensure this runs in the same task"""
|
||||
if hasattr(self, "_cleanup_task"):
|
||||
# If we're in a different task, schedule cleanup in original task
|
||||
current_task = asyncio.current_task()
|
||||
if current_task != self._creation_task:
|
||||
# Create a future to signal completion
|
||||
cleanup_done = asyncio.Future()
|
||||
self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done))
|
||||
await cleanup_done
|
||||
return
|
||||
|
||||
# Normal cleanup
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
def to_sync_client(self):
|
||||
raise NotImplementedError("Subclasses must implement to_sync_client")
|
||||
|
||||
async def __aenter__(self):
|
||||
"""Enter the async context manager."""
|
||||
await self.connect_to_server()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Exit the async context manager."""
|
||||
await self.cleanup()
|
||||
return False # Don't suppress exceptions
|
||||
|
||||
@@ -75,23 +75,23 @@ class MCPManager:
|
||||
server_config = mcp_config[mcp_server_name]
|
||||
|
||||
if isinstance(server_config, SSEServerConfig):
|
||||
mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
# mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
async with AsyncSSEMCPClient(server_config=server_config) as mcp_client:
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
return result, success
|
||||
elif isinstance(server_config, StdioServerConfig):
|
||||
mcp_client = AsyncStdioMCPClient(server_config=server_config)
|
||||
async with AsyncStdioMCPClient(server_config=server_config) as mcp_client:
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
return result, success
|
||||
elif isinstance(server_config, StreamableHTTPServerConfig):
|
||||
mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config)
|
||||
async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client:
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
return result, success
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# call tool
|
||||
result, success = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
logger.info(f"MCP Result: {result}, Success: {success}")
|
||||
# TODO: change to pydantic tool
|
||||
|
||||
await mcp_client.cleanup()
|
||||
|
||||
return result, success
|
||||
|
||||
@enforce_types
|
||||
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
|
||||
Reference in New Issue
Block a user