441 lines
18 KiB
Python
441 lines
18 KiB
Python
#!/bin/env python
|
|
# -*- coding: utf-8 -*-
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
"""palo client verify"""
|
|
import petl
|
|
import math
|
|
from decimal import Decimal
|
|
from collections import OrderedDict
|
|
from datetime import datetime
|
|
|
|
import palo_logger
|
|
LOG = palo_logger.Logger.getLogger()
|
|
L = palo_logger.StructedLogMessage
|
|
|
|
|
|
class VerifyFile(object):
|
|
"""
|
|
VerifyFile
|
|
"""
|
|
|
|
def __init__(self, file_name, delimiter='\t'):
|
|
self.file_name = file_name
|
|
self.delimiter = delimiter
|
|
|
|
def get_file_name(self):
|
|
"""
|
|
get file name
|
|
"""
|
|
return self.file_name
|
|
|
|
def get_delimiter(self):
|
|
"""
|
|
get delimiter
|
|
"""
|
|
return self.delimiter
|
|
|
|
def __str__(self):
|
|
return str(self.file_name)
|
|
|
|
|
|
class Verify(object):
|
|
"""verify class"""
|
|
def __init__(self, expected_file_list, datas, schema, table_name, database_name, encoding=None):
|
|
"""
|
|
file:校验文件,可以是str,['file1', 'file2'] or VerifyFile
|
|
sql_ret:sql执行的结果
|
|
schema:verify校验使用的表的desc结果做schema
|
|
verify_by_sql使用的4元组,sql查询结果的schema, 由四元组(name, type, agg_type, default_value)组成的list
|
|
table_name, database_name: 生成默认的校验文件名称用
|
|
"""
|
|
self.expected_file_list = expected_file_list
|
|
self.table_name = table_name
|
|
self.database_name = database_name
|
|
self.schema = schema
|
|
self.datas = datas
|
|
self.encoding = encoding
|
|
|
|
@staticmethod
|
|
def __get_type_convert_handler(field_type):
|
|
""""""
|
|
def __int_type(min, max):
|
|
"""Return a function that will attempt to parse the value as a number,
|
|
"""
|
|
|
|
def f(v):
|
|
"""check and return type
|
|
"""
|
|
try:
|
|
value = int(v)
|
|
except (ValueError, TypeError) as e:
|
|
raise e
|
|
if min <= value <= max:
|
|
return int(v)
|
|
else:
|
|
return None
|
|
|
|
return f
|
|
|
|
def __char_type():
|
|
"""regurn a function"""
|
|
|
|
def f(v):
|
|
"""check v if null"""
|
|
if v == "None":
|
|
v = None
|
|
return v
|
|
|
|
return f
|
|
|
|
tinyint = __int_type(-2 ** 7, 2 ** 7 - 1)
|
|
smallint = __int_type(-2 ** 15, 2 ** 15 - 1)
|
|
paloint = __int_type(-2 ** 31, 2 ** 31 - 1)
|
|
bigint = __int_type(-2 ** 63, 2 ** 63 - 1)
|
|
largeint = __int_type(-2 ** 127, 2 ** 127 - 1)
|
|
datetime = petl.datetimeparser('%Y-%m-%d %H:%M:%S')
|
|
date = petl.dateparser('%Y-%m-%d')
|
|
char = __char_type()
|
|
varchar = __char_type()
|
|
|
|
field_type = field_type.lower().split('<')[0]
|
|
field_type = field_type.lower().split('(')[0]
|
|
field_type_handler_dict = {'char': char, 'varchar': varchar, 'decimal': Decimal,
|
|
'tinyint': tinyint, 'smallint': smallint, 'int': paloint,
|
|
'bigint': bigint, 'largeint': largeint, 'text': varchar,
|
|
'float': float, 'double': float, 'datetime': datetime, 'date': date,
|
|
'boolean': tinyint,
|
|
'array': varchar, 'decimalv3': Decimal}
|
|
return field_type_handler_dict[field_type]
|
|
|
|
def __get_convert_dict(self):
|
|
"""get column type from schema, and get convert func"""
|
|
convert_dict = {field[0]: self.__get_type_convert_handler(field[1]) for field in self.schema}
|
|
return convert_dict
|
|
|
|
def __get_field_list(self):
|
|
"""get column name from schema"""
|
|
field_list = [field[0] for field in self.schema]
|
|
return field_list
|
|
|
|
def __get_key_list(self):
|
|
"""get key column from schema"""
|
|
key_list = [field[0] for field in self.schema if field[3] == 'true']
|
|
return tuple(key_list)
|
|
|
|
def __get_type_list(self):
|
|
"""get column type from schema"""
|
|
type_list = [field[1] for field in self.schema]
|
|
return type_list
|
|
|
|
@staticmethod
|
|
def __get_aggregate_key(key_list):
|
|
"""get key"""
|
|
if len(key_list) == 1:
|
|
return key_list[0]
|
|
else:
|
|
return key_list
|
|
|
|
def __get_aggregation_ordereddict(self):
|
|
"""aggregation table value agg func"""
|
|
def _sum(l):
|
|
items = []
|
|
for i in l:
|
|
if i is not None:
|
|
items.append(i)
|
|
if len(items) == 0:
|
|
return None
|
|
else:
|
|
return sum(items)
|
|
|
|
def __agg_replace(l):
|
|
items = []
|
|
for i in l:
|
|
items.append(i)
|
|
return items[-1]
|
|
|
|
def __agg_replace_if_not_null(l):
|
|
""" replace if not null """
|
|
items = []
|
|
for i in l:
|
|
if i is not None:
|
|
items.append(i)
|
|
if len(items) == 0:
|
|
return None
|
|
else:
|
|
return items[-1]
|
|
|
|
agg_function_dict = {'max': max, 'min': min, 'sum': _sum, 'replace': __agg_replace,
|
|
'replace_if_not_null': __agg_replace_if_not_null}
|
|
|
|
aggregation = OrderedDict()
|
|
|
|
aggtype_list = [(field[0], field[5]) for field in self.schema if field[5] != '']
|
|
|
|
for item in aggtype_list:
|
|
aggregation[item[0]] = item[0], agg_function_dict[item[1].lower()]
|
|
|
|
return aggregation
|
|
|
|
def __write_data_to_file(self, data_from_database, data_from_file, save_verifyfile_list):
|
|
"""将文件中的数据写入tmp文件中"""
|
|
if self.encoding is not None:
|
|
if save_verifyfile_list[0] is not None:
|
|
petl.tocsv(data_from_database, save_verifyfile_list[0].get_file_name(),
|
|
encoding=self.encoding, delimiter=save_verifyfile_list[0].get_delimiter())
|
|
if save_verifyfile_list[1] is not None:
|
|
petl.tocsv(data_from_file, save_verifyfile_list[1].get_file_name(),
|
|
encoding=self.encoding, delimiter=save_verifyfile_list[1].get_delimiter())
|
|
else:
|
|
if save_verifyfile_list[0] is not None:
|
|
petl.tocsv(data_from_database, save_verifyfile_list[0].get_file_name(),
|
|
delimiter=save_verifyfile_list[0].get_delimiter())
|
|
if save_verifyfile_list[1] is not None:
|
|
petl.tocsv(data_from_file, save_verifyfile_list[1].get_file_name(),
|
|
delimiter=save_verifyfile_list[1].get_delimiter())
|
|
|
|
@staticmethod
|
|
def __check_float(field_of_database, field_of_file, type):
|
|
def __adjust_data(num):
|
|
if num is None:
|
|
return None
|
|
else:
|
|
num = float(num)
|
|
if num == 0.0:
|
|
return 0.0
|
|
else:
|
|
return num / 10 ** (math.floor(math.log10(abs(num))) + 1)
|
|
data_of_database = __adjust_data(field_of_database)
|
|
data_of_file = __adjust_data(field_of_file)
|
|
# 最后一个有效数字可以相差 1,比如: 0.123456001 == 0.123456999 => True
|
|
# 0.123456001 == 0.123457999 => True 0.123456001 == 0.123458999 => False
|
|
# 0.123456001 == 0.123455999 => True 0.123456001 == 0.123454999 => False
|
|
precision = None
|
|
if type.lower() == 'float':
|
|
precision = 2e-6
|
|
elif type.lower() == 'double':
|
|
precision = 2e-15
|
|
if math.fabs(data_of_database - data_of_file) < precision or \
|
|
math.fabs(data_of_database - data_of_file) / data_of_file < 2e-3:
|
|
return True
|
|
else:
|
|
return False
|
|
|
|
def __check_data(self, data_from_database, data_from_file):
|
|
rows_number_of_database = petl.nrows(data_from_database)
|
|
rows_number_of_file = petl.nrows(data_from_file)
|
|
|
|
if rows_number_of_database != rows_number_of_file:
|
|
LOG.warning(L("verify data error", lines_of_database=rows_number_of_database,
|
|
lines_of_file=rows_number_of_file))
|
|
return False
|
|
result_of_database = petl.records(data_from_database)
|
|
result_of_file = petl.records(data_from_file)
|
|
type_list = self.__get_type_list()
|
|
|
|
for record_of_database, record_of_file in zip(result_of_database, result_of_file):
|
|
for field_of_database, field_of_file, field_type in \
|
|
zip(record_of_database, record_of_file, type_list):
|
|
if field_of_database is None and field_of_file is None:
|
|
continue
|
|
else:
|
|
if field_of_database is None or field_of_file is None:
|
|
return False
|
|
if field_type.lower() == 'float' or field_type.lower() == 'double':
|
|
if not self.__check_float(field_of_database, field_of_file,
|
|
type=field_type.lower()):
|
|
LOG.error(L("FLOAT VERIFY FAIL", field_of_database=field_of_database,
|
|
field_of_file=field_of_file, record_of_database=record_of_database,
|
|
record_of_file=record_of_file))
|
|
return False
|
|
elif field_of_database != field_of_file:
|
|
LOG.error(L("VERIFY FAIL", field_of_database=field_of_database,
|
|
field_of_file=field_of_file, record_of_database=record_of_database,
|
|
record_of_file=record_of_file))
|
|
return False
|
|
return True
|
|
|
|
def __get_data_from_database(self):
|
|
"""
|
|
处理数据库中的数据,datas是client.execute(sql)的结果
|
|
"""
|
|
key_list = self.__get_key_list()
|
|
header = self.__get_field_list()
|
|
field_list = self.__get_field_list()
|
|
convert_dict = {}
|
|
for field in self.schema:
|
|
if field[1].lower().startswith('largeint'):
|
|
convert_dict[field[0]] = self.__get_type_convert_handler(field[1])
|
|
dict_list = []
|
|
for row in self.datas:
|
|
field_value_dict = {}
|
|
for field, value in zip(header, row):
|
|
field_value_dict[field] = value
|
|
dict_list.append(field_value_dict)
|
|
table_database_from = petl.fromdicts(dict_list, header)
|
|
table_database_convert = petl.convert(table_database_from, convert_dict)
|
|
table_database_sort = petl.sort(table_database_convert, field_list)
|
|
table_database_merge_sort = petl.mergesort(table_database_sort,
|
|
key=field_list, presorted=False)
|
|
return table_database_merge_sort
|
|
|
|
def __get_data_from_file(self):
|
|
"""
|
|
从文件中获取数据,排序,按照表的聚合模型处理数据
|
|
"""
|
|
# 为了兼容以前的代码
|
|
if type(self.expected_file_list) is str:
|
|
from_verifyfile_list = [VerifyFile(self.expected_file_list, '\t')]
|
|
elif type(self.expected_file_list) is list and type(self.expected_file_list[0]) is str:
|
|
from_verifyfile_list = [VerifyFile(file, '\t') for file in self.expected_file_list]
|
|
elif type(self.expected_file_list) is VerifyFile:
|
|
from_verifyfile_list = [self.expected_file_list]
|
|
else:
|
|
from_verifyfile_list = None
|
|
header = self.__get_field_list()
|
|
key_list = self.__get_key_list()
|
|
field_list = self.__get_field_list()
|
|
convert_dict = self.__get_convert_dict()
|
|
dup = False
|
|
for col in self.schema:
|
|
if col[5] == 'NONE':
|
|
dup = True
|
|
|
|
table_file_to_merge_list = []
|
|
for etl_file in from_verifyfile_list:
|
|
# 读取csv文件数据
|
|
table_file_from = petl.fromcsv(etl_file.get_file_name(),
|
|
encoding='utf8', delimiter=etl_file.get_delimiter())
|
|
# 给数据增加表头
|
|
table_file_push = petl.pushheader(table_file_from, header)
|
|
# 给数据加类型
|
|
table_file_convert = petl.convert(table_file_push, convert_dict)
|
|
table_file_to_merge_list.append(table_file_convert)
|
|
if not dup:
|
|
table_file_merge_sort = petl.mergesort(*table_file_to_merge_list,
|
|
key=key_list, presorted=False)
|
|
aggregation = self.__get_aggregation_ordereddict()
|
|
aggregate_key = self.__get_aggregate_key(key_list)
|
|
# 聚合表,按照聚合方式聚合
|
|
table_file_aggregate = petl.aggregate(table_file_merge_sort,
|
|
key=aggregate_key, aggregation=aggregation,
|
|
presorted=True)
|
|
table_file_merge_sort = petl.mergesort(table_file_aggregate,
|
|
key=key_list, presorted=True)
|
|
return table_file_merge_sort
|
|
else:
|
|
table_file_merge_sort = petl.mergesort(*table_file_to_merge_list,
|
|
key=field_list, presorted=False)
|
|
return table_file_merge_sort
|
|
|
|
def __generate_dafault_save_verifyfile_list(self):
|
|
"""根据库名,表名生成校验文件的名称"""
|
|
name_prefix = ".%s.%s" % (self.database_name, self.table_name)
|
|
name_for_database = "%s.%s" % (name_prefix, 'DB')
|
|
name_for_file = "%s.%s" % (name_prefix, 'FILE')
|
|
return [VerifyFile(name_for_database), VerifyFile(name_for_file)]
|
|
|
|
def verify(self, save_file_list=None):
|
|
"""
|
|
崭新的校验函数
|
|
"""
|
|
LOG.info(L("check file:", file=self.expected_file_list))
|
|
self.__adjust_schema_for_verify()
|
|
# 获取db中的数据
|
|
data_from_database = self.__get_data_from_database()
|
|
# 获取file中的文件
|
|
data_from_file = self.__get_data_from_file()
|
|
if save_file_list is None:
|
|
save_file_list = self.__generate_dafault_save_verifyfile_list()
|
|
# 分别写入数据
|
|
self.__write_data_to_file(data_from_database, data_from_file, save_file_list)
|
|
# 返回check结果, true / false
|
|
return self.__check_data(data_from_database, data_from_file)
|
|
|
|
def __adjust_schema_for_verify(self):
|
|
adjust_schema = []
|
|
for field in self.schema:
|
|
adjust_field = list(field)
|
|
if adjust_field[3] == 'false':
|
|
if adjust_field[5] is not None:
|
|
adjust_field[5] = adjust_field[5].split(',')[0]
|
|
else:
|
|
adjust_field[5] = ''
|
|
adjust_schema.append(tuple(adjust_field))
|
|
self.schema = tuple(adjust_schema)
|
|
return self.schema
|
|
|
|
def __adjust_schema_for_self_defined_sql(self):
|
|
# TODO
|
|
# 这个函数可能有问题,以后修改
|
|
adjust_schema = []
|
|
for column in self.schema:
|
|
adjust_column = []
|
|
adjust_column.append(column[0])
|
|
adjust_column.append(column[1])
|
|
adjust_column.append('No')
|
|
if len(column) > 2 and column[2] is not None:
|
|
adjust_column.append('false')
|
|
else:
|
|
adjust_column.append('true')
|
|
adjust_column.append('N/A')
|
|
if len(column) > 2 and column[2] is not None:
|
|
adjust_column.append(column[2])
|
|
else:
|
|
adjust_column.append('')
|
|
adjust_schema.append(tuple(adjust_column))
|
|
self.schema = adjust_schema
|
|
return self.schema
|
|
|
|
def verify_by_sql(self, save_file_list=None):
|
|
"""
|
|
校验自定义的SQL语句的查询结果
|
|
expected_file_list: VerifyFile对象的list
|
|
sql: SQL语句字符串
|
|
schema: 查询结果的schema, 由四元组(name, type, agg_type, default_value)组成的list
|
|
四元组中后两项可省略, 需要注意的是key列指定默认值是agg_type设置为None
|
|
Example -> [("k1", "INT"), ("k2", "CHAR", None, ""), ("v", "DATE", "REPLACE")]
|
|
save_file_list: VerifyFile对象的list
|
|
"""
|
|
self.__adjust_schema_for_self_defined_sql()
|
|
data_from_database = self.__get_data_from_database()
|
|
data_from_file = self.__get_data_from_file()
|
|
if save_file_list is not None:
|
|
self.__write_data_to_file(data_from_database, data_from_file, save_file_list)
|
|
return self.__check_data(data_from_database, data_from_file)
|
|
|
|
|
|
def verify(file, sql_ret, schema, table_name, database_name, encoding, save_file_list):
|
|
"""
|
|
verify, schema为palo desc结果
|
|
适用于
|
|
1. 多个文件的时候,会对文件进行拼接,排序读取
|
|
2. 适用于原始文件,palo对原始文件进行过滤、聚合等处理时,无需额外保存校验文件,直接使用原始文件进行处理生成校验文件
|
|
"""
|
|
verifier = Verify(file, sql_ret, schema, table_name, database_name, encoding)
|
|
return verifier.verify(save_file_list)
|
|
|
|
|
|
def verify_by_sql(file, sql_ret, schema, table_name, database_name, encoding, save_file_list):
|
|
"""
|
|
verify by sql
|
|
指定四元组为schema
|
|
"""
|
|
verifier = Verify(file, sql_ret, schema, table_name, database_name, encoding)
|
|
return verifier.verify_by_sql(save_file_list)
|
|
|