Co-authored-by: Matthew Zhou <mattzh1314@gmail.com> Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
20 lines
925 B
Python
20 lines
925 B
Python
from contextlib import AsyncExitStack
|
|
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import stdio_client
|
|
|
|
from letta.functions.mcp_client.types import StdioServerConfig
|
|
from letta.log import get_logger
|
|
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
|
class AsyncStdioMCPClient(AsyncBaseMCPClient):
|
|
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: StdioServerConfig) -> None:
|
|
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
|
stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
|
|
self.stdio, self.write = stdio_transport
|
|
self.session = await exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|