Files
openGauss-OM/script/gs_sshexkey

1218 lines
45 KiB
Python

#!/usr/bin/env python3
# -*- coding:utf-8 -*-
#############################################################################
# Copyright (c) 2020 Huawei Technologies Co.,Ltd.
#
# openGauss is licensed under Mulan PSL v2.
# You can use this software according to the terms
# and conditions of the Mulan PSL v2.
# You may obtain a copy of Mulan PSL v2 at:
#
# http://license.coscl.org.cn/MulanPSL2
#
# THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND,
# EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
# MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
# See the Mulan PSL v2 for more details.
# ----------------------------------------------------------------------------
# Description : gs_sshexkey is a utility to create SSH trust among nodes in
# a cluster.
#############################################################################
import sys
import warnings
warnings.simplefilter('ignore', DeprecationWarning)
sys.path.append(sys.path[0] + "/../lib")
import time
import os
import subprocess
import pwd
import grp
import socket
import getpass
import shutil
package_path = os.path.dirname(os.path.realpath(__file__))
ld_path = package_path + "/gspylib/clib"
if 'LD_LIBRARY_PATH' not in os.environ:
os.environ['LD_LIBRARY_PATH'] = ld_path
os.execve(os.path.realpath(__file__), sys.argv, os.environ)
if not os.environ.get('LD_LIBRARY_PATH').startswith(ld_path):
os.environ['LD_LIBRARY_PATH'] = \
ld_path + ":" + os.environ['LD_LIBRARY_PATH']
os.execve(os.path.realpath(__file__), sys.argv, os.environ)
from gspylib.common.GaussLog import GaussLog
from gspylib.common.ErrorCode import ErrorCode
from gspylib.threads.parallelTool import parallelTool
from gspylib.common.Common import DefaultValue, ClusterCommand
from gspylib.common.ParameterParsecheck import Parameter
from gspylib.os.gsfile import g_file
from gspylib.os.gsOSlib import g_OSlib
DefaultValue.doConfigForParamiko()
import paramiko
HOSTS_MAPPING_FLAG = "#Gauss OM IP Hosts Mapping"
ipHostInfo = ""
# the tmp path
tmp_files = ""
# tmp file name
TMP_TRUST_FILE = "step_preinstall_file.dat"
class PrintOnScreen():
"""
class about print on screen
"""
def __init__(self):
"""
function : Constructor
input: NA
output: NA
"""
pass
def log(self, msg):
"""
function : print log
input: msg: str
output: NA
"""
print(msg)
def debug(self, msg):
"""
function : debug
input: msg: debug message string
output: NA
"""
pass
def logExit(self, msg):
"""
function : print log and exit
input: msg: str
output: NA
"""
print(msg)
sys.exit(1)
class GaussCreateTrust():
"""
class about create trust for user
"""
log_list = ["addStep",
"constant",
"Checking network information.",
"Successfully checked network information.",
"Creating the local key file.",
"Successfully created the local key files.",
"Appending local ID to authorized_keys.",
"Successfully appended local ID to authorized_keys.",
"Updating the known_hosts file.",
"Successfully updated the known_hosts file.",
"Appending authorized_key on the remote node.",
"Successfully appended authorized_key on all remote node.",
"Checking common authentication file content.",
"Successfully checked common authentication content.",
"Distributing SSH trust file to all node.",
"Successfully distributed SSH trust file to all node.",
"Verifying SSH trust on all hosts.",
"Successfully verified SSH trust on all hosts.",
]
def __init__(self):
"""
function : Constructor
input: NA
output: NA
"""
self.logger = None
self.hostFile = ""
self.hostList = []
self.passwd = []
self.logFile = ""
self.localHost = ""
self.flag = False
self.logger = None
self.localID = ""
self.user = pwd.getpwuid(os.getuid()).pw_name
self.group = grp.getgrgid(os.getgid()).gr_name
self.incorrectPasswdInfo = ""
self.failedToAppendInfo = ""
self.homeDir = os.path.expanduser("~" + self.user)
self.sshDir = "%s/.ssh" % self.homeDir
self.authorized_keys_fname = '%s/.ssh/authorized_keys' % self.homeDir
self.known_hosts_fname = '%s/.ssh/known_hosts' % self.homeDir
self.id_rsa_fname = '%s/.ssh/id_rsa' % self.homeDir
self.id_rsa_pub_fname = self.id_rsa_fname + '.pub'
self.skipHostnameSet = False
self.isKeyboardPassword = False
self.nodeduplicate = False
def usage(self):
"""
gs_sshexkey is a utility to create SSH trust among nodes in a cluster.
Usage:
gs_sshexkey -? | --help
gs_sshexkey -V | --version
gs_sshexkey -f HOSTFILE [-l LOGFILE] [--skip-hostname-set]
gs_sshexkey -h IPLIST [-l LOGFILE] [--skip-hostname-set]
General options:
-f Host file containing the IP address of nodes.
-h Host ip list. Separate multiple nodes with commas(,).
-l Path of log file.
--skip-hostname-set Whether to skip hostname setting.
(The default value is set.)
-?, --help Show help information for this utility,
and exit the command line mode.
-V, --version Show version information.
"""
print(self.usage.__doc__)
def parseCommandLine(self):
"""
function: Check parameter from command line
input : NA
output: NA
"""
paraObj = Parameter()
paraDict = paraObj.ParameterCommandLine("sshexkey")
if ("helpFlag" in paraDict.keys()):
self.usage()
sys.exit(0)
if ("hostfile" in paraDict.keys()):
self.hostFile = paraDict.get("hostfile")
if ("nodename" in paraDict.keys()):
self.hostList = paraDict.get("nodename")
if ("logFile" in paraDict.keys()):
self.logFile = paraDict.get("logFile")
if ("skipHostnameSet" in paraDict.keys()):
self.skipHostnameSet = paraDict.get("skipHostnameSet")
if ("noDeduplicate" in paraDict.keys()):
self.nodeduplicate = paraDict.get("noDeduplicate")
def checkParameter(self):
"""
function: Check parameter from command line
input : NA
output: NA
"""
# check required parameters
if len(self.hostList) == 0:
if (self.hostFile == ""):
self.usage()
GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50001"]
% 'f' + ".")
if (not os.path.exists(self.hostFile)):
GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50201"]
% self.hostFile)
if (not os.path.isabs(self.hostFile)):
GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"]
% self.hostFile)
# read host file to hostList
self.readHostFile()
if (self.hostList == []):
GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50004"]
% 'f' + " It cannot be empty.")
else:
for temp_host in self.hostList:
if not DefaultValue.isIpValid(temp_host):
GaussLog.exitWithError(ErrorCode.GAUSS_500["GAUSS_50000"]
% temp_host)
# check logfile
if (self.logFile != ""):
if (not os.path.isabs(self.logFile)):
GaussLog.exitWithError(ErrorCode.GAUSS_502["GAUSS_50213"]
% self.logFile)
if (not self.passwd):
self.passwd = self.getUserPasswd()
self.isKeyboardPassword = True
def logOut(self, log_index1, log_index2):
"""
function:logout
:param log_index1: index of the log
:param log_index2: indec of the log
:return:
"""
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.log(GaussCreateTrust.log_list[log_index1],
GaussCreateTrust.log_list[log_index2])
else:
self.logger.log(GaussCreateTrust.log_list[log_index1])
else:
self.logger.log(GaussCreateTrust.log_list[log_index1])
def readHostFile(self):
"""
function: read host file to hostList
input : NA
output: NA
"""
inValidIp = []
try:
with open(self.hostFile, "r") as f:
for readLine in f:
hostname = readLine.strip().split("\n")[0]
if hostname != "" and hostname not in self.hostList:
if not DefaultValue.isIpValid(hostname):
inValidIp.append(hostname)
continue
self.hostList.append(hostname)
if len(inValidIp) > 0:
GaussLog.exitWithError(ErrorCode.GAUSS_506["GAUSS_50603"]
+ "The IP list is:%s." % inValidIp)
except Exception as e:
raise Exception(ErrorCode.GAUSS_502["GAUSS_50204"] % "host file"
+ " Error: \n%s" % str(e))
def getAllHostsName(self, ip):
"""
function:
Connect to all nodes ,then get all hostaname by threading
precondition:
1.User's password is correct on each node
postcondition:
NA
input: ip
output:Dictionary ipHostname,key is IP and value is hostname
hideninfo:NA
"""
ipHostname = {}
try:
ssh = paramiko.Transport((ip, 22))
except Exception as e:
raise Exception(ErrorCode.GAUSS_512["GAUSS_51220"] % ip
+ " Error: \n%s" % str(e))
try:
ssh.connect(username=self.user, password=self.passwd[0])
except Exception as e:
ssh.close()
raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % ip)
check_channel = ssh.open_session()
cmd = "cd"
check_channel.exec_command(cmd)
env_msg = check_channel.recv_stderr(9999).decode()
while True:
channel_read = check_channel.recv(9999).decode().strip()
if (len(channel_read) != 0):
env_msg += str(channel_read)
else:
break
if (env_msg != ""):
ipHostname["Node[%s]" % ip] = "Output: [" + env_msg \
+ " ] print by /etc/profile or" \
" ~/.bashrc, please check it."
return ipHostname
channel = ssh.open_session()
cmd = "hostname"
channel.exec_command(cmd)
hostname = channel.recv(9999).decode().strip()
ipHostname[ip] = hostname
ssh.close()
return ipHostname
def verifyPasswd(self, ssh, pswd=None):
try:
ssh.connect(username=self.user, password=pswd)
return True
except Exception:
ssh.close()
return False
def parallelGetHosts(self, sshIps):
parallelResult = {}
ipHostname = parallelTool.parallelExecute(self.getAllHostsName, sshIps)
err_msg = ""
for i in ipHostname:
for (key, value) in i.items():
if (key.find("Node") >= 0):
err_msg += str(i)
else:
parallelResult[key] = value
if (len(err_msg) > 0):
raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"] % err_msg)
return parallelResult
def serialGetHosts(self, sshIps):
serialResult = {}
invalidIP = ""
boolInvalidIp = False
for sshIp in sshIps:
isPasswdOK = False
for pswd in self.passwd:
try:
ssh = paramiko.Transport((sshIp, 22))
except Exception as e:
self.logger.debug(str(e))
invalidIP += "Incorrect IP address: %s.\n" % sshIp
boolInvalidIp = True
break
isPasswdOK = self.verifyPasswd(ssh, pswd)
if (isPasswdOK):
break
if (boolInvalidIp):
boolInvalidIp = False
continue
if (not isPasswdOK and self.isKeyboardPassword):
GaussLog.printMessage("Please enter password for current"
" user[%s] on the node[%s]."
% (self.user, sshIp))
# Try entering the password 3 times interactively
for i in range(3):
KeyboardPassword = getpass.getpass()
DefaultValue.checkPasswordVaild(KeyboardPassword)
ssh = paramiko.Transport((sshIp, 22))
isPasswdOK = self.verifyPasswd(ssh, KeyboardPassword)
if (isPasswdOK):
self.passwd.append(KeyboardPassword)
break
else:
continue
# if isKeyboardPassword is true, 3 times after the password is
# also wrong to throw an unusual exit
if (not isPasswdOK):
raise Exception(ErrorCode.GAUSS_503["GAUSS_50306"] % sshIp)
cmd = "cd"
check_channel = ssh.open_session()
check_channel.exec_command(cmd)
check_result = check_channel.recv_stderr(9999).decode()
while True:
channel_read = check_channel.recv(9999).decode()
if (len(channel_read) != 0):
check_result += str(channel_read)
else:
break
if (check_result != ""):
raise Exception(ErrorCode.GAUSS_518["GAUSS_51808"]
% check_result + "Please check %s node"
" /etc/profile or ~/.bashrc"
% sshIp)
else:
cmd = "hostname"
channel = ssh.open_session()
channel.exec_command(cmd)
while True:
hostname = channel.recv(9999).decode().strip()
if (len(hostname) != 0):
serialResult[sshIp] = hostname
else:
break
ssh.close()
if (invalidIP):
raise Exception(
ErrorCode.GAUSS_511["GAUSS_51101"] % invalidIP.rstrip("\n"))
return serialResult
def getAllHosts(self, sshIps):
"""
function:
Connect to all nodes ,then get all hostaname
precondition:
1.User's password is correct on each node
postcondition:
NA
input: sshIps,username,passwd
output:Dictionary ipHostname,key is IP and value is hostname
hideninfo:NA
"""
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.debug("Get hostnames for all nodes.", "addStep")
else:
self.logger.debug("Get hostnames for all nodes.")
if (len(self.passwd) == 0):
self.isKeyboardPassword = True
GaussLog.printMessage("Please enter password for current user[%s]."
% self.user)
passwd = getpass.getpass()
self.passwd.append(passwd)
if (len(self.passwd) == 1):
try:
result = self.parallelGetHosts(sshIps)
except Exception as e:
if (self.isKeyboardPassword and str(e).startswith(
"[GAUSS-50306] : The password of")):
GaussLog.printMessage(
"Notice :The password of some nodes is incorrect.")
result = self.serialGetHosts(sshIps)
else:
raise Exception(str(e))
else:
result = self.serialGetHosts(sshIps)
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.debug("Successfully get hostnames for all nodes.",
"constant")
else:
self.logger.debug("Successfully get hostnames for all nodes.")
return result
def writeLocalHosts(self, result):
"""
function:
Write hostname and Ip into /etc/hosts when there's not the same one
in /etc/hosts file
precondition:
NA
postcondition:
NA
input: Dictionary result,key is IP and value is hostname
output: NA
hideninfo:NA
"""
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.debug(
"Write local hostname and Ip into /etc/hosts.", "addStep")
else:
self.logger.debug(
"Write local hostname and Ip into /etc/hosts.")
hostIPInfo = ""
if (os.getuid() == 0):
tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()
# Check if /etc/hosts exists.
if (not os.path.exists("/etc/hosts")):
raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"] +
" Error: \nThe /etc/hosts does not exist.")
cmd = "grep -v '" + HOSTS_MAPPING_FLAG + "' /etc/hosts| grep -v '^$'"
(status, output) = subprocess.getstatusoutput(cmd)
try:
g_file.createFile(tmpHostIpName)
g_file.changeMode(DefaultValue.KEY_FILE_MODE, tmpHostIpName)
g_file.writeFile(tmpHostIpName, [output])
shutil.copyfile(tmpHostIpName, '/etc/hosts')
g_file.removeFile(tmpHostIpName)
except Exception as e:
if os.path.exists(tmpHostIpName):
g_file.removeFile(tmpHostIpName)
raise Exception(str(e))
if not self.nodeduplicate:
ipCompare = []
for line in output.split("\n"):
if line:
ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
tmpResult = {}
for s_key in list(result.keys()):
if s_key not in ipCompare:
tmpResult[s_key] = result[s_key]
for (key, value) in tmpResult.items():
hostIPInfo += '%s %s %s\n' % (key, value, HOSTS_MAPPING_FLAG)
else:
for (key, value) in result.items():
hostIPInfo += '%s %s %s\n' % (key, value, HOSTS_MAPPING_FLAG)
hostIPInfo = hostIPInfo[:-1]
ipInfoList = [hostIPInfo]
g_file.writeFile("/etc/hosts", ipInfoList)
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.debug(
"Successfully write local hostname and Ip into "
"/etc/hosts.",
"constant")
else:
self.logger.debug(
"Successfully write local hostname and Ip into "
"/etc/hosts.")
def writeRemoteHostName(self, ip):
"""
function:
Write hostname and Ip into /etc/hosts when there's not the same one
in /etc/hosts file by threading
precondition:
NA
postcondition:
NA
input: ip
output: NA
hideninfo:NA
"""
writeResult = []
result = {}
tmpHostIpName = "./tmp_hostsiphostname_%d_%s" % (os.getpid(), ip)
username = pwd.getpwuid(os.getuid()).pw_name
global ipHostInfo
try:
ssh = paramiko.Transport((ip, 22))
except Exception as e:
raise Exception(ErrorCode.GAUSS_511["GAUSS_51107"]
+ " Error: \n%s" % str(e))
try:
ssh.connect(username=username, password=self.passwd[0])
except Exception as e:
ssh.close()
raise Exception(ErrorCode.GAUSS_503["GAUSS_50317"]
+ " Error: \n%s" % str(e))
cmd = "grep -v '%s' %s | grep -v '^$'" \
% (" #Gauss.* IP Hosts Mapping", '/etc/hosts')
channel = ssh.open_session()
channel.exec_command(cmd)
ipHosts = channel.recv(9999).decode().strip()
errInfo = channel.recv_stderr(9999).decode().strip()
cmd = "echo \"%s\" > %s ; cp %s %s && rm -rf %s" \
% (ipHosts, tmpHostIpName, tmpHostIpName, '/etc/hosts', tmpHostIpName)
channel = ssh.open_session()
channel.exec_command(cmd)
ipHosts1 = channel.recv(9999).decode().strip()
errInfo1 = channel.recv_stderr(9999).decode().strip()
if ((errInfo + errInfo1)):
writeResult.append(errInfo + errInfo1)
else:
if (not ipHosts1):
if not self.nodeduplicate:
ipCompare = []
for line in ipHosts.split("\n"):
if line:
ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
tmpIpHostInfo = ""
ipArray = ipHostInfo.split("\n")
for info in ipArray:
hostname = info.split(' ')[0]
if hostname not in ipCompare:
tmpIpHostInfo += info + "\n"
cmd = "echo '%s' >> /etc/hosts" % tmpIpHostInfo
else:
cmd = "echo '%s' >> /etc/hosts" % ipHostInfo
channel = ssh.open_session()
channel.exec_command(cmd)
errInfo = channel.recv_stderr(9999).decode().strip()
if (errInfo):
writeResult.append(errInfo)
if channel:
channel.close()
result[ip] = writeResult
if (len(writeResult) > 0):
return (False, result)
else:
return (True, result)
def writeRemoteHosts(self, result, username, rootPasswd):
"""
function:
Write hostname and Ip into /etc/hosts when there's not the same one
in /etc/hosts file
precondition:
NA
postcondition:
NA
input: Dictionary result,key is IP and value is hostname
rootPasswd
output: NA
hideninfo:NA
"""
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.debug(
"Write remote hostname and Ip into /etc/hosts.", "addStep")
else:
self.logger.debug(
"Write remote hostname and Ip into /etc/hosts.")
global ipHostInfo
boolInvalidIp = False
ipHostInfo = ""
if (os.getuid() == 0):
writeResult = []
tmpHostIpName = "./tmp_hostsiphostname_%d" % os.getpid()
if (len(rootPasswd) == 1):
result1 = {}
for (key, value) in result.items():
ipHostInfo += '%s %s %s\n' % (key, value,
HOSTS_MAPPING_FLAG)
if (value != self.localHost):
if (not value in result1.keys()):
result1[value] = key
sshIps = result1.keys()
ipHostInfo = ipHostInfo[:-1]
if (sshIps):
ipRemoteHostname = parallelTool.parallelExecute(
self.writeRemoteHostName, sshIps)
errorMsg = ""
for (key, value) in ipRemoteHostname:
if (not key):
errorMsg = errorMsg + '\n' + str(value)
if (errorMsg != ""):
raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"]
+ " Error: %s" % errorMsg)
else:
for (key, value) in result.items():
if (value == self.localHost):
continue
for pswd in rootPasswd:
try:
ssh = paramiko.Transport((key, 22))
except Exception as e:
self.logger.debug(str(e))
boolInvalidIp = True
break
try:
ssh.connect(username=username, password=pswd)
break
except Exception as e:
self.logger.debug(str(e))
continue
if (boolInvalidIp):
boolInvalidIp = False
continue
cmd = "grep -v '%s' %s | grep -v '^$'" % (
" #Gauss.* IP Hosts Mapping", '/etc/hosts')
channel = ssh.open_session()
channel.exec_command(cmd)
ipHosts = channel.recv(9999).decode().strip()
errInfo = channel.recv_stderr(9999).decode().strip()
cmd = "echo \"%s\" > %s ; cp %s %s && rm -rf %s" % (
ipHosts, tmpHostIpName, tmpHostIpName,
'/etc/hosts', tmpHostIpName)
channel = ssh.open_session()
channel.exec_command(cmd)
ipHosts1 = channel.recv(9999).decode().strip()
errInfo1 = channel.recv_stderr(9999).decode().strip()
if (errInfo + errInfo1):
writeResult.append(errInfo + errInfo1)
else:
if (not ipHosts1):
ipHostInfo = ""
if not self.nodeduplicate:
ipCompare = []
for line in ipHosts.split("\n"):
if line:
ipCompare.append(line.replace("\t", " ").strip().split(' ')[0])
for (key1, value1) in result.items():
if key1 not in ipCompare:
ipHostInfo += '%s %s %s\n' % (
key1, value1, HOSTS_MAPPING_FLAG)
else:
for (key1, value1) in result.items():
ipHostInfo += '%s %s %s\n' % (
key1, value1, HOSTS_MAPPING_FLAG)
ipHostInfo = ipHostInfo[:-1]
cmd = "echo '%s' >> /etc/hosts" % ipHostInfo
channel = ssh.open_session()
channel.exec_command(cmd)
errInfo = channel.recv_stderr(
9999).decode().strip()
if (errInfo):
writeResult.append(errInfo)
if channel:
channel.close()
if (len(writeResult) > 0):
raise Exception(ErrorCode.GAUSS_512["GAUSS_51221"]
+ " Error: \n%s" % writeResult)
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.debug(
"Successfully write remote hostname and Ip into "
"/etc/hosts.",
"constant")
else:
self.logger.debug(
"Successfully write remote hostname and Ip into "
"/etc/hosts.")
def initLogger(self):
"""
function: Init logger
input : NA
output: NA
"""
if (self.logFile != ""):
self.logger = GaussLog(self.logFile, "gs_sshexkey")
else:
self.logger = PrintOnScreen()
def checkNetworkInfo(self):
"""
function: check local node to other node Network Information
input : NA
output: NA
"""
self.logOut(2, 0)
try:
netWorkList = DefaultValue.checkIsPing(self.hostList)
if not netWorkList:
self.logger.log("All nodes in the network are Normal.")
else:
self.logger.logExit(ErrorCode.GAUSS_506["GAUSS_50600"]
+ "The IP list is:%s." % netWorkList)
except Exception as e:
self.logger.logExit(str(e))
self.logOut(3, 1)
def run(self):
"""
function: Do create SSH trust
input : NA
output: NA
"""
self.parseCommandLine()
self.checkParameter()
self.localHost = socket.gethostname()
self.initLogger()
global tmp_files
tmp_files = "/tmp/%s" % TMP_TRUST_FILE
if (self.logFile != ""):
if (not os.path.exists(tmp_files)):
self.logger.debug(
"gs_sshexkey execution takes %s steps in total" %
ClusterCommand.countTotalSteps("gs_sshexkey", "",
self.skipHostnameSet))
Ips = []
Ips.extend(self.hostList)
result = self.getAllHosts(Ips)
self.checkNetworkInfo()
if not self.skipHostnameSet:
self.writeLocalHosts(result)
self.writeRemoteHosts(result, self.user, self.passwd)
self.logger.log("Creating SSH trust.")
try:
self.localID = self.createPublicPrivateKeyFile()
self.addLocalAuthorized()
self.updateKnow_hostsFile(result)
self.addRemoteAuthorization()
self.determinePublicAuthorityFile()
self.synchronizationLicenseFile()
self.verifyTrust()
self.logger.log("Successfully created SSH trust.")
except Exception as e:
self.logger.logExit(str(e))
def createPublicPrivateKeyFile(self):
"""
function: create local public private key file
input : NA
output: NA
"""
self.logOut(4, 0)
if not os.path.exists(self.id_rsa_pub_fname):
cmd = 'ssh-keygen -t rsa -N \"\" -f ~/.ssh/id_rsa < /dev/null'
cmd += "&& chmod %s %s %s" % (DefaultValue.KEY_FILE_MODE,
self.id_rsa_fname,
self.id_rsa_pub_fname)
(status, output) = subprocess.getstatusoutput(cmd)
if (status != 0):
self.logger.log("The cmd is %s " % cmd)
raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"]
+ " Error:\n%s" % output)
try:
try:
with open(self.id_rsa_pub_fname, 'r') as f:
return f.readline().strip()
except IOError as e:
self.logger.debug(str(e))
raise Exception(ErrorCode.GAUSS_511["GAUSS_51108"]
+ " Unable to read the generated file."
+ self.id_rsa_pub_fname)
finally:
self.logOut(5, 1)
def addLocalAuthorized(self):
"""
function: append the local id_rsa.pub value provided to authorized_keys
input : NA
output: NA
"""
self.logOut(6, 0)
g_file.createFileInSafeMode(self.authorized_keys_fname)
with open(self.authorized_keys_fname, 'a+') as f:
for line in f:
if line.strip() == self.localID:
# The localID is already in authorizedKeys; no need to add
return
f.write(self.localID)
f.write('\n')
self.logOut(7, 1)
g_file.changeMode(DefaultValue.KEY_FILE_MODE,
self.authorized_keys_fname)
def checkAuthentication(self, hostname):
"""
function: Ensure the proper password-less access to the remote host.
input : hostname
output: True/False, hostname
"""
cmd = 'ssh -n %s %s true' % (DefaultValue.SSH_OPTION, hostname)
(status, output) = subprocess.getstatusoutput(cmd)
if (status != 0):
self.logger.debug("The cmd is %s " % cmd)
self.logger.debug(
"Failed to check authentication. Hostname:%s. Error: \n%s"
% (hostname, output))
return (False, hostname)
return (True, hostname)
def updateKnow_hostsFile(self, result):
"""
function: keyscan all hosts and update known_hosts file
input : result
output: NA
"""
self.logOut(8, 0)
hostnameList = []
hostnameList.extend(self.hostList)
for (key, value) in result.items():
hostnameList.append(value)
for hostname in hostnameList:
cmd = 'ssh-keyscan -t rsa %s >> %s ' % (hostname,
self.known_hosts_fname)
cmd += "&& chmod %s %s" % (DefaultValue.KEY_FILE_MODE,
self.known_hosts_fname)
(status, output) = subprocess.getstatusoutput(cmd)
if (status != 0):
raise Exception(ErrorCode.GAUSS_514["GAUSS_51400"] % cmd
+ " Error:\n%s" % output)
(status, output) = self.checkAuthentication(self.localHost)
if not status:
raise Exception(
ErrorCode.GAUSS_511["GAUSS_51100"] % self.localHost)
self.logOut(9, 1)
def tryParamikoConnect(self, hostname, client, pswd=None, silence=False):
"""
function: try paramiko connect
input : hostname, client, pswd, silence
output: True/False
"""
try:
client.connect(hostname, password=pswd, allow_agent=False,
look_for_keys=False)
return True
except paramiko.AuthenticationException as e:
if not silence:
self.logger.debug("Incorrect password. Node: %s." % hostname
+ " Error:\n%s" % str(e))
client.close()
return False
except Exception as e:
if not silence:
self.logger.debug('[SSHException %s] %s' % (hostname, str(e)))
client.close()
raise Exception(str(e))
def addRemoteAuthorization(self):
"""
function: Send local ID to remote over SSH, and append to
authorized_key
input : NA
output: NA
"""
self.logOut(10, 0)
try:
parallelTool.parallelExecute(self.sendRemoteAuthorization,
self.hostList)
if (self.incorrectPasswdInfo != ""):
self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"]
% (self.incorrectPasswdInfo.rstrip("\n")))
if (self.failedToAppendInfo != ""):
self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51101"]
% (self.failedToAppendInfo.rstrip("\n")))
except Exception as e:
self.logger.logExit(ErrorCode.GAUSS_511["GAUSS_51111"]
+ " Error:%s." % str(e))
self.logOut(11, 1)
def sendRemoteAuthorization(self, hostname):
"""
function: send remote authorization
input : hostname
output: NA
"""
if (hostname != self.localHost):
p = None
cin = cout = cerr = None
try:
# ssh Remote Connection other node
p = paramiko.SSHClient()
p.load_system_host_keys()
ok = self.tryParamikoConnect(hostname, p, self.passwd[0],
silence=True)
if not ok:
for pswd in self.passwd[1:]:
ok = self.tryParamikoConnect(hostname, p, pswd,
silence=True)
if ok:
break
if not ok:
self.incorrectPasswdInfo += "Without this node[%s] of " \
"the correct password.\n" % \
hostname
return
# Create .ssh directory and ensure content meets permission
# requirements for password-less SSH
cmd = ('mkdir -p .ssh; ' + "chown -R %s:%s %s; " %
(self.user, self.group, self.sshDir) + 'chmod %s .ssh; '
% DefaultValue.KEY_DIRECTORY_MODE
+ 'touch .ssh/authorized_keys; '
+ 'touch .ssh/known_hosts; '
+ 'chmod %s .ssh/auth* .ssh/id* .ssh/known_hosts; '
% DefaultValue.KEY_FILE_MODE)
(cin, cout, cerr) = p.exec_command(cmd)
cin.close()
cout.close()
cerr.close()
# Append the ID to authorized_keys;
cnt = 0
cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ok' \
% self.localID
(cin, cout, cerr) = p.exec_command(cmd)
cin.close()
# readline will read other msg.
line = cout.read().decode()
while (line.find("ok ok ok") < 0):
time.sleep(cnt * 2)
cmd = 'echo \"%s\" >> .ssh/authorized_keys && echo ok ok ' \
'ok' % self.localID
(cin, cout, cerr) = p.exec_command(cmd)
cin.close()
cnt += 1
line = cout.readline()
if (cnt >= 3):
break
if (line.find("ok ok ok") < 0):
continue
else:
break
if (line.find("ok ok ok") < 0):
self.failedToAppendInfo += "...send to %s\nFailed to " \
"append local ID to " \
"authorized_keys on remote " \
"node %s.\n" % (
hostname, hostname)
return
cout.close()
cerr.close()
self.logger.debug(
"Send to %s\nSuccessfully appended authorized_key on "
"remote node %s." % (hostname, hostname))
finally:
if cin:
cin.close()
if cout:
cout.close()
if cerr:
cerr.close()
if p:
p.close()
def determinePublicAuthorityFile(self):
'''
function: determine common authentication file content
input : NA
output: NA
'''
self.logOut(12, 0)
# eliminate duplicates in known_hosts file
try:
tab = self.readKnownHosts()
self.writeKnownHosts(tab)
except IOError as e:
self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"]
% "known hosts file" + " Error:\n%s" % str(e))
# eliminate duploicates in authorized_keys file
try:
tab = self.readAuthorizedKeys()
self.writeAuthorizedKeys(tab)
except IOError as e:
self.logger.logExit(ErrorCode.GAUSS_502["GAUSS_50230"]
% "authorized keys file" + " Error:\n%s"
% str(e))
self.logOut(13, 1)
def addRemoteID(self, tab, line):
"""
function: add remote node id
input : tab, line
output: True/False
"""
IDKey = line.strip().split()
if not (len(IDKey) == 3 and line[0] != '#'):
return False
tab[IDKey[2]] = line
return True
def readAuthorizedKeys(self, tab=None, keysFile=None):
"""
function: read authorized keys
input : tab, keysFile
output: tab
"""
if not keysFile:
keysFile = self.authorized_keys_fname
if not tab:
tab = {}
with open(keysFile, 'r') as f:
for line in f:
self.addRemoteID(tab, line)
return tab
def writeAuthorizedKeys(self, tab, keysFile=None):
"""
function: write authorized keys
input : tab, keysFile
output: True/False
"""
if not keysFile:
keysFile = self.authorized_keys_fname
with open(keysFile, 'w') as f:
for IDKey in tab:
f.write(tab[IDKey])
def addKnownHost(self, tab, line):
"""
function: add known host
input : tab, line
output: True/False
"""
key = line.strip().split()
if not (len(key) == 3 and line[0] != '#'):
return False
tab[key[0]] = line
return True
def readKnownHosts(self, tab=None, hostsFile=None):
"""
function: read known host
input : tab, hostsFile
output: tab
"""
if not hostsFile:
hostsFile = self.known_hosts_fname
if not tab:
tab = {}
with open(hostsFile, 'r') as f:
for line in f:
self.addKnownHost(tab, line)
return tab
def writeKnownHosts(self, tab, hostsFile=None):
"""
function: write known host
input : tab, hostsFile
output: NA
"""
if not hostsFile:
hostsFile = self.known_hosts_fname
with open(hostsFile, 'w') as f:
for key in tab:
f.write(tab[key])
def sendTrustFile(self, hostname):
'''
function: Set or update the authentication files on hostname
input : hostname
output: NA
'''
cmd = ('scp -q -o "BatchMode yes" -o "NumberOfPasswordPrompts '
'0" ' + '%s %s %s %s %s:.ssh/' % (
self.authorized_keys_fname, self.known_hosts_fname,
self.id_rsa_fname, self.id_rsa_pub_fname, hostname))
(status, output) = subprocess.getstatusoutput(cmd)
if (status != 0):
raise Exception(ErrorCode.GAUSS_502["GAUSS_50223"]
% "the authentication" + " Node:%s. Error:\n%s."
% (hostname, output) + "The cmd is %s " % cmd)
def synchronizationLicenseFile(self):
'''
function: Distribution of documents through concurrent execution
ThreadPool.
input : NA
output: NA
'''
self.logOut(14, 0)
try:
parallelTool.parallelExecute(self.sendTrustFile, self.hostList)
except Exception as e:
self.logger.logExit(str(e))
self.logOut(15, 1)
def verifyTrust(self):
"""
function: Verify creating SSH trust is successful
input : NA
output: NA
"""
self.logOut(16, 0)
try:
results = parallelTool.parallelExecute(self.checkAuthentication,
self.hostList)
hostnames = ""
for (key, value) in results:
if (not key):
hostnames = hostnames + ',' + value
if (hostnames != ""):
raise Exception(ErrorCode.GAUSS_511["GAUSS_51100"]
% hostnames.lstrip(','))
except Exception as e:
self.logger.logExit(str(e))
self.logOut(17, 1)
def getUserPasswd(self):
"""
function: get user passwd from cache
input: NA
output: NA
"""
user_passwd = []
if (sys.stdin.isatty()):
GaussLog.printMessage(
"Please enter password for current user[%s]." % self.user)
user_passwd.append(getpass.getpass())
else:
user_passwd.append(sys.stdin.readline().strip('\n'))
if (not user_passwd):
GaussLog.exitWithError("Password should not be empty")
return user_passwd
if __name__ == '__main__':
# main function
createTrust = None
try:
createTrust = GaussCreateTrust()
createTrust.run()
except Exception as e:
if str(e).startswith("[GAUSS-"):
GaussLog.exitWithError(str(e))
else:
GaussLog.exitWithError("[GAUSS-50100]:"+str(e))
sys.exit(0)