#!/bin/env python # -*- coding: utf-8 -*- # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. """palo client verify""" import petl import math from decimal import Decimal from collections import OrderedDict from datetime import datetime import palo_logger LOG = palo_logger.Logger.getLogger() L = palo_logger.StructedLogMessage class VerifyFile(object): """ VerifyFile """ def __init__(self, file_name, delimiter='\t'): self.file_name = file_name self.delimiter = delimiter def get_file_name(self): """ get file name """ return self.file_name def get_delimiter(self): """ get delimiter """ return self.delimiter def __str__(self): return str(self.file_name) class Verify(object): """verify class""" def __init__(self, expected_file_list, datas, schema, table_name, database_name, encoding=None): """ file:校验文件,可以是str,['file1', 'file2'] or VerifyFile sql_ret:sql执行的结果 schema:verify校验使用的表的desc结果做schema verify_by_sql使用的4元组,sql查询结果的schema, 由四元组(name, type, agg_type, default_value)组成的list table_name, database_name: 生成默认的校验文件名称用 """ self.expected_file_list = expected_file_list self.table_name = table_name self.database_name = database_name self.schema = schema self.datas = datas self.encoding = encoding @staticmethod def __get_type_convert_handler(field_type): """""" def __int_type(min, max): """Return a function that will attempt to parse the value as a number, """ def f(v): """check and return type """ try: value = int(v) except (ValueError, TypeError) as e: raise e if min <= value <= max: return int(v) else: return None return f def __char_type(): """regurn a function""" def f(v): """check v if null""" if v == "None": v = None return v return f tinyint = __int_type(-2 ** 7, 2 ** 7 - 1) smallint = __int_type(-2 ** 15, 2 ** 15 - 1) paloint = __int_type(-2 ** 31, 2 ** 31 - 1) bigint = __int_type(-2 ** 63, 2 ** 63 - 1) largeint = __int_type(-2 ** 127, 2 ** 127 - 1) datetime = petl.datetimeparser('%Y-%m-%d %H:%M:%S') date = petl.dateparser('%Y-%m-%d') char = __char_type() varchar = __char_type() field_type = field_type.lower().split('<')[0] field_type = field_type.lower().split('(')[0] field_type_handler_dict = {'char': char, 'varchar': varchar, 'decimal': Decimal, 'tinyint': tinyint, 'smallint': smallint, 'int': paloint, 'bigint': bigint, 'largeint': largeint, 'text': varchar, 'float': float, 'double': float, 'datetime': datetime, 'date': date, 'boolean': tinyint, 'array': varchar, 'decimalv3': Decimal} return field_type_handler_dict[field_type] def __get_convert_dict(self): """get column type from schema, and get convert func""" convert_dict = {field[0]: self.__get_type_convert_handler(field[1]) for field in self.schema} return convert_dict def __get_field_list(self): """get column name from schema""" field_list = [field[0] for field in self.schema] return field_list def __get_key_list(self): """get key column from schema""" key_list = [field[0] for field in self.schema if field[3] == 'true'] return tuple(key_list) def __get_type_list(self): """get column type from schema""" type_list = [field[1] for field in self.schema] return type_list @staticmethod def __get_aggregate_key(key_list): """get key""" if len(key_list) == 1: return key_list[0] else: return key_list def __get_aggregation_ordereddict(self): """aggregation table value agg func""" def _sum(l): items = [] for i in l: if i is not None: items.append(i) if len(items) == 0: return None else: return sum(items) def __agg_replace(l): items = [] for i in l: items.append(i) return items[-1] def __agg_replace_if_not_null(l): """ replace if not null """ items = [] for i in l: if i is not None: items.append(i) if len(items) == 0: return None else: return items[-1] agg_function_dict = {'max': max, 'min': min, 'sum': _sum, 'replace': __agg_replace, 'replace_if_not_null': __agg_replace_if_not_null} aggregation = OrderedDict() aggtype_list = [(field[0], field[5]) for field in self.schema if field[5] != ''] for item in aggtype_list: aggregation[item[0]] = item[0], agg_function_dict[item[1].lower()] return aggregation def __write_data_to_file(self, data_from_database, data_from_file, save_verifyfile_list): """将文件中的数据写入tmp文件中""" if self.encoding is not None: if save_verifyfile_list[0] is not None: petl.tocsv(data_from_database, save_verifyfile_list[0].get_file_name(), encoding=self.encoding, delimiter=save_verifyfile_list[0].get_delimiter()) if save_verifyfile_list[1] is not None: petl.tocsv(data_from_file, save_verifyfile_list[1].get_file_name(), encoding=self.encoding, delimiter=save_verifyfile_list[1].get_delimiter()) else: if save_verifyfile_list[0] is not None: petl.tocsv(data_from_database, save_verifyfile_list[0].get_file_name(), delimiter=save_verifyfile_list[0].get_delimiter()) if save_verifyfile_list[1] is not None: petl.tocsv(data_from_file, save_verifyfile_list[1].get_file_name(), delimiter=save_verifyfile_list[1].get_delimiter()) @staticmethod def __check_float(field_of_database, field_of_file, type): def __adjust_data(num): if num is None: return None else: num = float(num) if num == 0.0: return 0.0 else: return num / 10 ** (math.floor(math.log10(abs(num))) + 1) data_of_database = __adjust_data(field_of_database) data_of_file = __adjust_data(field_of_file) # 最后一个有效数字可以相差 1,比如: 0.123456001 == 0.123456999 => True # 0.123456001 == 0.123457999 => True 0.123456001 == 0.123458999 => False # 0.123456001 == 0.123455999 => True 0.123456001 == 0.123454999 => False precision = None if type.lower() == 'float': precision = 2e-6 elif type.lower() == 'double': precision = 2e-15 if math.fabs(data_of_database - data_of_file) < precision or \ math.fabs(data_of_database - data_of_file) / data_of_file < 2e-3: return True else: return False def __check_data(self, data_from_database, data_from_file): rows_number_of_database = petl.nrows(data_from_database) rows_number_of_file = petl.nrows(data_from_file) if rows_number_of_database != rows_number_of_file: LOG.warning(L("verify data error", lines_of_database=rows_number_of_database, lines_of_file=rows_number_of_file)) return False result_of_database = petl.records(data_from_database) result_of_file = petl.records(data_from_file) type_list = self.__get_type_list() for record_of_database, record_of_file in zip(result_of_database, result_of_file): for field_of_database, field_of_file, field_type in \ zip(record_of_database, record_of_file, type_list): if field_of_database is None and field_of_file is None: continue else: if field_of_database is None or field_of_file is None: return False if field_type.lower() == 'float' or field_type.lower() == 'double': if not self.__check_float(field_of_database, field_of_file, type=field_type.lower()): LOG.error(L("FLOAT VERIFY FAIL", field_of_database=field_of_database, field_of_file=field_of_file, record_of_database=record_of_database, record_of_file=record_of_file)) return False elif field_of_database != field_of_file: LOG.error(L("VERIFY FAIL", field_of_database=field_of_database, field_of_file=field_of_file, record_of_database=record_of_database, record_of_file=record_of_file)) return False return True def __get_data_from_database(self): """ 处理数据库中的数据,datas是client.execute(sql)的结果 """ key_list = self.__get_key_list() header = self.__get_field_list() field_list = self.__get_field_list() convert_dict = {} for field in self.schema: if field[1].lower().startswith('largeint'): convert_dict[field[0]] = self.__get_type_convert_handler(field[1]) dict_list = [] for row in self.datas: field_value_dict = {} for field, value in zip(header, row): field_value_dict[field] = value dict_list.append(field_value_dict) table_database_from = petl.fromdicts(dict_list, header) table_database_convert = petl.convert(table_database_from, convert_dict) table_database_sort = petl.sort(table_database_convert, field_list) table_database_merge_sort = petl.mergesort(table_database_sort, key=field_list, presorted=False) return table_database_merge_sort def __get_data_from_file(self): """ 从文件中获取数据,排序,按照表的聚合模型处理数据 """ # 为了兼容以前的代码 if type(self.expected_file_list) is str: from_verifyfile_list = [VerifyFile(self.expected_file_list, '\t')] elif type(self.expected_file_list) is list and type(self.expected_file_list[0]) is str: from_verifyfile_list = [VerifyFile(file, '\t') for file in self.expected_file_list] elif type(self.expected_file_list) is VerifyFile: from_verifyfile_list = [self.expected_file_list] else: from_verifyfile_list = None header = self.__get_field_list() key_list = self.__get_key_list() field_list = self.__get_field_list() convert_dict = self.__get_convert_dict() dup = False for col in self.schema: if col[5] == 'NONE': dup = True table_file_to_merge_list = [] for etl_file in from_verifyfile_list: # 读取csv文件数据 table_file_from = petl.fromcsv(etl_file.get_file_name(), encoding='utf8', delimiter=etl_file.get_delimiter()) # 给数据增加表头 table_file_push = petl.pushheader(table_file_from, header) # 给数据加类型 table_file_convert = petl.convert(table_file_push, convert_dict) table_file_to_merge_list.append(table_file_convert) if not dup: table_file_merge_sort = petl.mergesort(*table_file_to_merge_list, key=key_list, presorted=False) aggregation = self.__get_aggregation_ordereddict() aggregate_key = self.__get_aggregate_key(key_list) # 聚合表,按照聚合方式聚合 table_file_aggregate = petl.aggregate(table_file_merge_sort, key=aggregate_key, aggregation=aggregation, presorted=True) table_file_merge_sort = petl.mergesort(table_file_aggregate, key=key_list, presorted=True) return table_file_merge_sort else: table_file_merge_sort = petl.mergesort(*table_file_to_merge_list, key=field_list, presorted=False) return table_file_merge_sort def __generate_dafault_save_verifyfile_list(self): """根据库名,表名生成校验文件的名称""" name_prefix = ".%s.%s" % (self.database_name, self.table_name) name_for_database = "%s.%s" % (name_prefix, 'DB') name_for_file = "%s.%s" % (name_prefix, 'FILE') return [VerifyFile(name_for_database), VerifyFile(name_for_file)] def verify(self, save_file_list=None): """ 崭新的校验函数 """ LOG.info(L("check file:", file=self.expected_file_list)) self.__adjust_schema_for_verify() # 获取db中的数据 data_from_database = self.__get_data_from_database() # 获取file中的文件 data_from_file = self.__get_data_from_file() if save_file_list is None: save_file_list = self.__generate_dafault_save_verifyfile_list() # 分别写入数据 self.__write_data_to_file(data_from_database, data_from_file, save_file_list) # 返回check结果, true / false return self.__check_data(data_from_database, data_from_file) def __adjust_schema_for_verify(self): adjust_schema = [] for field in self.schema: adjust_field = list(field) if adjust_field[3] == 'false': if adjust_field[5] is not None: adjust_field[5] = adjust_field[5].split(',')[0] else: adjust_field[5] = '' adjust_schema.append(tuple(adjust_field)) self.schema = tuple(adjust_schema) return self.schema def __adjust_schema_for_self_defined_sql(self): # TODO # 这个函数可能有问题,以后修改 adjust_schema = [] for column in self.schema: adjust_column = [] adjust_column.append(column[0]) adjust_column.append(column[1]) adjust_column.append('No') if len(column) > 2 and column[2] is not None: adjust_column.append('false') else: adjust_column.append('true') adjust_column.append('N/A') if len(column) > 2 and column[2] is not None: adjust_column.append(column[2]) else: adjust_column.append('') adjust_schema.append(tuple(adjust_column)) self.schema = adjust_schema return self.schema def verify_by_sql(self, save_file_list=None): """ 校验自定义的SQL语句的查询结果 expected_file_list: VerifyFile对象的list sql: SQL语句字符串 schema: 查询结果的schema, 由四元组(name, type, agg_type, default_value)组成的list 四元组中后两项可省略, 需要注意的是key列指定默认值是agg_type设置为None Example -> [("k1", "INT"), ("k2", "CHAR", None, ""), ("v", "DATE", "REPLACE")] save_file_list: VerifyFile对象的list """ self.__adjust_schema_for_self_defined_sql() data_from_database = self.__get_data_from_database() data_from_file = self.__get_data_from_file() if save_file_list is not None: self.__write_data_to_file(data_from_database, data_from_file, save_file_list) return self.__check_data(data_from_database, data_from_file) def verify(file, sql_ret, schema, table_name, database_name, encoding, save_file_list): """ verify, schema为palo desc结果 适用于 1. 多个文件的时候,会对文件进行拼接,排序读取 2. 适用于原始文件,palo对原始文件进行过滤、聚合等处理时,无需额外保存校验文件,直接使用原始文件进行处理生成校验文件 """ verifier = Verify(file, sql_ret, schema, table_name, database_name, encoding) return verifier.verify(save_file_list) def verify_by_sql(file, sql_ret, schema, table_name, database_name, encoding, save_file_list): """ verify by sql 指定四元组为schema """ verifier = Verify(file, sql_ret, schema, table_name, database_name, encoding) return verifier.verify_by_sql(save_file_list)