1218 lines
45 KiB
Python
1218 lines
45 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding:utf-8 -*-
|
|
#############################################################################
|
|
# Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
|
#
|
|
# openGauss is licensed under Mulan PSL v2.
|
|
# You can use this software according to the terms
|
|
# and conditions of the Mulan PSL v2.
|
|
# You may obtain a copy of Mulan PSL v2 at:
|
|
#
|
|
# http://license.coscl.org.cn/MulanPSL2
|
|
#
|
|
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OF ANY KIND,
|
|
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
|
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
|
# See the Mulan PSL v2 for more details.
|
|
# ----------------------------------------------------------------------------
|
|
# Description : gs_sshexkey is a utility to create SSH trust among nodes in
|
|
# a cluster.
|
|
#############################################################################
|
|
|
|
import sys
|
|
import warnings
|
|
|
|
warnings.simplefilter('ignore', DeprecationWarning)
|
|
sys.path.append(sys.path[0] + "/../lib")
|
|
import time
|
|
import os
|
|
import subprocess
|
|
import pwd
|
|
import grp
|
|
import socket
|
|
import getpass
|
|
import shutil
|
|
package_path = os.path.dirname(os.path.realpath(__file__))
|
|
ld_path = package_path + "/gspylib/clib"
|
|
if 'LD_LIBRARY_PATH' not in os.environ:
|
|
os.environ['LD_LIBRARY_PATH'] = ld_path
|
|
os.execve(os.path.realpath(__file__), sys.argv, os.environ)
|
|
if not os.environ.get('LD_LIBRARY_PATH').startswith(ld_path):
|
|
os.environ['LD_LIBRARY_PATH'] = \
|
|
ld_path + ":" + os.environ['LD_LIBRARY_PATH']
|
|
os.execve(os.path.realpath(__file__), sys.argv, os.environ)
|
|
|
|
from gspylib.common.GaussLog import GaussLog
|
|
from gspylib.common.ErrorCode import ErrorCode
|
|
from gspylib.threads.parallelTool import parallelTool
|
|
from gspylib.common.Common import DefaultValue, ClusterCommand
|
|
from gspylib.common.ParameterParsecheck import Parameter
|
|
from gspylib.os.gsfile import g_file
|
|
from gspylib.os.gsOSlib import g_OSlib
|
|
|
|
DefaultValue.doConfigForParamiko()
|
|
import paramiko
|
|
|
|
HOSTS_MAPPING_FLAG = "#Gauss OM IP Hosts Mapping"
|
|
ipHostInfo = ""
|
|
# the tmp path
|
|
tmp_files = ""
|
|
# tmp file name
|
|
TMP_TRUST_FILE = "step_preinstall_file.dat"
|
|
|
|
|
|
class PrintOnScreen():
|
|
"""
|
|
class about print on screen
|
|
"""
|
|
|
|
def __init__(self):
|
|
"""
|
|
function : Constructor
|
|
input: NA
|
|
output: NA
|
|
"""
|
|
pass
|
|
|
|
def log(self, msg):
|
|
"""
|
|
function : print log
|
|
input: msg: str
|
|
output: NA
|
|
"""
|
|
print(msg)
|
|
|
|
def debug(self, msg):
|
|
"""
|
|
function : debug
|
|
input: msg: debug message string
|
|
output: NA
|
|
"""
|
|
pass
|
|
|
|
def logExit(self, msg):
|
|
"""
|
|
function : print log and exit
|
|
input: msg: str
|
|
output: NA
|
|
"""
|
|
print(msg)
|
|
sys.exit(1)
|
|
|
|
|
|
class GaussCreateTrust():
|
|
"""
|
|
class about create trust for user
|
|
"""
|
|
log_list = ["addStep",
|
|
"constant",
|
|
"Checking network information.",
|
|
"Successfully checked network information.",
|
|
"Creating the local key file.",
|
|
"Successfully created the local key files.",
|
|
"Appending local ID to authorized_keys.",
|
|
"Successfully appended local ID to authorized_keys.",
|
|
"Updating the known_hosts file.",
|
|
"Successfully updated the known_hosts file.",
|
|
"Appending authorized_key on the remote node.",
|
|
"Successfully appended authorized_key on all remote node.",
|
|
"Checking common authentication file content.",
|
|
"Successfully checked common authentication content.",
|
|
"Distributing SSH trust file to all node.",
|
|
"Successfully distributed SSH trust file to all node.",
|
|
"Verifying SSH trust on all hosts.",
|
|
"Successfully verified SSH trust on all hosts.",
|
|
]
|
|
|
|
def __init__(self):
|
|
"""
|
|
function : Constructor
|
|
input: NA
|
|
output: NA
|
|
"""
|
|
self.logger = None
|
|
self.hostFile = ""
|
|
self.hostList = []
|
|
self.passwd = []
|
|
self.logFile = ""
|
|
self.localHost = ""
|
|
self.flag = False
|
|
self.logger = None
|
|
self.localID = ""
|
|
self.user = pwd.getpwuid(os.getuid()).pw_name
|
|
self.group = grp.getgrgid(os.getgid()).gr_name
|
|
self.incorrectPasswdInfo = ""
|
|
self.failedToAppendInfo = ""
|
|
self.homeDir = os.path.expanduser("~" + self.user)
|
|
self.sshDir = "%s/.ssh" % self.homeDir
|
|
self.authorized_keys_fname = '%s/.ssh/authorized_keys' % self.homeDir
|
|
self.known_hosts_fname = '%s/.ssh/known_hosts' % self.homeDir
|
|
self.id_rsa_fname = '%s/.ssh/id_rsa' % self.homeDir
|
|
self.id_rsa_pub_fname = self.id_rsa_fname + '.pub'
|
|
self.skipHostnameSet = False
|
|
self.isKeyboardPassword = False
|
|
self.nodeduplicate = False
|
|
|
|
def usage(self):
|
|
"""
|
|
gs_sshexkey is a utility to create SSH trust among nodes in a cluster.
|
|
|
|
Usage:
|
|
gs_sshexkey -? | --help
|
|
gs_sshexkey -V | --version
|
|
gs_sshexkey -f HOSTFILE [-l LOGFILE] [--skip-hostname-set]
|
|
gs_sshexkey -h IPLIST [-l LOGFILE] [--skip-hostname-set]
|
|
|
|
General options:
|
|
-f Host file containing the IP address of nodes.
|
|
-h Host ip list. Separate multiple nodes with commas(,).
|
|
-l Path of log file.
|
|
--skip-hostname-set Whether to skip hostname setting.
|
|
(The default value is set.)
|
|
-?, --help Show help information for this utility,
|
|
and exit the command line mode.
|
|
-V, --version Show version information.
|
|
"""
|
|
print(self.usage.__doc__)
|
|
|
|
def parseCommandLine(self):
|
|
"""
|
|
function: Check parameter from command line
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
paraObj = Parameter()
|
|
paraDict = paraObj.ParameterCommandLine("sshexkey")
|
|
if ("helpFlag" in paraDict.keys()):
|
|
self.usage()
|
|
sys.exit(0)
|
|
|
|
if ("hostfile" in paraDict.keys()):
|
|
self.hostFile = paraDict.get("hostfile")
|
|
if ("nodename" in paraDict.keys()):
|
|
self.hostList = paraDict.get("nodename")
|
|
if ("logFile" in paraDict.keys()):
|
|
self.logFile = paraDict.get("logFile")
|
|
if ("skipHostnameSet" in paraDict.keys()):
|
|
self.skipHostnameSet = paraDict.get("skipHostnameSet")
|
|
if ("noDeduplicate" in paraDict.keys()):
|
|
self.nodeduplicate = paraDict.get("noDeduplicate")
|
|
|
|
def checkParameter(self):
|
|
"""
|
|
function: Check parameter from command line
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
# check required parameters
|
|
if len(self.hostList) == 0:
|
|
if (self.hostFile == ""):
|
|
self.usage()
|
|
GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"]
|
|
% 'f' + ".")
|
|
if (not os.path.exists(self.hostFile)):
|
|
GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50201"]
|
|
% self.hostFile)
|
|
if (not os.path.isabs(self.hostFile)):
|
|
GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"]
|
|
% self.hostFile)
|
|
|
|
# read host file to hostList
|
|
self.readHostFile()
|
|
|
|
if (self.hostList == []):
|
|
GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50004"]
|
|
% 'f' + " It cannot be empty.")
|
|
else:
|
|
for temp_host in self.hostList:
|
|
if not DefaultValue.isIpValid(temp_host):
|
|
GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50000"]
|
|
% temp_host)
|
|
# check logfile
|
|
if (self.logFile != ""):
|
|
if (not os.path.isabs(self.logFile)):
|
|
GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"]
|
|
% self.logFile)
|
|
|
|
if (not self.passwd):
|
|
self.passwd = self.getUserPasswd()
|
|
self.isKeyboardPassword = True
|
|
|
|
def logOut(self, log_index1, log_index2):
|
|
"""
|
|
function:logout
|
|
:param log_index1: index of the log
|
|
:param log_index2: indec of the log
|
|
:return:
|
|
"""
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.log(GaussCreateTrust.log_list[log_index1],
|
|
GaussCreateTrust.log_list[log_index2])
|
|
else:
|
|
self.logger.log(GaussCreateTrust.log_list[log_index1])
|
|
else:
|
|
self.logger.log(GaussCreateTrust.log_list[log_index1])
|
|
|
|
def readHostFile(self):
|
|
"""
|
|
function: read host file to hostList
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
inValidIp = []
|
|
try:
|
|
with open(self.hostFile, "r") as f:
|
|
for readLine in f:
|
|
hostname = readLine.strip().split("\n")[0]
|
|
if hostname != "" and hostname not in self.hostList:
|
|
if not DefaultValue.isIpValid(hostname):
|
|
inValidIp.append(hostname)
|
|
continue
|
|
self.hostList.append(hostname)
|
|
if len(inValidIp) > 0:
|
|
GaussLog.exitWithError(ErrorCode.GAUSS_506["GAUSS_50603"]
|
|
+ "The IP list is:%s." % inValidIp)
|
|
except Exception as e:
|
|
raise Exception(ErrorCode.GAUSS_502["GAUSS_50204"] % "host file"
|
|
+ " Error: \n%s" % str(e))
|
|
|
|
def getAllHostsName(self, ip):
|
|
"""
|
|
function:
|
|
Connect to all nodes ,then get all hostaname by threading
|
|
precondition:
|
|
1.User's password is correct on each node
|
|
postcondition:
|
|
NA
|
|
input: ip
|
|
output:Dictionary ipHostname,key is IP and value is hostname
|
|
hideninfo:NA
|
|
"""
|
|
|
|
ipHostname = {}
|
|
try:
|
|
ssh = paramiko.Transport((ip, 22))
|
|
except Exception as e:
|
|
raise Exception(ErrorCode.GAUSS_512["GAUSS_51220"] % ip
|
|
+ " Error: \n%s" % str(e))
|
|
try:
|
|
ssh.connect(username=self.user, password=self.passwd[0])
|
|
except Exception as e:
|
|
ssh.close()
|
|
raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % ip)
|
|
|
|
check_channel = ssh.open_session()
|
|
cmd = "cd"
|
|
check_channel.exec_command(cmd)
|
|
env_msg = check_channel.recv_stderr(9999).decode()
|
|
while True:
|
|
channel_read = check_channel.recv(9999).decode().strip()
|
|
if (len(channel_read) != 0):
|
|
env_msg += str(channel_read)
|
|
else:
|
|
break
|
|
if (env_msg != ""):
|
|
ipHostname["Node[%s]" % ip] = "Output: [" + env_msg \
|
|
+ " ] print by /etc/profile or" \
|
|
" ~/.bashrc, please check it."
|
|
return ipHostname
|
|
|
|
channel = ssh.open_session()
|
|
cmd = "hostname"
|
|
channel.exec_command(cmd)
|
|
hostname = channel.recv(9999).decode().strip()
|
|
ipHostname[ip] = hostname
|
|
ssh.close()
|
|
return ipHostname
|
|
|
|
def verifyPasswd(self, ssh, pswd=None):
|
|
try:
|
|
ssh.connect(username=self.user, password=pswd)
|
|
return True
|
|
except Exception:
|
|
ssh.close()
|
|
return False
|
|
|
|
def parallelGetHosts(self, sshIps):
|
|
parallelResult = {}
|
|
ipHostname = parallelTool.parallelExecute(self.getAllHostsName, sshIps)
|
|
|
|
err_msg = ""
|
|
for i in ipHostname:
|
|
for (key, value) in i.items():
|
|
if (key.find("Node") >= 0):
|
|
err_msg += str(i)
|
|
else:
|
|
parallelResult[key] = value
|
|
if (len(err_msg) > 0):
|
|
raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"] % err_msg)
|
|
return parallelResult
|
|
|
|
def serialGetHosts(self, sshIps):
|
|
serialResult = {}
|
|
invalidIP = ""
|
|
boolInvalidIp = False
|
|
for sshIp in sshIps:
|
|
isPasswdOK = False
|
|
for pswd in self.passwd:
|
|
try:
|
|
ssh = paramiko.Transport((sshIp, 22))
|
|
except Exception as e:
|
|
self.logger.debug(str(e))
|
|
invalidIP += "Incorrect IP address: %s.\n" % sshIp
|
|
boolInvalidIp = True
|
|
break
|
|
|
|
isPasswdOK = self.verifyPasswd(ssh, pswd)
|
|
if (isPasswdOK):
|
|
break
|
|
|
|
if (boolInvalidIp):
|
|
boolInvalidIp = False
|
|
continue
|
|
|
|
if (not isPasswdOK and self.isKeyboardPassword):
|
|
GaussLog.printMessage("Please enter password for current"
|
|
" user[%s] on the node[%s]."
|
|
% (self.user, sshIp))
|
|
# Try entering the password 3 times interactively
|
|
for i in range(3):
|
|
KeyboardPassword = getpass.getpass()
|
|
DefaultValue.checkPasswordVaild(KeyboardPassword)
|
|
ssh = paramiko.Transport((sshIp, 22))
|
|
isPasswdOK = self.verifyPasswd(ssh, KeyboardPassword)
|
|
if (isPasswdOK):
|
|
self.passwd.append(KeyboardPassword)
|
|
break
|
|
else:
|
|
continue
|
|
# if isKeyboardPassword is true, 3 times after the password is
|
|
# also wrong to throw an unusual exit
|
|
if (not isPasswdOK):
|
|
raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % sshIp)
|
|
|
|
cmd = "cd"
|
|
check_channel = ssh.open_session()
|
|
check_channel.exec_command(cmd)
|
|
check_result = check_channel.recv_stderr(9999).decode()
|
|
while True:
|
|
channel_read = check_channel.recv(9999).decode()
|
|
if (len(channel_read) != 0):
|
|
check_result += str(channel_read)
|
|
else:
|
|
break
|
|
|
|
if (check_result != ""):
|
|
raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"]
|
|
% check_result + "Please check %s node"
|
|
" /etc/profile or ~/.bashrc"
|
|
% sshIp)
|
|
else:
|
|
cmd = "hostname"
|
|
channel = ssh.open_session()
|
|
channel.exec_command(cmd)
|
|
while True:
|
|
hostname = channel.recv(9999).decode().strip()
|
|
if (len(hostname) != 0):
|
|
serialResult[sshIp] = hostname
|
|
else:
|
|
break
|
|
ssh.close()
|
|
|
|
if (invalidIP):
|
|
raise Exception(
|
|
ErrorCode.GAUSS_511["GAUSS_51101"] % invalidIP.rstrip("\n"))
|
|
|
|
return serialResult
|
|
|
|
def getAllHosts(self, sshIps):
|
|
"""
|
|
function:
|
|
Connect to all nodes ,then get all hostaname
|
|
precondition:
|
|
1.User's password is correct on each node
|
|
postcondition:
|
|
NA
|
|
input: sshIps,username,passwd
|
|
output:Dictionary ipHostname,key is IP and value is hostname
|
|
hideninfo:NA
|
|
"""
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.debug("Get hostnames for all nodes.", "addStep")
|
|
else:
|
|
self.logger.debug("Get hostnames for all nodes.")
|
|
if (len(self.passwd) == 0):
|
|
self.isKeyboardPassword = True
|
|
GaussLog.printMessage("Please enter password for current user[%s]."
|
|
% self.user)
|
|
passwd = getpass.getpass()
|
|
self.passwd.append(passwd)
|
|
|
|
if (len(self.passwd) == 1):
|
|
try:
|
|
result = self.parallelGetHosts(sshIps)
|
|
except Exception as e:
|
|
if (self.isKeyboardPassword and str(e).startswith(
|
|
"[GAUSS-50306] : The password of")):
|
|
GaussLog.printMessage(
|
|
"Notice :The password of some nodes is incorrect.")
|
|
result = self.serialGetHosts(sshIps)
|
|
else:
|
|
raise Exception(str(e))
|
|
else:
|
|
result = self.serialGetHosts(sshIps)
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.debug("Successfully get hostnames for all nodes.",
|
|
"constant")
|
|
else:
|
|
self.logger.debug("Successfully get hostnames for all nodes.")
|
|
return result
|
|
|
|
def writeLocalHosts(self, result):
|
|
"""
|
|
function:
|
|
Write hostname and Ip into /etc/hosts when there's not the same one
|
|
in /etc/hosts file
|
|
precondition:
|
|
NA
|
|
postcondition:
|
|
NA
|
|
input: Dictionary result,key is IP and value is hostname
|
|
output: NA
|
|
hideninfo:NA
|
|
"""
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.debug(
|
|
"Write local hostname and Ip into /etc/hosts.", "addStep")
|
|
else:
|
|
self.logger.debug(
|
|
"Write local hostname and Ip into /etc/hosts.")
|
|
hostIPInfo = ""
|
|
if (os.getuid() == 0):
|
|
tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()
|
|
# Check if /etc/hosts exists.
|
|
if (not os.path.exists("/etc/hosts")):
|
|
raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"] +
|
|
" Error: \nThe /etc/hosts does not exist.")
|
|
cmd = "grep -v '" + HOSTS_MAPPING_FLAG + "' /etc/hosts| grep -v '^$'"
|
|
(status, output) = subprocess.getstatusoutput(cmd)
|
|
try:
|
|
g_file.createFile(tmpHostIpName)
|
|
g_file.changeMode(DefaultValue.KEY_FILE_MODE, tmpHostIpName)
|
|
g_file.writeFile(tmpHostIpName, [output])
|
|
shutil.copyfile(tmpHostIpName, '/etc/hosts')
|
|
g_file.removeFile(tmpHostIpName)
|
|
except Exception as e:
|
|
if os.path.exists(tmpHostIpName):
|
|
g_file.removeFile(tmpHostIpName)
|
|
raise Exception(str(e))
|
|
if not self.nodeduplicate:
|
|
ipCompare = []
|
|
for line in output.split("\n"):
|
|
if line:
|
|
ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
|
|
tmpResult = {}
|
|
for s_key in list(result.keys()):
|
|
if s_key not in ipCompare:
|
|
tmpResult[s_key] = result[s_key]
|
|
for (key, value) in tmpResult.items():
|
|
hostIPInfo += '%s %s %s\n' % (key, value, HOSTS_MAPPING_FLAG)
|
|
else:
|
|
for (key, value) in result.items():
|
|
hostIPInfo += '%s %s %s\n' % (key, value, HOSTS_MAPPING_FLAG)
|
|
hostIPInfo = hostIPInfo[:-1]
|
|
ipInfoList = [hostIPInfo]
|
|
g_file.writeFile("/etc/hosts", ipInfoList)
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.debug(
|
|
"Successfully write local hostname and Ip into "
|
|
"/etc/hosts.",
|
|
"constant")
|
|
else:
|
|
self.logger.debug(
|
|
"Successfully write local hostname and Ip into "
|
|
"/etc/hosts.")
|
|
|
|
def writeRemoteHostName(self, ip):
|
|
"""
|
|
function:
|
|
Write hostname and Ip into /etc/hosts when there's not the same one
|
|
in /etc/hosts file by threading
|
|
precondition:
|
|
NA
|
|
postcondition:
|
|
NA
|
|
input: ip
|
|
output: NA
|
|
hideninfo:NA
|
|
"""
|
|
writeResult = []
|
|
result = {}
|
|
tmpHostIpName = "./tmp_hostsiphostname_%d_%s" % (os.getpid(), ip)
|
|
username = pwd.getpwuid(os.getuid()).pw_name
|
|
global ipHostInfo
|
|
try:
|
|
ssh = paramiko.Transport((ip, 22))
|
|
except Exception as e:
|
|
raise Exception(ErrorCode.GAUSS_511["GAUSS_51107"]
|
|
+ " Error: \n%s" % str(e))
|
|
try:
|
|
ssh.connect(username=username, password=self.passwd[0])
|
|
except Exception as e:
|
|
ssh.close()
|
|
raise Exception(ErrorCode.GAUSS_503["GAUSS_50317"]
|
|
+ " Error: \n%s" % str(e))
|
|
cmd = "grep -v '%s' %s | grep -v '^$'" \
|
|
% (" #Gauss.* IP Hosts Mapping", '/etc/hosts')
|
|
channel = ssh.open_session()
|
|
channel.exec_command(cmd)
|
|
ipHosts = channel.recv(9999).decode().strip()
|
|
errInfo = channel.recv_stderr(9999).decode().strip()
|
|
cmd = "echo \"%s\" > %s ; cp %s %s && rm -rf %s" \
|
|
% (ipHosts, tmpHostIpName, tmpHostIpName, '/etc/hosts', tmpHostIpName)
|
|
channel = ssh.open_session()
|
|
channel.exec_command(cmd)
|
|
ipHosts1 = channel.recv(9999).decode().strip()
|
|
errInfo1 = channel.recv_stderr(9999).decode().strip()
|
|
if ((errInfo + errInfo1)):
|
|
writeResult.append(errInfo + errInfo1)
|
|
else:
|
|
if (not ipHosts1):
|
|
if not self.nodeduplicate:
|
|
ipCompare = []
|
|
for line in ipHosts.split("\n"):
|
|
if line:
|
|
ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
|
|
tmpIpHostInfo = ""
|
|
ipArray = ipHostInfo.split("\n")
|
|
for info in ipArray:
|
|
hostname = info.split(' ')[0]
|
|
if hostname not in ipCompare:
|
|
tmpIpHostInfo += info + "\n"
|
|
cmd = "echo '%s' >> /etc/hosts" % tmpIpHostInfo
|
|
else:
|
|
cmd = "echo '%s' >> /etc/hosts" % ipHostInfo
|
|
channel = ssh.open_session()
|
|
channel.exec_command(cmd)
|
|
errInfo = channel.recv_stderr(9999).decode().strip()
|
|
if (errInfo):
|
|
writeResult.append(errInfo)
|
|
if channel:
|
|
channel.close()
|
|
result[ip] = writeResult
|
|
if (len(writeResult) > 0):
|
|
return (False, result)
|
|
else:
|
|
return (True, result)
|
|
|
|
def writeRemoteHosts(self, result, username, rootPasswd):
|
|
"""
|
|
function:
|
|
Write hostname and Ip into /etc/hosts when there's not the same one
|
|
in /etc/hosts file
|
|
precondition:
|
|
NA
|
|
postcondition:
|
|
NA
|
|
input: Dictionary result,key is IP and value is hostname
|
|
rootPasswd
|
|
output: NA
|
|
hideninfo:NA
|
|
"""
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.debug(
|
|
"Write remote hostname and Ip into /etc/hosts.", "addStep")
|
|
else:
|
|
self.logger.debug(
|
|
"Write remote hostname and Ip into /etc/hosts.")
|
|
global ipHostInfo
|
|
boolInvalidIp = False
|
|
ipHostInfo = ""
|
|
if (os.getuid() == 0):
|
|
writeResult = []
|
|
tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()
|
|
|
|
if (len(rootPasswd) == 1):
|
|
result1 = {}
|
|
for (key, value) in result.items():
|
|
ipHostInfo += '%s %s %s\n' % (key, value,
|
|
HOSTS_MAPPING_FLAG)
|
|
if (value != self.localHost):
|
|
if (not value in result1.keys()):
|
|
result1[value] = key
|
|
|
|
sshIps = result1.keys()
|
|
ipHostInfo = ipHostInfo[:-1]
|
|
if (sshIps):
|
|
ipRemoteHostname = parallelTool.parallelExecute(
|
|
self.writeRemoteHostName, sshIps)
|
|
errorMsg = ""
|
|
for (key, value) in ipRemoteHostname:
|
|
if (not key):
|
|
errorMsg = errorMsg + '\n' + str(value)
|
|
if (errorMsg != ""):
|
|
raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"]
|
|
+ " Error: %s" % errorMsg)
|
|
else:
|
|
for (key, value) in result.items():
|
|
if (value == self.localHost):
|
|
continue
|
|
for pswd in rootPasswd:
|
|
try:
|
|
ssh = paramiko.Transport((key, 22))
|
|
except Exception as e:
|
|
self.logger.debug(str(e))
|
|
boolInvalidIp = True
|
|
break
|
|
try:
|
|
ssh.connect(username=username, password=pswd)
|
|
break
|
|
except Exception as e:
|
|
self.logger.debug(str(e))
|
|
continue
|
|
if (boolInvalidIp):
|
|
boolInvalidIp = False
|
|
continue
|
|
cmd = "grep -v '%s' %s | grep -v '^$'" % (
|
|
" #Gauss.* IP Hosts Mapping", '/etc/hosts')
|
|
channel = ssh.open_session()
|
|
channel.exec_command(cmd)
|
|
ipHosts = channel.recv(9999).decode().strip()
|
|
errInfo = channel.recv_stderr(9999).decode().strip()
|
|
cmd = "echo \"%s\" > %s ; cp %s %s && rm -rf %s" % (
|
|
ipHosts, tmpHostIpName, tmpHostIpName,
|
|
'/etc/hosts', tmpHostIpName)
|
|
channel = ssh.open_session()
|
|
channel.exec_command(cmd)
|
|
ipHosts1 = channel.recv(9999).decode().strip()
|
|
errInfo1 = channel.recv_stderr(9999).decode().strip()
|
|
|
|
if (errInfo + errInfo1):
|
|
writeResult.append(errInfo + errInfo1)
|
|
else:
|
|
if (not ipHosts1):
|
|
ipHostInfo = ""
|
|
if not self.nodeduplicate:
|
|
ipCompare = []
|
|
for line in ipHosts.split("\n"):
|
|
if line:
|
|
ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
|
|
for (key1, value1) in result.items():
|
|
if key1 not in ipCompare:
|
|
ipHostInfo += '%s %s %s\n' % (
|
|
key1, value1, HOSTS_MAPPING_FLAG)
|
|
else:
|
|
for (key1, value1) in result.items():
|
|
ipHostInfo += '%s %s %s\n' % (
|
|
key1, value1, HOSTS_MAPPING_FLAG)
|
|
ipHostInfo = ipHostInfo[:-1]
|
|
cmd = "echo '%s' >> /etc/hosts" % ipHostInfo
|
|
channel = ssh.open_session()
|
|
channel.exec_command(cmd)
|
|
errInfo = channel.recv_stderr(
|
|
9999).decode().strip()
|
|
if (errInfo):
|
|
writeResult.append(errInfo)
|
|
|
|
if channel:
|
|
channel.close()
|
|
|
|
if (len(writeResult) > 0):
|
|
raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"]
|
|
+ " Error: \n%s" % writeResult)
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.debug(
|
|
"Successfully write remote hostname and Ip into "
|
|
"/etc/hosts.",
|
|
"constant")
|
|
else:
|
|
self.logger.debug(
|
|
"Successfully write remote hostname and Ip into "
|
|
"/etc/hosts.")
|
|
|
|
def initLogger(self):
|
|
"""
|
|
function: Init logger
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
if (self.logFile != ""):
|
|
self.logger = GaussLog(self.logFile, "gs_sshexkey")
|
|
else:
|
|
self.logger = PrintOnScreen()
|
|
|
|
def checkNetworkInfo(self):
|
|
"""
|
|
function: check local node to other node Network Information
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
self.logOut(2, 0)
|
|
try:
|
|
netWorkList = DefaultValue.checkIsPing(self.hostList)
|
|
if not netWorkList:
|
|
self.logger.log("All nodes in the network are Normal.")
|
|
else:
|
|
self.logger.logExit(ErrorCode.GAUSS_506["GAUSS_50600"]
|
|
+ "The IP list is:%s." % netWorkList)
|
|
except Exception as e:
|
|
self.logger.logExit(str(e))
|
|
self.logOut(3, 1)
|
|
|
|
def run(self):
|
|
"""
|
|
function: Do create SSH trust
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
self.parseCommandLine()
|
|
self.checkParameter()
|
|
self.localHost = socket.gethostname()
|
|
|
|
self.initLogger()
|
|
global tmp_files
|
|
tmp_files = "/tmp/%s" % TMP_TRUST_FILE
|
|
if (self.logFile != ""):
|
|
if (not os.path.exists(tmp_files)):
|
|
self.logger.debug(
|
|
"gs_sshexkey execution takes %s steps in total" %
|
|
ClusterCommand.countTotalSteps("gs_sshexkey", "",
|
|
self.skipHostnameSet))
|
|
Ips = []
|
|
Ips.extend(self.hostList)
|
|
result = self.getAllHosts(Ips)
|
|
self.checkNetworkInfo()
|
|
|
|
if not self.skipHostnameSet:
|
|
self.writeLocalHosts(result)
|
|
self.writeRemoteHosts(result, self.user, self.passwd)
|
|
|
|
self.logger.log("Creating SSH trust.")
|
|
try:
|
|
self.localID = self.createPublicPrivateKeyFile()
|
|
self.addLocalAuthorized()
|
|
self.updateKnow_hostsFile(result)
|
|
self.addRemoteAuthorization()
|
|
self.determinePublicAuthorityFile()
|
|
self.synchronizationLicenseFile()
|
|
self.verifyTrust()
|
|
self.logger.log("Successfully created SSH trust.")
|
|
except Exception as e:
|
|
self.logger.logExit(str(e))
|
|
|
|
def createPublicPrivateKeyFile(self):
|
|
"""
|
|
function: create local public private key file
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
self.logOut(4, 0)
|
|
|
|
if not os.path.exists(self.id_rsa_pub_fname):
|
|
cmd = 'ssh-keygen -t rsa -N \"\" -f ~/.ssh/id_rsa < /dev/null'
|
|
cmd += "&& chmod %s %s %s" % (DefaultValue.KEY_FILE_MODE,
|
|
self.id_rsa_fname,
|
|
self.id_rsa_pub_fname)
|
|
(status, output) = subprocess.getstatusoutput(cmd)
|
|
if (status != 0):
|
|
self.logger.log("The cmd is %s " % cmd)
|
|
raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"]
|
|
+ " Error:\n%s" % output)
|
|
try:
|
|
try:
|
|
with open(self.id_rsa_pub_fname, 'r') as f:
|
|
return f.readline().strip()
|
|
except IOError as e:
|
|
self.logger.debug(str(e))
|
|
raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"]
|
|
+ " Unable to read the generated file."
|
|
+ self.id_rsa_pub_fname)
|
|
finally:
|
|
self.logOut(5, 1)
|
|
|
|
def addLocalAuthorized(self):
|
|
"""
|
|
function: append the local id_rsa.pub value provided to authorized_keys
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
self.logOut(6, 0)
|
|
g_file.createFileInSafeMode(self.authorized_keys_fname)
|
|
with open(self.authorized_keys_fname, 'a+') as f:
|
|
for line in f:
|
|
if line.strip() == self.localID:
|
|
# The localID is already in authorizedKeys; no need to add
|
|
return
|
|
f.write(self.localID)
|
|
f.write('\n')
|
|
self.logOut(7, 1)
|
|
g_file.changeMode(DefaultValue.KEY_FILE_MODE,
|
|
self.authorized_keys_fname)
|
|
|
|
def checkAuthentication(self, hostname):
|
|
"""
|
|
function: Ensure the proper password-less access to the remote host.
|
|
input : hostname
|
|
output: True/False, hostname
|
|
"""
|
|
cmd = 'ssh -n %s %s true' % (DefaultValue.SSH_OPTION, hostname)
|
|
(status, output) = subprocess.getstatusoutput(cmd)
|
|
if (status != 0):
|
|
self.logger.debug("The cmd is %s " % cmd)
|
|
self.logger.debug(
|
|
"Failed to check authentication. Hostname:%s. Error: \n%s"
|
|
% (hostname, output))
|
|
return (False, hostname)
|
|
return (True, hostname)
|
|
|
|
def updateKnow_hostsFile(self, result):
|
|
"""
|
|
function: keyscan all hosts and update known_hosts file
|
|
input : result
|
|
output: NA
|
|
"""
|
|
self.logOut(8, 0)
|
|
hostnameList = []
|
|
hostnameList.extend(self.hostList)
|
|
for (key, value) in result.items():
|
|
hostnameList.append(value)
|
|
for hostname in hostnameList:
|
|
cmd = 'ssh-keyscan -t rsa %s >> %s ' % (hostname,
|
|
self.known_hosts_fname)
|
|
cmd += "&& chmod %s %s" % (DefaultValue.KEY_FILE_MODE,
|
|
self.known_hosts_fname)
|
|
(status, output) = subprocess.getstatusoutput(cmd)
|
|
if (status != 0):
|
|
raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % cmd
|
|
+ " Error:\n%s" % output)
|
|
(status, output) = self.checkAuthentication(self.localHost)
|
|
if not status:
|
|
raise Exception(
|
|
ErrorCode.GAUSS_511["GAUSS_51100"] % self.localHost)
|
|
self.logOut(9, 1)
|
|
|
|
def tryParamikoConnect(self, hostname, client, pswd=None, silence=False):
|
|
"""
|
|
function: try paramiko connect
|
|
input : hostname, client, pswd, silence
|
|
output: True/False
|
|
"""
|
|
try:
|
|
client.connect(hostname, password=pswd, allow_agent=False,
|
|
look_for_keys=False)
|
|
return True
|
|
except paramiko.AuthenticationException as e:
|
|
if not silence:
|
|
self.logger.debug("Incorrect password. Node: %s." % hostname
|
|
+ " Error:\n%s" % str(e))
|
|
client.close()
|
|
return False
|
|
except Exception as e:
|
|
if not silence:
|
|
self.logger.debug('[SSHException %s] %s' % (hostname, str(e)))
|
|
client.close()
|
|
raise Exception(str(e))
|
|
|
|
def addRemoteAuthorization(self):
|
|
"""
|
|
function: Send local ID to remote over SSH, and append to
|
|
authorized_key
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
self.logOut(10, 0)
|
|
try:
|
|
parallelTool.parallelExecute(self.sendRemoteAuthorization,
|
|
self.hostList)
|
|
if (self.incorrectPasswdInfo != ""):
|
|
self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"]
|
|
% (self.incorrectPasswdInfo.rstrip("\n")))
|
|
if (self.failedToAppendInfo != ""):
|
|
self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"]
|
|
% (self.failedToAppendInfo.rstrip("\n")))
|
|
except Exception as e:
|
|
self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51111"]
|
|
+ " Error:%s." % str(e))
|
|
self.logOut(11, 1)
|
|
|
|
def sendRemoteAuthorization(self, hostname):
|
|
"""
|
|
function: send remote authorization
|
|
input : hostname
|
|
output: NA
|
|
"""
|
|
if (hostname != self.localHost):
|
|
p = None
|
|
cin = cout = cerr = None
|
|
try:
|
|
# ssh Remote Connection other node
|
|
p = paramiko.SSHClient()
|
|
p.load_system_host_keys()
|
|
ok = self.tryParamikoConnect(hostname, p, self.passwd[0],
|
|
silence=True)
|
|
if not ok:
|
|
for pswd in self.passwd[1:]:
|
|
ok = self.tryParamikoConnect(hostname, p, pswd,
|
|
silence=True)
|
|
if ok:
|
|
break
|
|
if not ok:
|
|
self.incorrectPasswdInfo += "Without this node[%s] of " \
|
|
"the correct password.\n" % \
|
|
hostname
|
|
return
|
|
# Create .ssh directory and ensure content meets permission
|
|
# requirements for password-less SSH
|
|
cmd = ('mkdir -p .ssh; ' + "chown -R %s:%s %s; " %
|
|
(self.user, self.group, self.sshDir) + 'chmod %s .ssh; '
|
|
% DefaultValue.KEY_DIRECTORY_MODE
|
|
+ 'touch .ssh/authorized_keys; '
|
|
+ 'touch .ssh/known_hosts; '
|
|
+ 'chmod %s .ssh/auth* .ssh/id* .ssh/known_hosts; '
|
|
% DefaultValue.KEY_FILE_MODE)
|
|
(cin, cout, cerr) = p.exec_command(cmd)
|
|
cin.close()
|
|
cout.close()
|
|
cerr.close()
|
|
|
|
# Append the ID to authorized_keys;
|
|
cnt = 0
|
|
cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ok' \
|
|
% self.localID
|
|
(cin, cout, cerr) = p.exec_command(cmd)
|
|
cin.close()
|
|
# readline will read other msg.
|
|
line = cout.read().decode()
|
|
while (line.find("ok ok ok") < 0):
|
|
time.sleep(cnt * 2)
|
|
cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ' \
|
|
'ok' % self.localID
|
|
(cin, cout, cerr) = p.exec_command(cmd)
|
|
cin.close()
|
|
cnt += 1
|
|
line = cout.readline()
|
|
if (cnt >= 3):
|
|
break
|
|
if (line.find("ok ok ok") < 0):
|
|
continue
|
|
else:
|
|
break
|
|
|
|
if (line.find("ok ok ok") < 0):
|
|
self.failedToAppendInfo += "...send to %s\nFailed to " \
|
|
"append local ID to " \
|
|
"authorized_keys on remote " \
|
|
"node %s.\n" % (
|
|
hostname, hostname)
|
|
return
|
|
cout.close()
|
|
cerr.close()
|
|
self.logger.debug(
|
|
"Send to %s\nSuccessfully appended authorized_key on "
|
|
"remote node %s." % (hostname, hostname))
|
|
finally:
|
|
if cin:
|
|
cin.close()
|
|
if cout:
|
|
cout.close()
|
|
if cerr:
|
|
cerr.close()
|
|
if p:
|
|
p.close()
|
|
|
|
def determinePublicAuthorityFile(self):
|
|
'''
|
|
function: determine common authentication file content
|
|
input : NA
|
|
output: NA
|
|
'''
|
|
self.logOut(12, 0)
|
|
# eliminate duplicates in known_hosts file
|
|
try:
|
|
tab = self.readKnownHosts()
|
|
self.writeKnownHosts(tab)
|
|
except IOError as e:
|
|
self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"]
|
|
% "known hosts file" + " Error:\n%s" % str(e))
|
|
|
|
# eliminate duploicates in authorized_keys file
|
|
try:
|
|
tab = self.readAuthorizedKeys()
|
|
self.writeAuthorizedKeys(tab)
|
|
except IOError as e:
|
|
self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"]
|
|
% "authorized keys file" + " Error:\n%s"
|
|
% str(e))
|
|
self.logOut(13, 1)
|
|
|
|
def addRemoteID(self, tab, line):
|
|
"""
|
|
function: add remote node id
|
|
input : tab, line
|
|
output: True/False
|
|
"""
|
|
IDKey = line.strip().split()
|
|
if not (len(IDKey) == 3 and line[0] != '#'):
|
|
return False
|
|
tab[IDKey[2]] = line
|
|
return True
|
|
|
|
def readAuthorizedKeys(self, tab=None, keysFile=None):
|
|
"""
|
|
function: read authorized keys
|
|
input : tab, keysFile
|
|
output: tab
|
|
"""
|
|
if not keysFile:
|
|
keysFile = self.authorized_keys_fname
|
|
if not tab:
|
|
tab = {}
|
|
with open(keysFile, 'r') as f:
|
|
for line in f:
|
|
self.addRemoteID(tab, line)
|
|
return tab
|
|
|
|
def writeAuthorizedKeys(self, tab, keysFile=None):
|
|
"""
|
|
function: write authorized keys
|
|
input : tab, keysFile
|
|
output: True/False
|
|
"""
|
|
if not keysFile:
|
|
keysFile = self.authorized_keys_fname
|
|
with open(keysFile, 'w') as f:
|
|
for IDKey in tab:
|
|
f.write(tab[IDKey])
|
|
|
|
def addKnownHost(self, tab, line):
|
|
"""
|
|
function: add known host
|
|
input : tab, line
|
|
output: True/False
|
|
"""
|
|
key = line.strip().split()
|
|
if not (len(key) == 3 and line[0] != '#'):
|
|
return False
|
|
tab[key[0]] = line
|
|
return True
|
|
|
|
def readKnownHosts(self, tab=None, hostsFile=None):
|
|
"""
|
|
function: read known host
|
|
input : tab, hostsFile
|
|
output: tab
|
|
"""
|
|
if not hostsFile:
|
|
hostsFile = self.known_hosts_fname
|
|
if not tab:
|
|
tab = {}
|
|
with open(hostsFile, 'r') as f:
|
|
for line in f:
|
|
self.addKnownHost(tab, line)
|
|
return tab
|
|
|
|
def writeKnownHosts(self, tab, hostsFile=None):
|
|
"""
|
|
function: write known host
|
|
input : tab, hostsFile
|
|
output: NA
|
|
"""
|
|
if not hostsFile:
|
|
hostsFile = self.known_hosts_fname
|
|
with open(hostsFile, 'w') as f:
|
|
for key in tab:
|
|
f.write(tab[key])
|
|
|
|
def sendTrustFile(self, hostname):
|
|
'''
|
|
function: Set or update the authentication files on hostname
|
|
input : hostname
|
|
output: NA
|
|
'''
|
|
cmd = ('scp -q -o "BatchMode yes" -o "NumberOfPasswordPrompts '
|
|
'0" ' + '%s %s %s %s %s:.ssh/' % (
|
|
self.authorized_keys_fname, self.known_hosts_fname,
|
|
self.id_rsa_fname, self.id_rsa_pub_fname, hostname))
|
|
(status, output) = subprocess.getstatusoutput(cmd)
|
|
if (status != 0):
|
|
raise Exception(ErrorCode.GAUSS_502["GAUSS_50223"]
|
|
% "the authentication" + " Node:%s. Error:\n%s."
|
|
% (hostname, output) + "The cmd is %s " % cmd)
|
|
|
|
def synchronizationLicenseFile(self):
|
|
'''
|
|
function: Distribution of documents through concurrent execution
|
|
ThreadPool.
|
|
input : NA
|
|
output: NA
|
|
'''
|
|
self.logOut(14, 0)
|
|
try:
|
|
parallelTool.parallelExecute(self.sendTrustFile, self.hostList)
|
|
except Exception as e:
|
|
self.logger.logExit(str(e))
|
|
self.logOut(15, 1)
|
|
|
|
def verifyTrust(self):
|
|
"""
|
|
function: Verify creating SSH trust is successful
|
|
input : NA
|
|
output: NA
|
|
"""
|
|
self.logOut(16, 0)
|
|
try:
|
|
results = parallelTool.parallelExecute(self.checkAuthentication,
|
|
self.hostList)
|
|
hostnames = ""
|
|
for (key, value) in results:
|
|
if (not key):
|
|
hostnames = hostnames + ',' + value
|
|
if (hostnames != ""):
|
|
raise Exception(ErrorCode.GAUSS_511["GAUSS_51100"]
|
|
% hostnames.lstrip(','))
|
|
except Exception as e:
|
|
self.logger.logExit(str(e))
|
|
self.logOut(17, 1)
|
|
|
|
def getUserPasswd(self):
|
|
"""
|
|
function: get user passwd from cache
|
|
input: NA
|
|
output: NA
|
|
"""
|
|
user_passwd = []
|
|
if (sys.stdin.isatty()):
|
|
GaussLog.printMessage(
|
|
"Please enter password for current user[%s]." % self.user)
|
|
user_passwd.append(getpass.getpass())
|
|
else:
|
|
user_passwd.append(sys.stdin.readline().strip('\n'))
|
|
|
|
if (not user_passwd):
|
|
GaussLog.exitWithError("Password should not be empty")
|
|
|
|
return user_passwd
|
|
|
|
|
|
if __name__ == '__main__':
|
|
# main function
|
|
createTrust = None
|
|
try:
|
|
createTrust = GaussCreateTrust()
|
|
createTrust.run()
|
|
except Exception as e:
|
|
if str(e).startswith("[GAUSS-"):
|
|
GaussLog.exitWithError(str(e))
|
|
else:
|
|
GaussLog.exitWithError("[GAUSS-50100]:"+str(e))
|
|
|
|
sys.exit(0)
|