515 lines
17 KiB
Python
515 lines
17 KiB
Python
#!/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
"""
|
|
util模块,提供一些辅助函数
|
|
Date: 2014/08/05 17:19:26
|
|
"""
|
|
import datetime
|
|
import os
|
|
import sys
|
|
import struct
|
|
import random
|
|
import inspect
|
|
import time
|
|
import subprocess
|
|
import pexpect
|
|
from decimal import Decimal
|
|
import hashlib
|
|
import functools
|
|
import palo_logger
|
|
|
|
LOG = palo_logger.Logger.getLogger()
|
|
L = palo_logger.StructedLogMessage
|
|
|
|
|
|
def pretty(data):
|
|
"""
|
|
将data转成string方便打印,需要list中的元素实现__str__方法
|
|
"""
|
|
result = ""
|
|
if isinstance(data, dict):
|
|
result += '{'
|
|
for key, value in data.iteritems():
|
|
result += '%s:%s' % (str(key), pretty(value))
|
|
result += ', '
|
|
result = result.rstrip(", ")
|
|
result += '}'
|
|
elif isinstance(data, list):
|
|
for val in data:
|
|
result += '['
|
|
result += pretty(val)
|
|
result += ', '
|
|
result = result.rstrip(", ")
|
|
result += ']'
|
|
else:
|
|
result += str(data)
|
|
return result
|
|
|
|
|
|
def gen_name_list(prefix=""):
|
|
"""
|
|
根据调用的文件名和函数名生成database_name, table_name, index_name
|
|
"""
|
|
file_name = ""
|
|
for command in sys.argv:
|
|
if command.endswith(".py"):
|
|
file_name = command
|
|
dir_name, file_name = os.path.split(os.path.abspath(file_name))
|
|
file_name = file_name[:-3]
|
|
function_name = inspect.stack()[1][3]
|
|
database_name = "%s_%s_%s_db" % (prefix, file_name, function_name)
|
|
LOG.info(L('test case', file_name=file_name, case=function_name))
|
|
if len(database_name) > 60:
|
|
md5_str = get_md5(database_name)
|
|
database_name = 'd' + md5_str[-4:] + '_' + database_name[-50:]
|
|
database_name = database_name.lstrip("_")
|
|
table_name = "%s_%s_%s_tb" % (prefix, file_name, function_name)
|
|
if len(table_name) > 60:
|
|
md5_str = get_md5(table_name)
|
|
table_name = 't' + md5_str[-4:] + '_' + table_name[-50:]
|
|
table_name = table_name.lstrip("_")
|
|
index_name = "%s_%s_%s_index" % (prefix, file_name, function_name)
|
|
if len(index_name) > 60:
|
|
md5_str = get_md5(index_name)
|
|
index_name = 'i' + md5_str[-4:] + '_' + index_name[-50:]
|
|
index_name = index_name.lstrip("_")
|
|
return database_name, table_name, index_name
|
|
|
|
|
|
def gen_num_format_name_list(prefix=""):
|
|
"""
|
|
根据调用的文件名和函数名生成database_name, table_name, index_name
|
|
在文件名和函数名生成的字符串过长时,取其md5值的一部分,尽可能减少case之间重名的问题
|
|
"""
|
|
file_name = ""
|
|
for command in sys.argv:
|
|
if command.endswith(".py"):
|
|
file_name = command
|
|
dir_name, file_name = os.path.split(os.path.abspath(file_name))
|
|
file_name = file_name[:-3]
|
|
function_name = inspect.stack()[1][3]
|
|
|
|
# 分别获取database_name, table_name, index_name
|
|
suffixes = ['db', 'tb', 'index']
|
|
names = []
|
|
for suffix in suffixes:
|
|
name = "%s_%s_%s_%s" % (prefix, file_name, function_name, suffix)
|
|
if len(name) > 60:
|
|
# palo里的数据库和表名必须以字符串开始
|
|
md5_str = get_md5(name)
|
|
name = 's' + md5_str[-4:] + '_' + name[-50:]
|
|
name = name.lstrip("_")
|
|
names.append(name)
|
|
return names
|
|
|
|
|
|
def get_label():
|
|
"""
|
|
生成label字符串
|
|
"""
|
|
fmt = '%d_%H_%M_%S_%f'
|
|
return "label_%s_%d" % (datetime.datetime.now().strftime(fmt), random.randint(0, 2 ** 31 - 1))
|
|
|
|
|
|
def get_snapshot_label(prefix=None):
|
|
"""生成snapshot label"""
|
|
if prefix is None:
|
|
prefix = 'random'
|
|
fmt = '%d_%H_%M_%S_%f'
|
|
return "%s_snapshot_%s_%d" % (prefix, datetime.datetime.now().strftime(fmt),
|
|
random.randint(0, 2 ** 31 - 1))
|
|
|
|
|
|
def column_to_sql(column, set_null=False):
|
|
"""
|
|
将column 4元组转成palo格式的sql字符串
|
|
(column_name, column_type, aggregation_type, default_value)
|
|
"""
|
|
sql = "%s %s" % (column[0], column[1])
|
|
#value列有聚合方法
|
|
#key列有默认值时指定此项为None
|
|
if len(column) > 2:
|
|
if column[2]:
|
|
sql = "%s %s" % (sql, column[2])
|
|
if set_null is False:
|
|
sql = sql + ' NOT NULL'
|
|
elif set_null is True:
|
|
sql = sql + ' NULL'
|
|
else:
|
|
pass
|
|
#有默认值的列
|
|
if len(column) > 3:
|
|
if column[3] is None:
|
|
sql = '%s DEFAULT NULL' % sql
|
|
else:
|
|
sql = '%s DEFAULT "%s"' % (sql, column[3])
|
|
return sql
|
|
|
|
|
|
def column_to_no_agg_sql(column, set_null=False):
|
|
"""
|
|
将column 4元组转成SQL字符串
|
|
(column_name, column_type, aggregation_type, default_value)
|
|
"""
|
|
sql = "%s %s" % (column[0], column[1])
|
|
if set_null is False:
|
|
sql = sql + ' NOT NULL'
|
|
elif set_null is True:
|
|
sql = sql + ' NULL'
|
|
else:
|
|
pass
|
|
#有默认值的列
|
|
if len(column) > 3:
|
|
sql = '%s DEFAULT "%s"' % (sql, column[3])
|
|
return sql
|
|
|
|
|
|
def convert_agg_column_to_no_agg_column(column_list):
|
|
"""
|
|
将agg column 4元组转成no agg column 4元组
|
|
(column_name, column_type, aggregation_type, default_value)
|
|
"""
|
|
no_agg_column = []
|
|
for i in column_list:
|
|
if len(i) == 2:
|
|
no_agg_column.append(i)
|
|
elif len(i) == 3:
|
|
no_agg_column.append((i[0], i[1]))
|
|
elif len(i) == 4:
|
|
no_agg_column.append((i[0], i[1], "", i[3]))
|
|
return no_agg_column
|
|
|
|
|
|
def file_to_insert_sql_value(file_name, to_str=False):
|
|
"""
|
|
将文件中的列转为insert into类型sql中value字段
|
|
会将文件中的N转为NULL(broker load和insert中空值的差别)
|
|
:param file_name:
|
|
:param to_str: 默认是false,这样的话对数字不会加双引号
|
|
:return:
|
|
"""
|
|
fp = open(file_name, 'r')
|
|
total_sql_list = []
|
|
for line in fp.readlines():
|
|
items = line.split('\n')[0].split('\t')
|
|
str = ''
|
|
for i in range(len(items)):
|
|
item = items[i]
|
|
if not to_str and is_number(item):
|
|
str += item
|
|
elif item == '\\N':
|
|
str += 'NULL'
|
|
else:
|
|
str += '"' + item + '"'
|
|
|
|
if i < len(items) - 1:
|
|
str += ','
|
|
total_sql_list.append('(' + str + ')')
|
|
return ','.join(total_sql_list)
|
|
|
|
|
|
def is_number(s):
|
|
"""
|
|
验证字符串是否为数字/浮点数
|
|
:param s:
|
|
:return:
|
|
"""
|
|
try:
|
|
if s == 'NaN':
|
|
return False
|
|
float(s)
|
|
return True
|
|
except ValueError:
|
|
return False
|
|
|
|
|
|
def exec_cmd(cmd, user=None, password=None, host=None, timeout=30):
|
|
"""
|
|
执行shell命令
|
|
"""
|
|
if user is not None and password is not None and host is not None:
|
|
if cmd.find("'") >= 0:
|
|
raise Exception("We do not support quote ' now!")
|
|
output, status = pexpect.run("ssh %s@%s '%s'" % (user, host, cmd),
|
|
timeout=timeout, withexitstatus=True,
|
|
events = {"continue connecting":"yes\n", "password:":"%s\n" % password})
|
|
LOG.info(L('exec remote cmd', cmd=cmd, output=output, status=status))
|
|
else:
|
|
status, output = subprocess.getstatusoutput(cmd)
|
|
return status, output
|
|
|
|
|
|
def compare(a, b):
|
|
"""compare data to None"""
|
|
assert isinstance(a, (tuple, list))
|
|
assert isinstance(b, (tuple, list))
|
|
for i in range(0, len(a)):
|
|
if a[i] is not None and b[i] is not None:
|
|
if a[i] > b[i]:
|
|
return 1
|
|
elif a[i] < b[i]:
|
|
return -1
|
|
else:
|
|
continue
|
|
elif a[i] is None and b[i] is None:
|
|
continue
|
|
elif a[i] is None:
|
|
return -1
|
|
else:
|
|
return 1
|
|
return 0
|
|
|
|
|
|
def check(palo_result, mysql_result, force_order=False):
|
|
"""
|
|
check the palo result and mysql result
|
|
1. data format
|
|
2. order
|
|
"""
|
|
if force_order and palo_result != ():
|
|
try:
|
|
palo_result = list(palo_result)
|
|
mysql_result = list(mysql_result)
|
|
palo_result.sort()
|
|
mysql_result.sort()
|
|
except Exception as e:
|
|
if isinstance(e, TypeError):
|
|
palo_result.sort(key=functools.cmp_to_key(compare))
|
|
mysql_result.sort(key=functools.cmp_to_key(compare))
|
|
if mysql_result != palo_result:
|
|
if len(palo_result) != len(mysql_result):
|
|
LOG.error(L('check error', palo_length=len(palo_result),
|
|
expect_length=len(mysql_result)))
|
|
assert 0 == 1, "\npalo_result length: %d\nmysql_result length:%d" % \
|
|
(len(palo_result), len(mysql_result))
|
|
for palo_line, mysql_line in zip(palo_result, mysql_result):
|
|
if len(palo_line) != len(mysql_line):
|
|
LOG.error(L('check error, number of columns not match',
|
|
palo_colums_number=len(palo_line), expect_columns_number=len(mysql_line)))
|
|
LOG.error(L('check error', palo_line=str(palo_line)))
|
|
LOG.error(L('check error', expect_line=str(mysql_line)))
|
|
assert 0 == 1, "\npalo line: %s\nmysql line:%s" % (str(palo_line), str(mysql_line))
|
|
if palo_line != mysql_line:
|
|
for palo_data, mysql_data in zip(palo_line, mysql_line):
|
|
if palo_data != mysql_data:
|
|
same = False
|
|
# str vs bytes
|
|
if isinstance(palo_data, (str, bytes)) and isinstance(mysql_data, (str, bytes)):
|
|
if isinstance(palo_data, bytes):
|
|
palo_data = str(palo_data, "utf8")
|
|
if isinstance(mysql_data, bytes):
|
|
mysql_data = str(mysql_data, "utf8")
|
|
if palo_data == mysql_data:
|
|
return True
|
|
# chinese & palo largeint return unicode
|
|
if isinstance(palo_data, str):
|
|
if palo_data == str(mysql_data):
|
|
same = True
|
|
if palo_data == mysql_data:
|
|
same = True
|
|
# blank
|
|
if isinstance(palo_data, str):
|
|
if palo_data.strip() == "" and mysql_data == "":
|
|
same = True
|
|
# null string
|
|
if palo_data is None and mysql_data == "":
|
|
same = True
|
|
#float
|
|
elif isinstance(mysql_data, (Decimal, float)) \
|
|
and isinstance(palo_data, (Decimal, float)):
|
|
same = check_float(float(palo_data), float(mysql_data))
|
|
# list
|
|
elif isinstance(mysql_data, list):
|
|
same = check_list(palo_data, mysql_data)
|
|
if not same:
|
|
LOG.error(L('check error', palo_data=palo_data))
|
|
LOG.error(L('check error', expect_data=mysql_data))
|
|
LOG.error(L('check error', palo_line=palo_line))
|
|
LOG.error(L('check error', expect_line=mysql_line))
|
|
assert 0 == 1, "\ndiff data: \npalo:%s; \nexpect:%s;\npalo line:%s\nexpect line: %s" \
|
|
% (palo_data, mysql_data, palo_line, mysql_line)
|
|
|
|
|
|
def check_float(data1, data2):
|
|
"""
|
|
check float
|
|
"""
|
|
if float(data1) == float(data2):
|
|
return True
|
|
if abs(float(data1) - float(data2)) < 0.001:
|
|
return True
|
|
if data2 != 0 and abs(float(data1) / float(data2) - 1.0) < 0.001:
|
|
return True
|
|
return False
|
|
|
|
|
|
def check_list(data1, data2):
|
|
"""
|
|
check list
|
|
"""
|
|
if data1 is None and data2 is None:
|
|
return True
|
|
elif data1 is None or data2 is None:
|
|
return False
|
|
elif len(data2) > 0 and isinstance(data2[0], float):
|
|
for d1, d2 in zip(data1, data2):
|
|
if not check_float(d1, d2):
|
|
return False
|
|
else:
|
|
return data1 == data2
|
|
return True
|
|
|
|
|
|
def convert_dict2property(properties):
|
|
"""convert mat to property str"""
|
|
sql = '('
|
|
for k, v in properties.items():
|
|
if v is not None:
|
|
sql += ' "%s" = "%s", ' % (k, v)
|
|
sql = sql.rstrip(', ')
|
|
sql += ')'
|
|
return sql
|
|
|
|
|
|
def get_timestamp(palo_data):
|
|
"""将datetime 转为时间戳"""
|
|
timeArray = palo_data[0][0].timetuple()
|
|
timestamp = time.mktime(timeArray)
|
|
return timestamp
|
|
|
|
|
|
def check2_time_zone(palo_re, mysql_res, gap=2):
|
|
"""比较两个datetime时间,转为时间戳,允许的差值是gap"""
|
|
palo_data = get_timestamp(palo_re)
|
|
mysql_data = get_timestamp(mysql_res)
|
|
assert abs(palo_data - mysql_data) <= 2 + gap, "palo_res %s, mysql_res %s, \
|
|
excepted gap < %s" % (palo_data, mysql_data, gap)
|
|
|
|
|
|
def assert_return(expected_flag, expected_msg, func, *args, **kwargs):
|
|
"""
|
|
验证一个函数的执行结果
|
|
验证是否正确;如果错误的话,验证错误信息
|
|
:param expected_flag:True/False
|
|
:param expected_msg:
|
|
:param func:
|
|
:param args:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
|
|
try:
|
|
func(*args, **kwargs)
|
|
except Exception as e:
|
|
print(str(e))
|
|
print(expected_msg)
|
|
LOG.info(L('get an error', msg=str(e)))
|
|
assert not expected_flag, "real sql status is False, expect %s" % expected_flag
|
|
assert expected_msg in str(e), "expect:%s, doris:%s" % (expected_msg, str(e))
|
|
else:
|
|
assert expected_flag, "real sql status is True, expect %s" % expected_flag
|
|
|
|
|
|
def get_md5(target):
|
|
"""
|
|
计算字符串的md5值
|
|
:param target:
|
|
:return:
|
|
"""
|
|
obj = hashlib.md5(b"fkldsajlkfjlaksdjfkladsjfkladsjkldsjfklfjs")
|
|
obj.update(target.encode("utf-8"))
|
|
return obj.hexdigest()
|
|
|
|
|
|
def assert_return_flag(expected_flag, func, *args, **kwargs):
|
|
"""
|
|
验证一个函数的返回值
|
|
验证是否正确
|
|
:param expected_flag:
|
|
:param func:
|
|
:param args:
|
|
:param kwargs:
|
|
:return:
|
|
"""
|
|
return_flag = func(*args, **kwargs)
|
|
assert expected_flag == return_flag
|
|
|
|
|
|
def bitmap_index_to_sql(bitmap_index, set_null=False):
|
|
"""
|
|
bitmap_index: 由3元组(index_name, column_name, index_type)组成
|
|
"""
|
|
sql = "INDEX %s (%s) USING %s" % (bitmap_index[0], bitmap_index[1], bitmap_index[2])
|
|
return sql
|
|
|
|
|
|
def get_attr(ret, column_idx):
|
|
"""
|
|
获取返回结果ret的第n列
|
|
ret = client.show_backend()
|
|
be_ip = get_attr(ret, palo_job.BackendShowInfo.IP)
|
|
"""
|
|
return_list = list()
|
|
for record in ret:
|
|
return_list.append(record[column_idx])
|
|
LOG.info(L('get column from ret', idx=column_idx, ret=return_list))
|
|
return return_list
|
|
|
|
|
|
def get_attr_condition_value(ret, condition_col_idx, condition_value, retrun_col_idx=None):
|
|
"""寻找ret的第condition_col_idx列的值为condition_value的行,返回该行的第retrun_col_idx列,只返回一个符合条件的"""
|
|
if retrun_col_idx is None:
|
|
retrun_col_idx = condition_col_idx
|
|
for record in ret:
|
|
if record[condition_col_idx] == condition_value:
|
|
LOG.info(L('get result from ret', searced_key=condition_value, value=record[retrun_col_idx]))
|
|
return record[retrun_col_idx]
|
|
LOG.info(L('can not get result from ret', searched_key=condition_value))
|
|
return None
|
|
|
|
|
|
def get_attr_condition_list(ret, condition_col_idx, condition_value, retrun_col_idx=None):
|
|
"""寻找ret的第condition_col_idx列的值为condition_value的行,返回所有行的第retrun_col_idx列,返回说有符合条件的"""
|
|
result_list = list()
|
|
if retrun_col_idx is None:
|
|
retrun_col_idx = condition_col_idx
|
|
for record in ret:
|
|
if record[condition_col_idx] == condition_value:
|
|
result_list.append(record[retrun_col_idx])
|
|
if len(result_list) == 0:
|
|
return None
|
|
else:
|
|
LOG.info(L('get all result from ret', searced_key=condition_value, value=result_list))
|
|
return result_list
|
|
|
|
|
|
def gen_tuple_num_str(begin, end):
|
|
"""gen_tuple_num_str(1, 3) -> ('1', '2')"""
|
|
return tuple(map(str, range(begin, end)))
|
|
|
|
|
|
def get_string_md5(st):
|
|
"""get string md5"""
|
|
hl = hashlib.md5()
|
|
hl.update(st.encode(encoding='utf-8'))
|
|
return hl.hexdigest()
|
|
|
|
|