diff --git a/letta/services/mcp/base_client.py b/letta/services/mcp/base_client.py index 4ce0ad0f..394064bc 100644 --- a/letta/services/mcp/base_client.py +++ b/letta/services/mcp/base_client.py @@ -14,9 +14,15 @@ logger = get_logger(__name__) # TODO: Get rid of Async prefix on this class name once we deprecate old sync code class AsyncBaseMCPClient: - def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None): + # HTTP headers + AGENT_ID_HEADER = "X-Agent-Id" + + def __init__( + self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None + ): self.server_config = server_config self.oauth_provider = oauth_provider + self.agent_id = agent_id self.exit_stack = AsyncExitStack() self.session: Optional[ClientSession] = None self.initialized = False diff --git a/letta/services/mcp/sse_client.py b/letta/services/mcp/sse_client.py index 950b4ae0..0327fc26 100644 --- a/letta/services/mcp/sse_client.py +++ b/letta/services/mcp/sse_client.py @@ -16,8 +16,10 @@ logger = get_logger(__name__) # TODO: Get rid of Async prefix on this class name once we deprecate old sync code class AsyncSSEMCPClient(AsyncBaseMCPClient): - def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None): - super().__init__(server_config, oauth_provider) + def __init__( + self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None + ): + super().__init__(server_config, oauth_provider, agent_id) async def _initialize_connection(self, server_config: SSEServerConfig) -> None: headers = {} @@ -27,6 +29,9 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient): if server_config.auth_header and server_config.auth_token: headers[server_config.auth_header] = server_config.auth_token + if self.agent_id: + headers[self.AGENT_ID_HEADER] = self.agent_id + # Use OAuth provider if available, otherwise use regular headers if self.oauth_provider: sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider) diff --git a/letta/services/mcp/stdio_client.py b/letta/services/mcp/stdio_client.py index f7b2c716..faec26cc 100644 --- a/letta/services/mcp/stdio_client.py +++ b/letta/services/mcp/stdio_client.py @@ -1,3 +1,5 @@ +from typing import Optional + from mcp import ClientSession, StdioServerParameters from mcp.client.stdio import stdio_client @@ -10,6 +12,9 @@ logger = get_logger(__name__) # TODO: Get rid of Async prefix on this class name once we deprecate old sync code class AsyncStdioMCPClient(AsyncBaseMCPClient): + def __init__(self, server_config: StdioServerConfig, oauth_provider=None, agent_id: Optional[str] = None): + super().__init__(server_config, oauth_provider, agent_id) + async def _initialize_connection(self, server_config: StdioServerConfig) -> None: args = [arg.split() for arg in server_config.args] # flatten diff --git a/letta/services/mcp/streamable_http_client.py b/letta/services/mcp/streamable_http_client.py index baf2f7c6..e2f256f5 100644 --- a/letta/services/mcp/streamable_http_client.py +++ b/letta/services/mcp/streamable_http_client.py @@ -12,8 +12,13 @@ logger = get_logger(__name__) class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient): - def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None): - super().__init__(server_config, oauth_provider) + def __init__( + self, + server_config: StreamableHTTPServerConfig, + oauth_provider: Optional[OAuthClientProvider] = None, + agent_id: Optional[str] = None, + ): + super().__init__(server_config, oauth_provider, agent_id) async def _initialize_connection(self, server_config: BaseServerConfig) -> None: if not isinstance(server_config, StreamableHTTPServerConfig): @@ -28,6 +33,10 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient): if server_config.auth_header and server_config.auth_token: headers[server_config.auth_header] = server_config.auth_token + # Add agent ID header if provided + if self.agent_id: + headers[self.AGENT_ID_HEADER] = self.agent_id + # Use OAuth provider if available, otherwise use regular headers if self.oauth_provider: streamable_http_cm = streamablehttp_client( diff --git a/letta/services/mcp_manager.py b/letta/services/mcp_manager.py index 8e263efe..c645a34c 100644 --- a/letta/services/mcp_manager.py +++ b/letta/services/mcp_manager.py @@ -56,14 +56,14 @@ class MCPManager: self.cached_mcp_servers = {} # maps id -> async connection @enforce_types - async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]: + async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]: """Get a list of all tools for a specific MCP server.""" mcp_client = None try: mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor) mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor) server_config = mcp_config.to_config() - mcp_client = await self.get_mcp_client(server_config, actor) + mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id) await mcp_client.connect_to_server() # list tools @@ -93,6 +93,7 @@ class MCPManager: tool_args: Optional[Dict[str, Any]], environment_variables: Dict[str, str], actor: PydanticUser, + agent_id: Optional[str] = None, ) -> Tuple[str, bool]: """Call a specific tool from a specific MCP server.""" try: @@ -108,7 +109,7 @@ class MCPManager: raise ValueError(f"MCP server {mcp_server_name} not found in config.") server_config = mcp_config[mcp_server_name] - mcp_client = await self.get_mcp_client(server_config, actor) + mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id) await mcp_client.connect_to_server() # call tool @@ -449,6 +450,7 @@ class MCPManager: server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig], actor: PydanticUser, oauth_provider: Optional[Any] = None, + agent_id: Optional[str] = None, ) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]: """ Helper function to create the appropriate MCP client based on server configuration. @@ -481,13 +483,13 @@ class MCPManager: if server_config.type == MCPServerType.SSE: server_config = SSEServerConfig(**server_config.model_dump()) - return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider) + return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) elif server_config.type == MCPServerType.STDIO: server_config = StdioServerConfig(**server_config.model_dump()) - return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider) + return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) elif server_config.type == MCPServerType.STREAMABLE_HTTP: server_config = StreamableHTTPServerConfig(**server_config.model_dump()) - return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider) + return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id) else: raise ValueError(f"Unsupported server config type: {type(server_config)}") diff --git a/letta/services/tool_executor/mcp_tool_executor.py b/letta/services/tool_executor/mcp_tool_executor.py index 0fd2eea5..a5f70165 100644 --- a/letta/services/tool_executor/mcp_tool_executor.py +++ b/letta/services/tool_executor/mcp_tool_executor.py @@ -37,8 +37,10 @@ class ExternalMCPToolExecutor(ToolExecutor): # TODO: may need to have better client connection management environment_variables = {} + agent_id = None if agent_state: environment_variables = agent_state.get_agent_env_vars_as_dict() + agent_id = agent_state.id function_response, success = await mcp_manager.execute_mcp_server_tool( mcp_server_name=mcp_server_name, @@ -46,6 +48,7 @@ class ExternalMCPToolExecutor(ToolExecutor): tool_args=function_args, environment_variables=environment_variables, actor=actor, + agent_id=agent_id, ) return ToolExecutionResult(