#!/bin/env python # -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """ util模块,提供一些辅助函数 Date: 2014/08/05 17:19:26 """ import datetime import os import sys import struct import random import inspect import time import subprocess import pexpect from decimal import Decimal import hashlib import functools import palo_logger LOG = palo_logger.Logger.getLogger() L = palo_logger.StructedLogMessage def pretty(data): """ 将data转成string方便打印,需要list中的元素实现__str__方法 """ result = "" if isinstance(data, dict): result += '{' for key, value in data.iteritems(): result += '%s:%s' % (str(key), pretty(value)) result += ', ' result = result.rstrip(", ") result += '}' elif isinstance(data, list): for val in data: result += '[' result += pretty(val) result += ', ' result = result.rstrip(", ") result += ']' else: result += str(data) return result def gen_name_list(prefix=""): """ 根据调用的文件名和函数名生成database_name, table_name, index_name """ file_name = "" for command in sys.argv: if command.endswith(".py"): file_name = command dir_name, file_name = os.path.split(os.path.abspath(file_name)) file_name = file_name[:-3] function_name = inspect.stack()[1][3] database_name = "%s_%s_%s_db" % (prefix, file_name, function_name) LOG.info(L('test case', file_name=file_name, case=function_name)) if len(database_name) > 60: md5_str = get_md5(database_name) database_name = 'd' + md5_str[-4:] + '_' + database_name[-50:] database_name = database_name.lstrip("_") table_name = "%s_%s_%s_tb" % (prefix, file_name, function_name) if len(table_name) > 60: md5_str = get_md5(table_name) table_name = 't' + md5_str[-4:] + '_' + table_name[-50:] table_name = table_name.lstrip("_") index_name = "%s_%s_%s_index" % (prefix, file_name, function_name) if len(index_name) > 60: md5_str = get_md5(index_name) index_name = 'i' + md5_str[-4:] + '_' + index_name[-50:] index_name = index_name.lstrip("_") return database_name, table_name, index_name def gen_num_format_name_list(prefix=""): """ 根据调用的文件名和函数名生成database_name, table_name, index_name 在文件名和函数名生成的字符串过长时,取其md5值的一部分,尽可能减少case之间重名的问题 """ file_name = "" for command in sys.argv: if command.endswith(".py"): file_name = command dir_name, file_name = os.path.split(os.path.abspath(file_name)) file_name = file_name[:-3] function_name = inspect.stack()[1][3] # 分别获取database_name, table_name, index_name suffixes = ['db', 'tb', 'index'] names = [] for suffix in suffixes: name = "%s_%s_%s_%s" % (prefix, file_name, function_name, suffix) if len(name) > 60: # palo里的数据库和表名必须以字符串开始 md5_str = get_md5(name) name = 's' + md5_str[-4:] + '_' + name[-50:] name = name.lstrip("_") names.append(name) return names def get_label(): """ 生成label字符串 """ fmt = '%d_%H_%M_%S_%f' return "label_%s_%d" % (datetime.datetime.now().strftime(fmt), random.randint(0, 2 ** 31 - 1)) def get_snapshot_label(prefix=None): """生成snapshot label""" if prefix is None: prefix = 'random' fmt = '%d_%H_%M_%S_%f' return "%s_snapshot_%s_%d" % (prefix, datetime.datetime.now().strftime(fmt), random.randint(0, 2 ** 31 - 1)) def column_to_sql(column, set_null=False): """ 将column 4元组转成palo格式的sql字符串 (column_name, column_type, aggregation_type, default_value) """ sql = "%s %s" % (column[0], column[1]) #value列有聚合方法 #key列有默认值时指定此项为None if len(column) > 2: if column[2]: sql = "%s %s" % (sql, column[2]) if set_null is False: sql = sql + ' NOT NULL' elif set_null is True: sql = sql + ' NULL' else: pass #有默认值的列 if len(column) > 3: if column[3] is None: sql = '%s DEFAULT NULL' % sql else: sql = '%s DEFAULT "%s"' % (sql, column[3]) return sql def column_to_no_agg_sql(column, set_null=False): """ 将column 4元组转成SQL字符串 (column_name, column_type, aggregation_type, default_value) """ sql = "%s %s" % (column[0], column[1]) if set_null is False: sql = sql + ' NOT NULL' elif set_null is True: sql = sql + ' NULL' else: pass #有默认值的列 if len(column) > 3: sql = '%s DEFAULT "%s"' % (sql, column[3]) return sql def convert_agg_column_to_no_agg_column(column_list): """ 将agg column 4元组转成no agg column 4元组 (column_name, column_type, aggregation_type, default_value) """ no_agg_column = [] for i in column_list: if len(i) == 2: no_agg_column.append(i) elif len(i) == 3: no_agg_column.append((i[0], i[1])) elif len(i) == 4: no_agg_column.append((i[0], i[1], "", i[3])) return no_agg_column def file_to_insert_sql_value(file_name, to_str=False): """ 将文件中的列转为insert into类型sql中value字段 会将文件中的N转为NULL(broker load和insert中空值的差别) :param file_name: :param to_str: 默认是false,这样的话对数字不会加双引号 :return: """ fp = open(file_name, 'r') total_sql_list = [] for line in fp.readlines(): items = line.split('\n')[0].split('\t') str = '' for i in range(len(items)): item = items[i] if not to_str and is_number(item): str += item elif item == '\\N': str += 'NULL' else: str += '"' + item + '"' if i < len(items) - 1: str += ',' total_sql_list.append('(' + str + ')') return ','.join(total_sql_list) def is_number(s): """ 验证字符串是否为数字/浮点数 :param s: :return: """ try: if s == 'NaN': return False float(s) return True except ValueError: return False def exec_cmd(cmd, user=None, password=None, host=None, timeout=30): """ 执行shell命令 """ if user is not None and password is not None and host is not None: if cmd.find("'") >= 0: raise Exception("We do not support quote ' now!") output, status = pexpect.run("ssh %s@%s '%s'" % (user, host, cmd), timeout=timeout, withexitstatus=True, events = {"continue connecting":"yes\n", "password:":"%s\n" % password}) LOG.info(L('exec remote cmd', cmd=cmd, output=output, status=status)) else: status, output = subprocess.getstatusoutput(cmd) return status, output def compare(a, b): """compare data to None""" assert isinstance(a, (tuple, list)) assert isinstance(b, (tuple, list)) for i in range(0, len(a)): if a[i] is not None and b[i] is not None: if a[i] > b[i]: return 1 elif a[i] < b[i]: return -1 else: continue elif a[i] is None and b[i] is None: continue elif a[i] is None: return -1 else: return 1 return 0 def check(palo_result, mysql_result, force_order=False): """ check the palo result and mysql result 1. data format 2. order """ if force_order and palo_result != (): try: palo_result = list(palo_result) mysql_result = list(mysql_result) palo_result.sort() mysql_result.sort() except Exception as e: if isinstance(e, TypeError): palo_result.sort(key=functools.cmp_to_key(compare)) mysql_result.sort(key=functools.cmp_to_key(compare)) if mysql_result != palo_result: if len(palo_result) != len(mysql_result): LOG.error(L('check error', palo_length=len(palo_result), expect_length=len(mysql_result))) assert 0 == 1, "\npalo_result length: %d\nmysql_result length:%d" % \ (len(palo_result), len(mysql_result)) for palo_line, mysql_line in zip(palo_result, mysql_result): if len(palo_line) != len(mysql_line): LOG.error(L('check error, number of columns not match', palo_colums_number=len(palo_line), expect_columns_number=len(mysql_line))) LOG.error(L('check error', palo_line=str(palo_line))) LOG.error(L('check error', expect_line=str(mysql_line))) assert 0 == 1, "\npalo line: %s\nmysql line:%s" % (str(palo_line), str(mysql_line)) if palo_line != mysql_line: for palo_data, mysql_data in zip(palo_line, mysql_line): if palo_data != mysql_data: same = False # str vs bytes if isinstance(palo_data, (str, bytes)) and isinstance(mysql_data, (str, bytes)): if isinstance(palo_data, bytes): palo_data = str(palo_data, "utf8") if isinstance(mysql_data, bytes): mysql_data = str(mysql_data, "utf8") if palo_data == mysql_data: return True # chinese & palo largeint return unicode if isinstance(palo_data, str): if palo_data == str(mysql_data): same = True if palo_data == mysql_data: same = True # blank if isinstance(palo_data, str): if palo_data.strip() == "" and mysql_data == "": same = True # null string if palo_data is None and mysql_data == "": same = True #float elif isinstance(mysql_data, (Decimal, float)) \ and isinstance(palo_data, (Decimal, float)): same = check_float(float(palo_data), float(mysql_data)) # list elif isinstance(mysql_data, list): same = check_list(palo_data, mysql_data) if not same: LOG.error(L('check error', palo_data=palo_data)) LOG.error(L('check error', expect_data=mysql_data)) LOG.error(L('check error', palo_line=palo_line)) LOG.error(L('check error', expect_line=mysql_line)) assert 0 == 1, "\ndiff data: \npalo:%s; \nexpect:%s;\npalo line:%s\nexpect line: %s" \ % (palo_data, mysql_data, palo_line, mysql_line) def check_float(data1, data2): """ check float """ if float(data1) == float(data2): return True if abs(float(data1) - float(data2)) < 0.001: return True if data2 != 0 and abs(float(data1) / float(data2) - 1.0) < 0.001: return True return False def check_list(data1, data2): """ check list """ if data1 is None and data2 is None: return True elif data1 is None or data2 is None: return False elif len(data2) > 0 and isinstance(data2[0], float): for d1, d2 in zip(data1, data2): if not check_float(d1, d2): return False else: return data1 == data2 return True def convert_dict2property(properties): """convert mat to property str""" sql = '(' for k, v in properties.items(): if v is not None: sql += ' "%s" = "%s", ' % (k, v) sql = sql.rstrip(', ') sql += ')' return sql def get_timestamp(palo_data): """将datetime 转为时间戳""" timeArray = palo_data[0][0].timetuple() timestamp = time.mktime(timeArray) return timestamp def check2_time_zone(palo_re, mysql_res, gap=2): """比较两个datetime时间,转为时间戳,允许的差值是gap""" palo_data = get_timestamp(palo_re) mysql_data = get_timestamp(mysql_res) assert abs(palo_data - mysql_data) <= 2 + gap, "palo_res %s, mysql_res %s, \ excepted gap < %s" % (palo_data, mysql_data, gap) def assert_return(expected_flag, expected_msg, func, *args, **kwargs): """ 验证一个函数的执行结果 验证是否正确;如果错误的话,验证错误信息 :param expected_flag:True/False :param expected_msg: :param func: :param args: :param kwargs: :return: """ try: func(*args, **kwargs) except Exception as e: print(str(e)) print(expected_msg) LOG.info(L('get an error', msg=str(e))) assert not expected_flag, "real sql status is False, expect %s" % expected_flag assert expected_msg in str(e), "expect:%s, doris:%s" % (expected_msg, str(e)) else: assert expected_flag, "real sql status is True, expect %s" % expected_flag def get_md5(target): """ 计算字符串的md5值 :param target: :return: """ obj = hashlib.md5(b"fkldsajlkfjlaksdjfkladsjfkladsjkldsjfklfjs") obj.update(target.encode("utf-8")) return obj.hexdigest() def assert_return_flag(expected_flag, func, *args, **kwargs): """ 验证一个函数的返回值 验证是否正确 :param expected_flag: :param func: :param args: :param kwargs: :return: """ return_flag = func(*args, **kwargs) assert expected_flag == return_flag def bitmap_index_to_sql(bitmap_index, set_null=False): """ bitmap_index: 由3元组(index_name, column_name, index_type)组成 """ sql = "INDEX %s (%s) USING %s" % (bitmap_index[0], bitmap_index[1], bitmap_index[2]) return sql def get_attr(ret, column_idx): """ 获取返回结果ret的第n列 ret = client.show_backend() be_ip = get_attr(ret, palo_job.BackendShowInfo.IP) """ return_list = list() for record in ret: return_list.append(record[column_idx]) LOG.info(L('get column from ret', idx=column_idx, ret=return_list)) return return_list def get_attr_condition_value(ret, condition_col_idx, condition_value, retrun_col_idx=None): """寻找ret的第condition_col_idx列的值为condition_value的行,返回该行的第retrun_col_idx列,只返回一个符合条件的""" if retrun_col_idx is None: retrun_col_idx = condition_col_idx for record in ret: if record[condition_col_idx] == condition_value: LOG.info(L('get result from ret', searced_key=condition_value, value=record[retrun_col_idx])) return record[retrun_col_idx] LOG.info(L('can not get result from ret', searched_key=condition_value)) return None def get_attr_condition_list(ret, condition_col_idx, condition_value, retrun_col_idx=None): """寻找ret的第condition_col_idx列的值为condition_value的行,返回所有行的第retrun_col_idx列,返回说有符合条件的""" result_list = list() if retrun_col_idx is None: retrun_col_idx = condition_col_idx for record in ret: if record[condition_col_idx] == condition_value: result_list.append(record[retrun_col_idx]) if len(result_list) == 0: return None else: LOG.info(L('get all result from ret', searced_key=condition_value, value=result_list)) return result_list def gen_tuple_num_str(begin, end): """gen_tuple_num_str(1, 3) -> ('1', '2')""" return tuple(map(str, range(begin, end))) def get_string_md5(st): """get string md5""" hl = hashlib.md5() hl.update(st.encode(encoding='utf-8')) return hl.hexdigest()