Files
letta-server/letta/functions/mcp_client/stdio_client.py
2025-03-17 10:06:36 -07:00

109 lines
4.5 KiB
Python

import asyncio
import sys
from contextlib import asynccontextmanager
import anyio
import anyio.lowlevel
import mcp.types as types
from anyio.streams.text import TextReceiveStream
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import get_default_environment
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.functions.mcp_client.types import StdioServerConfig
from letta.log import get_logger
logger = get_logger(__name__)
class StdioMCPClient(BaseMCPClient):
def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
try:
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
stdio_cm = forked_stdio_client(server_params)
stdio_transport = self.loop.run_until_complete(asyncio.wait_for(stdio_cm.__aenter__(), timeout=timeout))
self.stdio, self.write = stdio_transport
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
session_cm = ClientSession(self.stdio, self.write)
self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
return True
except asyncio.TimeoutError:
logger.error(f"Timed out while establishing stdio connection (timeout={timeout}s).")
return False
except Exception:
logger.exception("Exception occurred while initializing stdio client session.")
return False
@asynccontextmanager
async def forked_stdio_client(server: StdioServerParameters):
"""
Client transport for stdio: this will connect to a server by spawning a
process and communicating with it over stdin/stdout.
"""
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
try:
process = await anyio.open_process(
[server.command, *server.args],
env=server.env or get_default_environment(),
stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it
)
except OSError as exc:
raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc
async def stdout_reader():
assert process.stdout, "Opened process is missing stdout"
buffer = ""
try:
async with read_stream_writer:
async for chunk in TextReceiveStream(
process.stdout,
encoding=server.encoding,
errors=server.encoding_error_handler,
):
lines = (buffer + chunk).split("\n")
buffer = lines.pop()
for line in lines:
try:
message = types.JSONRPCMessage.model_validate_json(line)
except Exception as exc:
await read_stream_writer.send(exc)
continue
await read_stream_writer.send(message)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async def stdin_writer():
assert process.stdin, "Opened process is missing stdin"
try:
async with write_stream_reader:
async for message in write_stream_reader:
json = message.model_dump_json(by_alias=True, exclude_none=True)
await process.stdin.send(
(json + "\n").encode(
encoding=server.encoding,
errors=server.encoding_error_handler,
)
)
except anyio.ClosedResourceError:
await anyio.lowlevel.checkpoint()
async def watch_process_exit():
returncode = await process.wait()
if returncode != 0:
raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")
async with anyio.create_task_group() as tg, process:
tg.start_soon(stdout_reader)
tg.start_soon(stdin_writer)
tg.start_soon(watch_process_exit)
with anyio.move_on_after(0.2):
await anyio.sleep_forever()
yield read_stream, write_stream