Files
letta-server/letta/functions/mcp_client/sse_client.py
2025-03-17 10:06:36 -07:00

34 lines
1.4 KiB
Python

import asyncio
from mcp import ClientSession
from mcp.client.sse import sse_client
from letta.functions.mcp_client.base_client import BaseMCPClient
from letta.functions.mcp_client.types import SSEServerConfig
from letta.log import get_logger
# see: https://modelcontextprotocol.io/quickstart/user
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
logger = get_logger(__name__)
class SSEMCPClient(BaseMCPClient):
def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
try:
sse_cm = sse_client(url=server_config.server_url)
sse_transport = self.loop.run_until_complete(asyncio.wait_for(sse_cm.__aenter__(), timeout=timeout))
self.stdio, self.write = sse_transport
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None)))
session_cm = ClientSession(self.stdio, self.write)
self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
return True
except asyncio.TimeoutError:
logger.error(f"Timed out while establishing SSE connection (timeout={timeout}s).")
return False
except Exception:
logger.exception("Exception occurred while initializing SSE client session.")
return False