109 lines
4.5 KiB
Python
109 lines
4.5 KiB
Python
import asyncio
|
|
import sys
|
|
from contextlib import asynccontextmanager
|
|
|
|
import anyio
|
|
import anyio.lowlevel
|
|
import mcp.types as types
|
|
from anyio.streams.text import TextReceiveStream
|
|
from mcp import ClientSession, StdioServerParameters
|
|
from mcp.client.stdio import get_default_environment
|
|
|
|
from letta.functions.mcp_client.base_client import BaseMCPClient
|
|
from letta.functions.mcp_client.types import StdioServerConfig
|
|
from letta.log import get_logger
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class StdioMCPClient(BaseMCPClient):
|
|
def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
|
|
try:
|
|
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
|
stdio_cm = forked_stdio_client(server_params)
|
|
stdio_transport = self.loop.run_until_complete(asyncio.wait_for(stdio_cm.__aenter__(), timeout=timeout))
|
|
self.stdio, self.write = stdio_transport
|
|
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
|
|
|
|
session_cm = ClientSession(self.stdio, self.write)
|
|
self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
|
|
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
|
return True
|
|
except asyncio.TimeoutError:
|
|
logger.error(f"Timed out while establishing stdio connection (timeout={timeout}s).")
|
|
return False
|
|
except Exception:
|
|
logger.exception("Exception occurred while initializing stdio client session.")
|
|
return False
|
|
|
|
|
|
@asynccontextmanager
|
|
async def forked_stdio_client(server: StdioServerParameters):
|
|
"""
|
|
Client transport for stdio: this will connect to a server by spawning a
|
|
process and communicating with it over stdin/stdout.
|
|
"""
|
|
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
|
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
|
|
|
try:
|
|
process = await anyio.open_process(
|
|
[server.command, *server.args],
|
|
env=server.env or get_default_environment(),
|
|
stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it
|
|
)
|
|
except OSError as exc:
|
|
raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc
|
|
|
|
async def stdout_reader():
|
|
assert process.stdout, "Opened process is missing stdout"
|
|
buffer = ""
|
|
try:
|
|
async with read_stream_writer:
|
|
async for chunk in TextReceiveStream(
|
|
process.stdout,
|
|
encoding=server.encoding,
|
|
errors=server.encoding_error_handler,
|
|
):
|
|
lines = (buffer + chunk).split("\n")
|
|
buffer = lines.pop()
|
|
for line in lines:
|
|
try:
|
|
message = types.JSONRPCMessage.model_validate_json(line)
|
|
except Exception as exc:
|
|
await read_stream_writer.send(exc)
|
|
continue
|
|
await read_stream_writer.send(message)
|
|
except anyio.ClosedResourceError:
|
|
await anyio.lowlevel.checkpoint()
|
|
|
|
async def stdin_writer():
|
|
assert process.stdin, "Opened process is missing stdin"
|
|
try:
|
|
async with write_stream_reader:
|
|
async for message in write_stream_reader:
|
|
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
|
await process.stdin.send(
|
|
(json + "\n").encode(
|
|
encoding=server.encoding,
|
|
errors=server.encoding_error_handler,
|
|
)
|
|
)
|
|
except anyio.ClosedResourceError:
|
|
await anyio.lowlevel.checkpoint()
|
|
|
|
async def watch_process_exit():
|
|
returncode = await process.wait()
|
|
if returncode != 0:
|
|
raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")
|
|
|
|
async with anyio.create_task_group() as tg, process:
|
|
tg.start_soon(stdout_reader)
|
|
tg.start_soon(stdin_writer)
|
|
tg.start_soon(watch_process_exit)
|
|
|
|
with anyio.move_on_after(0.2):
|
|
await anyio.sleep_forever()
|
|
|
|
yield read_stream, write_stream
|