From 7315c133ef14a67fa2cdfa87f8fb307b12fd6602 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Wed, 18 Jun 2025 09:52:58 -0700 Subject: [PATCH] fix: test async context fix for mcp clients (#2880) Co-authored-by: Jin Peng --- letta/services/mcp/base_client.py | 27 ++++++++++++++++++++++++++- letta/services/mcp_manager.py | 26 +++++++++++++------------- 2 files changed, 39 insertions(+), 14 deletions(-) diff --git a/letta/services/mcp/base_client.py b/letta/services/mcp/base_client.py index 98c2e81f..d0bc1094 100644 --- a/letta/services/mcp/base_client.py +++ b/letta/services/mcp/base_client.py @@ -1,3 +1,4 @@ +import asyncio from contextlib import AsyncExitStack from typing import Optional, Tuple @@ -18,6 +19,9 @@ class AsyncBaseMCPClient: self.exit_stack = AsyncExitStack() self.session: Optional[ClientSession] = None self.initialized = False + # Track the task that created this client + self._creation_task = asyncio.current_task() + self._cleanup_queue = asyncio.Queue(maxsize=1) async def connect_to_server(self): try: @@ -74,8 +78,29 @@ class AsyncBaseMCPClient: raise RuntimeError("MCPClient has not been initialized") async def cleanup(self): - """Clean up resources""" + """Clean up resources - ensure this runs in the same task""" + if hasattr(self, "_cleanup_task"): + # If we're in a different task, schedule cleanup in original task + current_task = asyncio.current_task() + if current_task != self._creation_task: + # Create a future to signal completion + cleanup_done = asyncio.Future() + self._cleanup_queue.put_nowait((self.exit_stack, cleanup_done)) + await cleanup_done + return + + # Normal cleanup await self.exit_stack.aclose() def to_sync_client(self): raise NotImplementedError("Subclasses must implement to_sync_client") + + async def __aenter__(self): + """Enter the async context manager.""" + await self.connect_to_server() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Exit the async context manager.""" + await self.cleanup() + return False # Don't suppress exceptions diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 9763e553..846a37f5 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -75,23 +75,23 @@ class MCPManager: server_config = mcp_config[mcp_server_name] if isinstance(server_config, SSEServerConfig): - mcp_client = AsyncSSEMCPClient(server_config=server_config) + # mcp_client = AsyncSSEMCPClient(server_config=server_config) + async with AsyncSSEMCPClient(server_config=server_config) as mcp_client: + result, success = await mcp_client.execute_tool(tool_name, tool_args) + logger.info(f"MCP Result: {result}, Success: {success}") + return result, success elif isinstance(server_config, StdioServerConfig): - mcp_client = AsyncStdioMCPClient(server_config=server_config) + async with AsyncStdioMCPClient(server_config=server_config) as mcp_client: + result, success = await mcp_client.execute_tool(tool_name, tool_args) + logger.info(f"MCP Result: {result}, Success: {success}") + return result, success elif isinstance(server_config, StreamableHTTPServerConfig): - mcp_client = AsyncStreamableHTTPMCPClient(server_config=server_config) + async with AsyncStreamableHTTPMCPClient(server_config=server_config) as mcp_client: + result, success = await mcp_client.execute_tool(tool_name, tool_args) + logger.info(f"MCP Result: {result}, Success: {success}") + return result, success else: raise ValueError(f"Unsupported server config type: {type(server_config)}") - await mcp_client.connect_to_server() - - # call tool - result, success = await mcp_client.execute_tool(tool_name, tool_args) - logger.info(f"MCP Result: {result}, Success: {success}") - # TODO: change to pydantic tool - - await mcp_client.cleanup() - - return result, success @enforce_types async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool: