!996 Update SQLDiag: Add two threshold judgments
Merge pull request !996 from wangtq/master
This commit is contained in:
@ -71,12 +71,16 @@ predict dataset sample: [sample_data/predict.csv](data/predict.csv)
|
||||
|
||||
## example for template method using sample dataset
|
||||
|
||||
# training
|
||||
# train
|
||||
python main.py train -f ./sample_data/train.csv --model template --model-path ./template
|
||||
|
||||
# predict
|
||||
python main.py predict -f ./sample_data/predict.csv --model template --model-path ./template
|
||||
--predicted-file ./result/t_result
|
||||
|
||||
# predict with threshold
|
||||
python main.py predict -f ./sample_data/predict.csv --threshold 0.02 --model template --model-path ./template
|
||||
--predicted-file ./result/t_result
|
||||
|
||||
# update model
|
||||
python main.py finetune -f ./sample_data/train.csv --model template --model-path ./template
|
||||
|
@ -9,9 +9,9 @@ W2V_SUFFIX = 'word2vector'
|
||||
|
||||
|
||||
def check_template_algorithm(param):
|
||||
if param and param not in ["list", "levenshtein", "parse_tree"]:
|
||||
if param and param not in ["list", "levenshtein", "parse_tree", "cosine_distance"]:
|
||||
raise ValueError("The similarity algorithm '%s' is invaild, "
|
||||
"please choose from ['list', 'levenshtein', 'parse_tree']" % param)
|
||||
"please choose from ['list', 'levenshtein', 'parse_tree', 'cosine_distance']" % param)
|
||||
|
||||
|
||||
class ModelConfig(object):
|
||||
@ -62,7 +62,7 @@ SUPPORTED_ALGORITHM = {'dnn': lambda config: DnnModel(DnnConfig.init_from_config
|
||||
|
||||
|
||||
class SQLDiag:
|
||||
def __init__(self, model_algorithm, csv_file, params):
|
||||
def __init__(self, model_algorithm, params):
|
||||
if model_algorithm not in SUPPORTED_ALGORITHM:
|
||||
raise NotImplementedError("do not support {}".format(model_algorithm))
|
||||
try:
|
||||
@ -70,20 +70,16 @@ class SQLDiag:
|
||||
except ValueError as e:
|
||||
logging.error(e, exc_info=True)
|
||||
sys.exit(1)
|
||||
self.load_data = LoadData(csv_file)
|
||||
|
||||
def __getattr__(self, item):
|
||||
return getattr(self.load_data, item)
|
||||
def fit(self, data):
|
||||
self._model.fit(data)
|
||||
|
||||
def fit(self):
|
||||
self._model.fit(self.train_data)
|
||||
def transform(self, data):
|
||||
return self._model.transform(data)
|
||||
|
||||
def transform(self):
|
||||
return self._model.transform(self.predict_data)
|
||||
|
||||
def fine_tune(self, filepath):
|
||||
def fine_tune(self, filepath, data):
|
||||
self._model.load(filepath)
|
||||
self._model.fit(self.train_data)
|
||||
self._model.fit(data)
|
||||
|
||||
def load(self, filepath):
|
||||
self._model.load(filepath)
|
||||
|
@ -36,7 +36,7 @@ class KerasRegression:
|
||||
shape = features.shape[1]
|
||||
if self.model is None:
|
||||
self.model = self.build_model(shape=shape, encoding_dim=self.encoding_dim)
|
||||
self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=2)
|
||||
self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=0)
|
||||
|
||||
def predict(self, features):
|
||||
predict_result = self.model.predict(features)
|
||||
@ -129,10 +129,6 @@ class DnnModel(AbstractModel, ABC):
|
||||
self.w2v.load(word2vector_path)
|
||||
with open(scaler_path, 'rb') as f:
|
||||
self.scaler = pickle.load(f)
|
||||
logging.info("dnn model is loaded: '{}'; w2v model is loaded: '{}'; scaler model is loaded: '{}'."
|
||||
.format(dnn_path,
|
||||
word2vector_path,
|
||||
scaler_path))
|
||||
else:
|
||||
logging.error("{} not exist.".format(realpath))
|
||||
|
||||
@ -149,7 +145,5 @@ class DnnModel(AbstractModel, ABC):
|
||||
self.w2v.save(word2vector_path)
|
||||
with open(scaler_path, 'wb') as f:
|
||||
pickle.dump(self.scaler, f)
|
||||
logging.info("dnn model is saved: '{}'; w2v model is saved: '{}'; scaler model is saved: '{}'."
|
||||
.format(dnn_path,
|
||||
word2vector_path,
|
||||
scaler_path))
|
||||
print("DNN model is stored in '{}'".format(realpath))
|
||||
|
||||
|
@ -18,9 +18,10 @@ import logging
|
||||
import os
|
||||
import stat
|
||||
from functools import reduce
|
||||
from collections import defaultdict
|
||||
|
||||
from algorithm.sql_similarity import calc_sql_distance
|
||||
from preprocessing import get_sql_template
|
||||
from preprocessing import get_sql_template, templatize_sql
|
||||
from utils import check_illegal_sql, LRUCache
|
||||
from . import AbstractModel
|
||||
|
||||
@ -40,97 +41,78 @@ class TemplateModel(AbstractModel):
|
||||
for sql, duration_time in data:
|
||||
if check_illegal_sql(sql):
|
||||
continue
|
||||
# get 'fine_template' and 'rough_template' of SQL
|
||||
fine_template, rough_template = get_sql_template(sql)
|
||||
sql_prefix = fine_template.split()[0]
|
||||
# if prefix of SQL is not in 'update', 'delete', 'select' and 'insert',
|
||||
# then convert prefix to 'other'
|
||||
sql_template = templatize_sql(sql)
|
||||
sql_prefix = sql_template.split()[0]
|
||||
if sql_prefix not in self.__hash_table:
|
||||
sql_prefix = 'OTHER'
|
||||
if rough_template not in self.__hash_table[sql_prefix]:
|
||||
self.__hash_table[sql_prefix][rough_template] = dict()
|
||||
self.__hash_table[sql_prefix][rough_template]['info'] = dict()
|
||||
if fine_template not in self.__hash_table[sql_prefix][rough_template]['info']:
|
||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template] = \
|
||||
dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
|
||||
# count the number of occurrences of fine template
|
||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['count'] += 1
|
||||
# store the execution time of the matched template in the corresponding list
|
||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
||||
'time_list'].append(duration_time)
|
||||
# iterative calculation of execution time based on historical data
|
||||
if not self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
||||
'iter_time']:
|
||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
||||
'iter_time'] = duration_time
|
||||
else:
|
||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time'] = \
|
||||
(self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
||||
'iter_time'] + duration_time) / 2
|
||||
# calculate the average execution time of each template
|
||||
if sql_template not in self.__hash_table[sql_prefix]:
|
||||
self.__hash_table[sql_prefix][sql_template] = dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
|
||||
self.__hash_table[sql_prefix][sql_template]['count'] += 1
|
||||
self.__hash_table[sql_prefix][sql_template]['time_list'].append(duration_time)
|
||||
|
||||
for sql_prefix, sql_prefix_info in self.__hash_table.items():
|
||||
for rough_template, rough_template_info in sql_prefix_info.items():
|
||||
for _, fine_template_info in rough_template_info['info'].items():
|
||||
del fine_template_info['time_list'][:-self.time_list_size]
|
||||
fine_template_info['mean_time'] = \
|
||||
sum(fine_template_info['time_list']) / len(
|
||||
fine_template_info['time_list'])
|
||||
rough_template_info['count'] = len(rough_template_info['info'])
|
||||
rough_template_info['mean_time'] = sum(
|
||||
[value['mean_time'] for key, value in
|
||||
rough_template_info['info'].items()]) / len(
|
||||
rough_template_info['info'])
|
||||
for sql_template, sql_template_info in sql_prefix_info.items():
|
||||
del sql_template_info['time_list'][:-self.time_list_size]
|
||||
sql_template_info['mean_time'] = sum(sql_template_info['time_list']) / len(sql_template_info['time_list'])
|
||||
sql_template_info['iter_time'] = reduce(lambda x, y: (x+y)/2, sql_template_info['time_list'])
|
||||
|
||||
|
||||
def transform(self, data):
|
||||
predict_time_list = []
|
||||
predict_result_dict = defaultdict(list)
|
||||
for sql in data:
|
||||
predict_time = self.predict_duration_time(sql)
|
||||
predict_time_list.append([sql, predict_time])
|
||||
return predict_time_list
|
||||
sql_, status, predict_time, top_similarity_sql = self.predict_duration_time(sql)
|
||||
predict_result_dict[status].append([sql_, predict_time, top_similarity_sql])
|
||||
for key, value in predict_result_dict.items():
|
||||
if value:
|
||||
value.sort(key=lambda item: item[1], reverse=True)
|
||||
return predict_result_dict
|
||||
|
||||
@LRUCache(max_size=1024)
|
||||
def predict_duration_time(self, sql):
|
||||
top_similarity_sql = None
|
||||
if check_illegal_sql(sql):
|
||||
return -1
|
||||
sql_prefix = sql.strip().split()[0]
|
||||
# get 'fine_template' and 'rough_template' of SQL
|
||||
fine_template, rough_template = get_sql_template(sql)
|
||||
predict_time = -1
|
||||
status = 'Suspect illegal sql'
|
||||
return sql, status, predict_time, top_similarity_sql
|
||||
|
||||
sql_template = templatize_sql(sql)
|
||||
# get 'sql_template' of SQL
|
||||
sql_prefix = sql_template.strip().split()[0]
|
||||
if sql_prefix not in self.__hash_table:
|
||||
sql_prefix = 'OTHER'
|
||||
if not self.__hash_table[sql_prefix]:
|
||||
logging.warning("'{}' not in the templates.".format(sql))
|
||||
status = 'No SQL information'
|
||||
predict_time = -1
|
||||
elif rough_template not in self.__hash_table[sql_prefix] or fine_template not in \
|
||||
self.__hash_table[sql_prefix][rough_template]['info']:
|
||||
elif sql_template not in self.__hash_table[sql_prefix]:
|
||||
similarity_info = []
|
||||
"""
|
||||
"""
|
||||
if the template does not exist in the hash table,
|
||||
then calculate the possible execution time based on template
|
||||
similarity and KNN algorithm in all other templates
|
||||
"""
|
||||
if rough_template not in self.__hash_table[sql_prefix]:
|
||||
for local_rough_template, local_rough_template_info in self.__hash_table[sql_prefix].items():
|
||||
similarity_info.append(
|
||||
(self.similarity_algorithm(rough_template, local_rough_template),
|
||||
local_rough_template_info['mean_time']))
|
||||
else:
|
||||
for local_fine_template, local_fine_template_info in \
|
||||
self.__hash_table[sql_prefix][rough_template]['info'].items():
|
||||
similarity_info.append(
|
||||
(self.similarity_algorithm(fine_template, local_fine_template),
|
||||
local_fine_template_info['iter_time']))
|
||||
status = 'No SQL template found'
|
||||
for local_sql_template, local_sql_template_info in self.__hash_table[sql_prefix].items():
|
||||
similarity_info.append(
|
||||
(self.similarity_algorithm(sql_template, local_sql_template),
|
||||
local_sql_template_info['mean_time'], local_sql_template))
|
||||
topn_similarity_info = heapq.nlargest(self.knn_number, similarity_info)
|
||||
sum_similarity_scores = sum(item[0] for item in topn_similarity_info) + self.bias
|
||||
similarity_proportions = (item[0] / sum_similarity_scores for item in
|
||||
topn_similarity_info)
|
||||
topn_duration_time = (item[1] for item in topn_similarity_info)
|
||||
sum_similarity_scores = sum(item[0] for item in topn_similarity_info)
|
||||
if not sum_similarity_scores:
|
||||
sum_similarity_scores = self.bias
|
||||
top_similarity_sql = '\n'.join([item[2] for item in topn_similarity_info])
|
||||
similarity_proportions = [item[0] / sum_similarity_scores for item in
|
||||
topn_similarity_info]
|
||||
topn_duration_time = [item[1] for item in topn_similarity_info]
|
||||
predict_time = reduce(lambda x, y: x + y,
|
||||
map(lambda x, y: x * y, similarity_proportions,
|
||||
topn_duration_time))
|
||||
else:
|
||||
predict_time = self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time']
|
||||
|
||||
return predict_time
|
||||
else:
|
||||
status = 'Fine match'
|
||||
predict_time = self.__hash_table[sql_prefix][sql_template]['iter_time']
|
||||
top_similarity_sql = sql_template
|
||||
|
||||
return sql, status, predict_time, top_similarity_sql
|
||||
|
||||
def load(self, filepath):
|
||||
realpath = os.path.realpath(filepath)
|
||||
@ -138,7 +120,6 @@ class TemplateModel(AbstractModel):
|
||||
template_path = os.path.join(realpath, 'template.json')
|
||||
with open(template_path, mode='r') as f:
|
||||
self.__hash_table = json.load(f)
|
||||
logging.info("template model '{}' is loaded.".format(template_path))
|
||||
else:
|
||||
logging.error("{} not exist.".format(realpath))
|
||||
|
||||
@ -151,4 +132,5 @@ class TemplateModel(AbstractModel):
|
||||
template_path = os.path.join(realpath, 'template.json')
|
||||
with open(template_path, mode='w') as f:
|
||||
json.dump(self.__hash_table, f, indent=4)
|
||||
logging.info("template model is stored in '{}'".format(realpath))
|
||||
print("Template model is stored in '{}'".format(realpath))
|
||||
|
||||
|
@ -5,6 +5,8 @@ def calc_sql_distance(algorithm):
|
||||
from .levenshtein import distance
|
||||
elif algorithm == 'parse_tree':
|
||||
from .parse_tree import distance
|
||||
elif algorithm == 'cosine_distance':
|
||||
from .cosine_distance import distance
|
||||
else:
|
||||
raise NotImplementedError("do not support '{}'".format(algorithm))
|
||||
return distance
|
||||
|
@ -0,0 +1,26 @@
|
||||
"""
|
||||
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
"""
|
||||
import math
|
||||
from collections import Counter
|
||||
|
||||
|
||||
def distance(str1, str2):
|
||||
c_1 = Counter(str1)
|
||||
c_2 = Counter(str2)
|
||||
c_union = set(c_1).union(c_2)
|
||||
dot_product = sum(c_1.get(item, 0) * c_2.get(item, 0) for item in c_union)
|
||||
mag_c1 = math.sqrt(sum(c_1.get(item, 0)**2 for item in c_union))
|
||||
mag_c2 = math.sqrt(sum(c_2.get(item, 0)**2 for item in c_union))
|
||||
return dot_product / (mag_c1 * mag_c2)
|
@ -1,3 +1,19 @@
|
||||
"""
|
||||
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||
|
||||
openGauss is licensed under Mulan PSL v2.
|
||||
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||
You may obtain a copy of Mulan PSL v2 at:
|
||||
|
||||
http://license.coscl.org.cn/MulanPSL2
|
||||
|
||||
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||
See the Mulan PSL v2 for more details.
|
||||
"""
|
||||
|
||||
|
||||
def distance(str1, str2):
|
||||
"""
|
||||
func: calculate levenshtein distance between two strings.
|
||||
@ -5,18 +21,20 @@ def distance(str1, str2):
|
||||
:param str2: string2
|
||||
:return: distance
|
||||
"""
|
||||
m, n = len(str1) + 1, len(str2) + 1
|
||||
matrix = [[0] * n for _ in range(m)]
|
||||
matrix[0][0] = 0
|
||||
for i in range(1, m):
|
||||
matrix[i][0] = matrix[i - 1][0] + 1
|
||||
for j in range(1, n):
|
||||
matrix[0][j] = matrix[0][j - 1] + 1
|
||||
for i in range(1, m):
|
||||
for j in range(1, n):
|
||||
if str1[i - 1] == str2[j - 1]:
|
||||
matrix[i][j] = matrix[i - 1][j - 1]
|
||||
len_str1 = len(str1) + 1
|
||||
len_str2 = len(str2) + 1
|
||||
|
||||
mat = [[0]*len_str2 for i in range(len_str1)]
|
||||
mat[0][0] = 0
|
||||
for i in range(1,len_str1):
|
||||
mat[i][0] = mat[i-1][0] + 1
|
||||
for j in range(1,len_str2):
|
||||
mat[0][j] = mat[0][j-1]+1
|
||||
for i in range(1,len_str1):
|
||||
for j in range(1,len_str2):
|
||||
if str1[i-1] == str2[j-1]:
|
||||
mat[i][j] = mat[i-1][j-1]
|
||||
else:
|
||||
matrix[i][j] = max(matrix[i - 1][j - 1], matrix[i - 1][j], matrix[i][j - 1]) + 1
|
||||
|
||||
return matrix[m - 1][n - 1]
|
||||
mat[i][j] = min(mat[i-1][j-1],mat[i-1][j],mat[i][j-1])+1
|
||||
|
||||
return 1 / mat[len_str1-1][j-1]
|
||||
|
@ -12,10 +12,10 @@ __description__ = "Get sql information based on wdr."
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
description=__description__)
|
||||
parser.add_argument('--port', help="User of remote server.", type=int, required=True)
|
||||
parser.add_argument('--start_time', help="Start time of query", required=True)
|
||||
parser.add_argument('--finish_time', help="Finish time of query", required=True)
|
||||
parser.add_argument('--save_path', default='sample_data/data.csv', help="Path to save result")
|
||||
parser.add_argument('--port', help="Port of database service.", type=int, required=True)
|
||||
parser.add_argument('--start-time', help="Start time of query", required=True)
|
||||
parser.add_argument('--finish-time', help="Finish time of query", required=True)
|
||||
parser.add_argument('--save-path', default='sample_data/data.csv', help="Path to save result")
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
|
@ -19,6 +19,7 @@ from configparser import ConfigParser
|
||||
|
||||
from algorithm.diag import SQLDiag
|
||||
from utils import ResultSaver
|
||||
from preprocessing import LoadData, split_sql
|
||||
|
||||
__version__ = '2.0.0'
|
||||
__description__ = 'SQLdiag integrated by openGauss.'
|
||||
@ -40,6 +41,8 @@ def parse_args():
|
||||
parser.add_argument('--predicted-file', help='The file path to save the predicted result.')
|
||||
parser.add_argument('--model', default='template', choices=['template', 'dnn'],
|
||||
help='Choose the model model to use.')
|
||||
parser.add_argument('--query', help='Input the querys to predict.')
|
||||
parser.add_argument('--threshold', help='Slow SQL threshold.')
|
||||
parser.add_argument('--model-path', required=True,
|
||||
help='The storage path of the model file, used to read or save the model file.')
|
||||
parser.add_argument('--config-file', default='sqldiag.conf')
|
||||
@ -54,23 +57,64 @@ def get_config(filepath):
|
||||
|
||||
|
||||
def main(args):
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
model = SQLDiag(args.model, args.csv_file, get_config(args.config_file))
|
||||
if args.mode == 'train':
|
||||
model.fit()
|
||||
model.save(args.model_path)
|
||||
elif args.mode == 'predict':
|
||||
if not args.predicted_file:
|
||||
logging.error("The [--predicted-file] parameter is required for predict mode")
|
||||
logging.basicConfig(level=logging.WARNING)
|
||||
model = SQLDiag(args.model, get_config(args.config_file))
|
||||
if args.mode in ('train', 'finetune'):
|
||||
if not args.csv_file:
|
||||
logging.fatal('The [--csv-file] parameter is required for train mode')
|
||||
sys.exit(1)
|
||||
model.load(args.model_path)
|
||||
pred_result = model.transform()
|
||||
ResultSaver().save(pred_result, args.predicted_file)
|
||||
logging.info('predicted result in saved in {}'.format(args.predicted_file))
|
||||
elif args.mode == 'finetune':
|
||||
model.fine_tune(args.model_path)
|
||||
train_data = LoadData(args.csv_file).train_data
|
||||
if args.mode == 'train':
|
||||
model.fit(train_data)
|
||||
else:
|
||||
model.fine_tune(args.model_path, train_data)
|
||||
model.save(args.model_path)
|
||||
else:
|
||||
model.load(args.model_path)
|
||||
if args.csv_file and not args.query:
|
||||
predict_data = LoadData(args.csv_file).predict_data
|
||||
elif args.query and not args.csv_file:
|
||||
predict_data = split_sql(args.query)
|
||||
else:
|
||||
logging.error('The predict model only supports [--csv-file] or [--query] at the same time.')
|
||||
sys.exit(1)
|
||||
args.threshold = -100 if not args.threshold else float(args.threshold)
|
||||
pred_result = model.transform(predict_data)
|
||||
if args.predicted_file:
|
||||
if args.model == 'template':
|
||||
info_sum = []
|
||||
for stats, _info in pred_result.items():
|
||||
if _info:
|
||||
_info = list(filter(lambda item: item[1]>=args.threshold, _info))
|
||||
for item in _info:
|
||||
item.insert(1, stats)
|
||||
info_sum.extend(_info)
|
||||
ResultSaver().save(info_sum, args.predicted_file)
|
||||
else:
|
||||
pred_result = list(filter(lambda item: float(item[1])>=args.threshold, pred_result))
|
||||
ResultSaver().save(pred_result, args.predicted_file)
|
||||
else:
|
||||
from prettytable import PrettyTable
|
||||
|
||||
display_table = PrettyTable()
|
||||
if args.model == 'template':
|
||||
display_table.field_names = ['sql', 'status', 'predicted time', 'most similar template']
|
||||
display_table.align = 'l'
|
||||
status = ('Suspect illegal SQL', 'No SQL information', 'No SQL template found', 'Fine match')
|
||||
for stats in status:
|
||||
if pred_result[stats]:
|
||||
for sql, predicted_time, similariest_sql in pred_result[stats]:
|
||||
if predicted_time >= args.threshold or stats == 'Suspect illegal sql':
|
||||
display_table.add_row([sql, stats, predicted_time, similariest_sql])
|
||||
else:
|
||||
display_table.field_names = ['sql', 'predicted time']
|
||||
display_table.align = 'l'
|
||||
for sql, predicted_time in pred_result:
|
||||
if float(predicted_time) >= args.threshold:
|
||||
display_table.add_row([sql, predicted_time])
|
||||
print(display_table.get_string())
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main(parse_args())
|
||||
|
||||
|
@ -18,8 +18,6 @@ import sqlparse
|
||||
from sqlparse.sql import Identifier, IdentifierList
|
||||
from sqlparse.tokens import Keyword, DML
|
||||
|
||||
# split flag in SQL
|
||||
split_flag = ('!=', '<=', '>=', '==', '<', '>', '=', ',', '(', ')', '*', ';', '%', '+', ',', ';')
|
||||
|
||||
DDL_WORDS = ('CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'COMMIT', 'RENAME')
|
||||
DML_WORDS = ('SELECT', 'INSERT INTO', 'UPDATE', 'DELETE', 'MERGE', 'CALL',
|
||||
@ -31,53 +29,6 @@ KEYWORDS = ('GRANT', 'REVOKE', 'DENY', 'ABORT', 'ADD', 'AGGREGATE', 'ANALYSE', '
|
||||
FUNC = ('FLOOR', 'SUM', 'NOW', 'UUID', 'COUNT')
|
||||
SQL_SIG = ('&', '&&')
|
||||
|
||||
# filter like (insert into aa (c1, c2) values (v1, v2) => insert into aa * values *)
|
||||
BRACKET_FILTER = r'\(.*?\)'
|
||||
|
||||
# filter (123, 123.123)
|
||||
PURE_DIGIT_FILTER = r'[\s]+\d+(\.\d+)?'
|
||||
|
||||
# filter ('123', '123.123')
|
||||
SINGLE_QUOTE_DIGIT_FILTER = r'\'\d+(\.\d+)?\''
|
||||
|
||||
# filter ("123", "123.123")
|
||||
DOUBLE_QUOTE_DIGIT_FILTER = r'"\d+(\.\d+)?"'
|
||||
|
||||
# filter ('123', 123, '123,123', 123.123) not filter(table1, column1, table_2, column_2)
|
||||
DIGIT_FILTER = r'([^a-zA-Z])_?\d+(\.\d+)?'
|
||||
|
||||
# filter date in sql ('1999-09-09', '1999/09/09', "1999-09-09 20:10:10", '1999/09/09 20:10:10.12345')
|
||||
PURE_TIME_FILTER = r'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?'
|
||||
SINGLE_QUOTE_TIME_FILTER = r'\'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,' \
|
||||
r'2})?(\.\d+)?\' '
|
||||
DOUBLE_QUOTE_TIME_FILTER = r'"[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?"'
|
||||
|
||||
# filter like "where id='abcd" => "where id=#"
|
||||
SINGLE_QUOTE_FILTER = r'\'.*?\''
|
||||
|
||||
# filter like 'where id="abcd" => 'where id=#'
|
||||
DOUBLE_QUOTE_FILTER = r'".*?"'
|
||||
|
||||
# filter annotation like "/* XXX */"
|
||||
ANNOTATION_FILTER_1 = r'/\s*\*[\w\W]*?\*\s*/\s*'
|
||||
ANNOTATION_FILTER_2 = r'^--.*\s?'
|
||||
|
||||
# filter NULL character '\n \t' in sql
|
||||
NULL_CHARACTER_FILTER = r'\s+'
|
||||
|
||||
# remove data in insert sql
|
||||
VALUE_BRACKET_FILETER = r'VALUES (\(.*\))'
|
||||
|
||||
# remove equal data in sql
|
||||
WHERE_EQUAL_FILTER = r'= .*?\s'
|
||||
|
||||
LESS_EQUAL_FILTER = r'(<= .*? |<= .*$)'
|
||||
GREATER_EQUAL_FILTER = r'(>= .*? |<= .*$)'
|
||||
LESS_FILTER = r'(< .*? |< .*$)'
|
||||
GREATER_FILTER = r'(> .*? |> .*$)'
|
||||
EQUALS_FILTER = r'(= .*? |= .*$)'
|
||||
LIMIT_DIGIT = r'LIMIT \d+'
|
||||
|
||||
|
||||
def _unify_sql(sql):
|
||||
"""
|
||||
@ -85,65 +36,49 @@ def _unify_sql(sql):
|
||||
"""
|
||||
index = 0
|
||||
sql = re.sub(r'\n', r' ', sql)
|
||||
sql = re.sub(ANNOTATION_FILTER_1, r'', sql)
|
||||
sql = re.sub(ANNOTATION_FILTER_2, r'', sql)
|
||||
while index < len(sql):
|
||||
if sql[index] in split_flag:
|
||||
if sql[index:index + 2] in split_flag:
|
||||
sql = sql[:index].strip() + ' ' + sql[index:index + 2] + ' ' + sql[index + 2:].strip()
|
||||
index = index + 3
|
||||
else:
|
||||
sql = sql[:index].strip() + ' ' + sql[index] + ' ' + sql[index + 1:].strip()
|
||||
index = index + 2
|
||||
else:
|
||||
index = index + 1
|
||||
new_sql = list()
|
||||
for word in sql.split():
|
||||
new_sql.append(word.upper())
|
||||
sql = ' '.join(new_sql)
|
||||
sql = re.sub(r'/\s*\*[\w\W]*?\*\s*/\s*', r'', sql)
|
||||
sql = re.sub(r'^--.*\s?', r'', sql)
|
||||
|
||||
sql = re.sub(r'([!><=]=)', r' \1 ', sql)
|
||||
sql = re.sub(r'([^!><=])([=<>])', r'\1 \2 ', sql)
|
||||
sql = re.sub(r'([,()*%/+])', r' \1 ', sql)
|
||||
sql = re.sub(r'\s+', r' ', sql)
|
||||
sql = sql.upper()
|
||||
return sql.strip()
|
||||
|
||||
|
||||
def split_sql(sqls):
|
||||
if not sqls:
|
||||
return []
|
||||
sqls = sqls.split(';')
|
||||
result = list(map(lambda item: _unify_sql(item), sqls))
|
||||
return result
|
||||
|
||||
|
||||
def templatize_sql(sql):
|
||||
"""
|
||||
function: replace the message which is not important in sql
|
||||
SQL desensitization
|
||||
"""
|
||||
sql = _unify_sql(sql)
|
||||
if not sql:
|
||||
return ''
|
||||
standard_sql = _unify_sql(sql)
|
||||
|
||||
sql = re.sub(r';', r'', sql)
|
||||
if standard_sql.startswith('INSERT'):
|
||||
standard_sql = re.sub(r'VALUES (\(.*\))', r'VALUES', standard_sql)
|
||||
# remove digital like 12, 12.565
|
||||
standard_sql = re.sub(r'[\s]+\d+(\.\d+)?', r' ?', standard_sql)
|
||||
# remove '$n' in sql
|
||||
standard_sql = re.sub(r'\$\d+', r'?', standard_sql)
|
||||
# remove single quotes content
|
||||
standard_sql = re.sub(r'\'.*?\'', r'?', standard_sql)
|
||||
# remove double quotes content
|
||||
standard_sql = re.sub(r'".*?"', r'?', standard_sql)
|
||||
# remove '`' in sql
|
||||
standard_sql = re.sub(r'`', r'', standard_sql)
|
||||
# remove ; in sql
|
||||
standard_sql = re.sub(r';', r'', standard_sql)
|
||||
|
||||
# ? represent date or time
|
||||
sql = re.sub(PURE_TIME_FILTER, r'?', sql)
|
||||
sql = re.sub(SINGLE_QUOTE_TIME_FILTER, r'?', sql)
|
||||
sql = re.sub(DOUBLE_QUOTE_TIME_FILTER, r'?', sql)
|
||||
|
||||
# $ represent insert value
|
||||
if sql.startswith('INSERT'):
|
||||
sql = re.sub(VALUE_BRACKET_FILETER, r'VALUES ()', sql)
|
||||
|
||||
# $$ represent select value
|
||||
if sql.startswith('SELECT') and ' = ' in sql:
|
||||
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$ ', sql)
|
||||
|
||||
# $$$ represent delete value
|
||||
if sql.startswith('DELETE') and ' = ' in sql:
|
||||
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$$ ', sql)
|
||||
|
||||
# & represent logical signal
|
||||
sql = re.sub(LESS_EQUAL_FILTER, r'<= & ', sql)
|
||||
sql = re.sub(LESS_FILTER, r'< & ', sql)
|
||||
sql = re.sub(GREATER_EQUAL_FILTER, r'>= & ', sql)
|
||||
sql = re.sub(GREATER_FILTER, r'> & ', sql)
|
||||
sql = re.sub(LIMIT_DIGIT, r'LIMIT &', sql)
|
||||
sql = re.sub(EQUALS_FILTER, r'= & ', sql)
|
||||
sql = re.sub(PURE_DIGIT_FILTER, r' &', sql)
|
||||
sql = re.sub(r'`', r'', sql)
|
||||
|
||||
# && represent quote str
|
||||
sql = re.sub(SINGLE_QUOTE_FILTER, r'?', sql)
|
||||
sql = re.sub(DOUBLE_QUOTE_FILTER, r'?', sql)
|
||||
|
||||
return sql
|
||||
return standard_sql.strip()
|
||||
|
||||
|
||||
def _is_select_clause(parsed_sql):
|
||||
@ -155,7 +90,6 @@ def _is_select_clause(parsed_sql):
|
||||
return False
|
||||
|
||||
|
||||
# todo: what is token list? from list?
|
||||
def _get_table_token_list(parsed_sql, token_list):
|
||||
flag = False
|
||||
for token in parsed_sql.tokens:
|
||||
|
@ -26,6 +26,7 @@ min_limit = 6.78e-05
|
||||
# Template Methed
|
||||
#------------------------------------------------------------------------------
|
||||
[template]
|
||||
similarity_algorithm =
|
||||
time_list_size =
|
||||
knn_number =
|
||||
# [cosine_distance, levenshtein, list, parse_tree]
|
||||
similarity_algorithm = cosine_distance
|
||||
time_list_size = 20
|
||||
knn_number = 1
|
||||
|
@ -25,20 +25,15 @@ class ResultSaver:
|
||||
os.chmod(dirname, stat.S_IRWXU)
|
||||
if isinstance(data, (list, tuple)):
|
||||
self.save_list(data, realpath)
|
||||
elif isinstance(data, dict):
|
||||
self.save_dict(data, path)
|
||||
else:
|
||||
raise TypeError("mode should be 'list', 'tuple' or 'dict', but input type is '{}'".format(str(type(data))))
|
||||
|
||||
@staticmethod
|
||||
def save_list(data, path):
|
||||
data = pd.DataFrame(data)
|
||||
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"')
|
||||
|
||||
@staticmethod
|
||||
def save_dict(data, path):
|
||||
data = pd.DataFrame(data.items())
|
||||
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"')
|
||||
with open(path, mode='w') as f:
|
||||
for item in data:
|
||||
content = ",".join([str(sub_item) for sub_item in item])
|
||||
f.write(content + '\n')
|
||||
|
||||
|
||||
class DBAgent:
|
||||
@ -74,7 +69,7 @@ class DBAgent:
|
||||
result = list(self.cursor.fetchall())
|
||||
return result
|
||||
except Exception as e:
|
||||
logging.getLogger('agent').warning(str(e))
|
||||
logging.warning(str(e))
|
||||
|
||||
def close(self):
|
||||
self.cursor.close()
|
||||
|
Reference in New Issue
Block a user