!996 Update SQLDiag: Add two threshold judgments

Merge pull request !996 from wangtq/master
This commit is contained in:
opengauss-bot
2021-05-28 19:34:54 +08:00
committed by Gitee
12 changed files with 235 additions and 239 deletions

View File

@ -71,13 +71,17 @@ predict dataset sample: [sample_data/predict.csv](data/predict.csv)
## example for template method using sample dataset ## example for template method using sample dataset
# training # train
python main.py train -f ./sample_data/train.csv --model template --model-path ./template python main.py train -f ./sample_data/train.csv --model template --model-path ./template
# predict # predict
python main.py predict -f ./sample_data/predict.csv --model template --model-path ./template python main.py predict -f ./sample_data/predict.csv --model template --model-path ./template
--predicted-file ./result/t_result --predicted-file ./result/t_result
# predict with threshold
python main.py predict -f ./sample_data/predict.csv --threshold 0.02 --model template --model-path ./template
--predicted-file ./result/t_result
# update model # update model
python main.py finetune -f ./sample_data/train.csv --model template --model-path ./template python main.py finetune -f ./sample_data/train.csv --model template --model-path ./template

View File

@ -9,9 +9,9 @@ W2V_SUFFIX = 'word2vector'
def check_template_algorithm(param): def check_template_algorithm(param):
if param and param not in ["list", "levenshtein", "parse_tree"]: if param and param not in ["list", "levenshtein", "parse_tree", "cosine_distance"]:
raise ValueError("The similarity algorithm '%s' is invaild, " raise ValueError("The similarity algorithm '%s' is invaild, "
"please choose from ['list', 'levenshtein', 'parse_tree']" % param) "please choose from ['list', 'levenshtein', 'parse_tree', 'cosine_distance']" % param)
class ModelConfig(object): class ModelConfig(object):
@ -62,7 +62,7 @@ SUPPORTED_ALGORITHM = {'dnn': lambda config: DnnModel(DnnConfig.init_from_config
class SQLDiag: class SQLDiag:
def __init__(self, model_algorithm, csv_file, params): def __init__(self, model_algorithm, params):
if model_algorithm not in SUPPORTED_ALGORITHM: if model_algorithm not in SUPPORTED_ALGORITHM:
raise NotImplementedError("do not support {}".format(model_algorithm)) raise NotImplementedError("do not support {}".format(model_algorithm))
try: try:
@ -70,20 +70,16 @@ class SQLDiag:
except ValueError as e: except ValueError as e:
logging.error(e, exc_info=True) logging.error(e, exc_info=True)
sys.exit(1) sys.exit(1)
self.load_data = LoadData(csv_file)
def __getattr__(self, item): def fit(self, data):
return getattr(self.load_data, item) self._model.fit(data)
def fit(self): def transform(self, data):
self._model.fit(self.train_data) return self._model.transform(data)
def transform(self): def fine_tune(self, filepath, data):
return self._model.transform(self.predict_data)
def fine_tune(self, filepath):
self._model.load(filepath) self._model.load(filepath)
self._model.fit(self.train_data) self._model.fit(data)
def load(self, filepath): def load(self, filepath):
self._model.load(filepath) self._model.load(filepath)

View File

@ -36,7 +36,7 @@ class KerasRegression:
shape = features.shape[1] shape = features.shape[1]
if self.model is None: if self.model is None:
self.model = self.build_model(shape=shape, encoding_dim=self.encoding_dim) self.model = self.build_model(shape=shape, encoding_dim=self.encoding_dim)
self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=2) self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=0)
def predict(self, features): def predict(self, features):
predict_result = self.model.predict(features) predict_result = self.model.predict(features)
@ -129,10 +129,6 @@ class DnnModel(AbstractModel, ABC):
self.w2v.load(word2vector_path) self.w2v.load(word2vector_path)
with open(scaler_path, 'rb') as f: with open(scaler_path, 'rb') as f:
self.scaler = pickle.load(f) self.scaler = pickle.load(f)
logging.info("dnn model is loaded: '{}'; w2v model is loaded: '{}'; scaler model is loaded: '{}'."
.format(dnn_path,
word2vector_path,
scaler_path))
else: else:
logging.error("{} not exist.".format(realpath)) logging.error("{} not exist.".format(realpath))
@ -149,7 +145,5 @@ class DnnModel(AbstractModel, ABC):
self.w2v.save(word2vector_path) self.w2v.save(word2vector_path)
with open(scaler_path, 'wb') as f: with open(scaler_path, 'wb') as f:
pickle.dump(self.scaler, f) pickle.dump(self.scaler, f)
logging.info("dnn model is saved: '{}'; w2v model is saved: '{}'; scaler model is saved: '{}'." print("DNN model is stored in '{}'".format(realpath))
.format(dnn_path,
word2vector_path,
scaler_path))

View File

@ -18,9 +18,10 @@ import logging
import os import os
import stat import stat
from functools import reduce from functools import reduce
from collections import defaultdict
from algorithm.sql_similarity import calc_sql_distance from algorithm.sql_similarity import calc_sql_distance
from preprocessing import get_sql_template from preprocessing import get_sql_template, templatize_sql
from utils import check_illegal_sql, LRUCache from utils import check_illegal_sql, LRUCache
from . import AbstractModel from . import AbstractModel
@ -40,97 +41,78 @@ class TemplateModel(AbstractModel):
for sql, duration_time in data: for sql, duration_time in data:
if check_illegal_sql(sql): if check_illegal_sql(sql):
continue continue
# get 'fine_template' and 'rough_template' of SQL sql_template = templatize_sql(sql)
fine_template, rough_template = get_sql_template(sql) sql_prefix = sql_template.split()[0]
sql_prefix = fine_template.split()[0]
# if prefix of SQL is not in 'update', 'delete', 'select' and 'insert',
# then convert prefix to 'other'
if sql_prefix not in self.__hash_table: if sql_prefix not in self.__hash_table:
sql_prefix = 'OTHER' sql_prefix = 'OTHER'
if rough_template not in self.__hash_table[sql_prefix]: if sql_template not in self.__hash_table[sql_prefix]:
self.__hash_table[sql_prefix][rough_template] = dict() self.__hash_table[sql_prefix][sql_template] = dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
self.__hash_table[sql_prefix][rough_template]['info'] = dict() self.__hash_table[sql_prefix][sql_template]['count'] += 1
if fine_template not in self.__hash_table[sql_prefix][rough_template]['info']: self.__hash_table[sql_prefix][sql_template]['time_list'].append(duration_time)
self.__hash_table[sql_prefix][rough_template]['info'][fine_template] = \
dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
# count the number of occurrences of fine template
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['count'] += 1
# store the execution time of the matched template in the corresponding list
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'time_list'].append(duration_time)
# iterative calculation of execution time based on historical data
if not self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'iter_time']:
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'iter_time'] = duration_time
else:
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time'] = \
(self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'iter_time'] + duration_time) / 2
# calculate the average execution time of each template
for sql_prefix, sql_prefix_info in self.__hash_table.items(): for sql_prefix, sql_prefix_info in self.__hash_table.items():
for rough_template, rough_template_info in sql_prefix_info.items(): for sql_template, sql_template_info in sql_prefix_info.items():
for _, fine_template_info in rough_template_info['info'].items(): del sql_template_info['time_list'][:-self.time_list_size]
del fine_template_info['time_list'][:-self.time_list_size] sql_template_info['mean_time'] = sum(sql_template_info['time_list']) / len(sql_template_info['time_list'])
fine_template_info['mean_time'] = \ sql_template_info['iter_time'] = reduce(lambda x, y: (x+y)/2, sql_template_info['time_list'])
sum(fine_template_info['time_list']) / len(
fine_template_info['time_list'])
rough_template_info['count'] = len(rough_template_info['info'])
rough_template_info['mean_time'] = sum(
[value['mean_time'] for key, value in
rough_template_info['info'].items()]) / len(
rough_template_info['info'])
def transform(self, data): def transform(self, data):
predict_time_list = [] predict_result_dict = defaultdict(list)
for sql in data: for sql in data:
predict_time = self.predict_duration_time(sql) sql_, status, predict_time, top_similarity_sql = self.predict_duration_time(sql)
predict_time_list.append([sql, predict_time]) predict_result_dict[status].append([sql_, predict_time, top_similarity_sql])
return predict_time_list for key, value in predict_result_dict.items():
if value:
value.sort(key=lambda item: item[1], reverse=True)
return predict_result_dict
@LRUCache(max_size=1024) @LRUCache(max_size=1024)
def predict_duration_time(self, sql): def predict_duration_time(self, sql):
top_similarity_sql = None
if check_illegal_sql(sql): if check_illegal_sql(sql):
return -1 predict_time = -1
sql_prefix = sql.strip().split()[0] status = 'Suspect illegal sql'
# get 'fine_template' and 'rough_template' of SQL return sql, status, predict_time, top_similarity_sql
fine_template, rough_template = get_sql_template(sql)
sql_template = templatize_sql(sql)
# get 'sql_template' of SQL
sql_prefix = sql_template.strip().split()[0]
if sql_prefix not in self.__hash_table: if sql_prefix not in self.__hash_table:
sql_prefix = 'OTHER' sql_prefix = 'OTHER'
if not self.__hash_table[sql_prefix]: if not self.__hash_table[sql_prefix]:
logging.warning("'{}' not in the templates.".format(sql)) status = 'No SQL information'
predict_time = -1 predict_time = -1
elif rough_template not in self.__hash_table[sql_prefix] or fine_template not in \ elif sql_template not in self.__hash_table[sql_prefix]:
self.__hash_table[sql_prefix][rough_template]['info']:
similarity_info = [] similarity_info = []
""" """
if the template does not exist in the hash table, if the template does not exist in the hash table,
then calculate the possible execution time based on template then calculate the possible execution time based on template
similarity and KNN algorithm in all other templates similarity and KNN algorithm in all other templates
""" """
if rough_template not in self.__hash_table[sql_prefix]: status = 'No SQL template found'
for local_rough_template, local_rough_template_info in self.__hash_table[sql_prefix].items(): for local_sql_template, local_sql_template_info in self.__hash_table[sql_prefix].items():
similarity_info.append( similarity_info.append(
(self.similarity_algorithm(rough_template, local_rough_template), (self.similarity_algorithm(sql_template, local_sql_template),
local_rough_template_info['mean_time'])) local_sql_template_info['mean_time'], local_sql_template))
else:
for local_fine_template, local_fine_template_info in \
self.__hash_table[sql_prefix][rough_template]['info'].items():
similarity_info.append(
(self.similarity_algorithm(fine_template, local_fine_template),
local_fine_template_info['iter_time']))
topn_similarity_info = heapq.nlargest(self.knn_number, similarity_info) topn_similarity_info = heapq.nlargest(self.knn_number, similarity_info)
sum_similarity_scores = sum(item[0] for item in topn_similarity_info) + self.bias sum_similarity_scores = sum(item[0] for item in topn_similarity_info)
similarity_proportions = (item[0] / sum_similarity_scores for item in if not sum_similarity_scores:
topn_similarity_info) sum_similarity_scores = self.bias
topn_duration_time = (item[1] for item in topn_similarity_info) top_similarity_sql = '\n'.join([item[2] for item in topn_similarity_info])
similarity_proportions = [item[0] / sum_similarity_scores for item in
topn_similarity_info]
topn_duration_time = [item[1] for item in topn_similarity_info]
predict_time = reduce(lambda x, y: x + y, predict_time = reduce(lambda x, y: x + y,
map(lambda x, y: x * y, similarity_proportions, map(lambda x, y: x * y, similarity_proportions,
topn_duration_time)) topn_duration_time))
else:
predict_time = self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time']
return predict_time else:
status = 'Fine match'
predict_time = self.__hash_table[sql_prefix][sql_template]['iter_time']
top_similarity_sql = sql_template
return sql, status, predict_time, top_similarity_sql
def load(self, filepath): def load(self, filepath):
realpath = os.path.realpath(filepath) realpath = os.path.realpath(filepath)
@ -138,7 +120,6 @@ class TemplateModel(AbstractModel):
template_path = os.path.join(realpath, 'template.json') template_path = os.path.join(realpath, 'template.json')
with open(template_path, mode='r') as f: with open(template_path, mode='r') as f:
self.__hash_table = json.load(f) self.__hash_table = json.load(f)
logging.info("template model '{}' is loaded.".format(template_path))
else: else:
logging.error("{} not exist.".format(realpath)) logging.error("{} not exist.".format(realpath))
@ -151,4 +132,5 @@ class TemplateModel(AbstractModel):
template_path = os.path.join(realpath, 'template.json') template_path = os.path.join(realpath, 'template.json')
with open(template_path, mode='w') as f: with open(template_path, mode='w') as f:
json.dump(self.__hash_table, f, indent=4) json.dump(self.__hash_table, f, indent=4)
logging.info("template model is stored in '{}'".format(realpath)) print("Template model is stored in '{}'".format(realpath))

View File

@ -5,6 +5,8 @@ def calc_sql_distance(algorithm):
from .levenshtein import distance from .levenshtein import distance
elif algorithm == 'parse_tree': elif algorithm == 'parse_tree':
from .parse_tree import distance from .parse_tree import distance
elif algorithm == 'cosine_distance':
from .cosine_distance import distance
else: else:
raise NotImplementedError("do not support '{}'".format(algorithm)) raise NotImplementedError("do not support '{}'".format(algorithm))
return distance return distance

View File

@ -0,0 +1,26 @@
"""
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
import math
from collections import Counter
def distance(str1, str2):
c_1 = Counter(str1)
c_2 = Counter(str2)
c_union = set(c_1).union(c_2)
dot_product = sum(c_1.get(item, 0) * c_2.get(item, 0) for item in c_union)
mag_c1 = math.sqrt(sum(c_1.get(item, 0)**2 for item in c_union))
mag_c2 = math.sqrt(sum(c_2.get(item, 0)**2 for item in c_union))
return dot_product / (mag_c1 * mag_c2)

View File

@ -1,3 +1,19 @@
"""
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
def distance(str1, str2): def distance(str1, str2):
""" """
func: calculate levenshtein distance between two strings. func: calculate levenshtein distance between two strings.
@ -5,18 +21,20 @@ def distance(str1, str2):
:param str2: string2 :param str2: string2
:return: distance :return: distance
""" """
m, n = len(str1) + 1, len(str2) + 1 len_str1 = len(str1) + 1
matrix = [[0] * n for _ in range(m)] len_str2 = len(str2) + 1
matrix[0][0] = 0
for i in range(1, m):
matrix[i][0] = matrix[i - 1][0] + 1
for j in range(1, n):
matrix[0][j] = matrix[0][j - 1] + 1
for i in range(1, m):
for j in range(1, n):
if str1[i - 1] == str2[j - 1]:
matrix[i][j] = matrix[i - 1][j - 1]
else:
matrix[i][j] = max(matrix[i - 1][j - 1], matrix[i - 1][j], matrix[i][j - 1]) + 1
return matrix[m - 1][n - 1] mat = [[0]*len_str2 for i in range(len_str1)]
mat[0][0] = 0
for i in range(1,len_str1):
mat[i][0] = mat[i-1][0] + 1
for j in range(1,len_str2):
mat[0][j] = mat[0][j-1]+1
for i in range(1,len_str1):
for j in range(1,len_str2):
if str1[i-1] == str2[j-1]:
mat[i][j] = mat[i-1][j-1]
else:
mat[i][j] = min(mat[i-1][j-1],mat[i-1][j],mat[i][j-1])+1
return 1 / mat[len_str1-1][j-1]

View File

@ -12,10 +12,10 @@ __description__ = "Get sql information based on wdr."
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter, parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
description=__description__) description=__description__)
parser.add_argument('--port', help="User of remote server.", type=int, required=True) parser.add_argument('--port', help="Port of database service.", type=int, required=True)
parser.add_argument('--start_time', help="Start time of query", required=True) parser.add_argument('--start-time', help="Start time of query", required=True)
parser.add_argument('--finish_time', help="Finish time of query", required=True) parser.add_argument('--finish-time', help="Finish time of query", required=True)
parser.add_argument('--save_path', default='sample_data/data.csv', help="Path to save result") parser.add_argument('--save-path', default='sample_data/data.csv', help="Path to save result")
return parser.parse_args() return parser.parse_args()

View File

@ -19,6 +19,7 @@ from configparser import ConfigParser
from algorithm.diag import SQLDiag from algorithm.diag import SQLDiag
from utils import ResultSaver from utils import ResultSaver
from preprocessing import LoadData, split_sql
__version__ = '2.0.0' __version__ = '2.0.0'
__description__ = 'SQLdiag integrated by openGauss.' __description__ = 'SQLdiag integrated by openGauss.'
@ -40,6 +41,8 @@ def parse_args():
parser.add_argument('--predicted-file', help='The file path to save the predicted result.') parser.add_argument('--predicted-file', help='The file path to save the predicted result.')
parser.add_argument('--model', default='template', choices=['template', 'dnn'], parser.add_argument('--model', default='template', choices=['template', 'dnn'],
help='Choose the model model to use.') help='Choose the model model to use.')
parser.add_argument('--query', help='Input the querys to predict.')
parser.add_argument('--threshold', help='Slow SQL threshold.')
parser.add_argument('--model-path', required=True, parser.add_argument('--model-path', required=True,
help='The storage path of the model file, used to read or save the model file.') help='The storage path of the model file, used to read or save the model file.')
parser.add_argument('--config-file', default='sqldiag.conf') parser.add_argument('--config-file', default='sqldiag.conf')
@ -54,23 +57,64 @@ def get_config(filepath):
def main(args): def main(args):
logging.basicConfig(level=logging.INFO) logging.basicConfig(level=logging.WARNING)
model = SQLDiag(args.model, args.csv_file, get_config(args.config_file)) model = SQLDiag(args.model, get_config(args.config_file))
if args.mode == 'train': if args.mode in ('train', 'finetune'):
model.fit() if not args.csv_file:
model.save(args.model_path) logging.fatal('The [--csv-file] parameter is required for train mode')
elif args.mode == 'predict':
if not args.predicted_file:
logging.error("The [--predicted-file] parameter is required for predict mode")
sys.exit(1) sys.exit(1)
model.load(args.model_path) train_data = LoadData(args.csv_file).train_data
pred_result = model.transform() if args.mode == 'train':
ResultSaver().save(pred_result, args.predicted_file) model.fit(train_data)
logging.info('predicted result in saved in {}'.format(args.predicted_file)) else:
elif args.mode == 'finetune': model.fine_tune(args.model_path, train_data)
model.fine_tune(args.model_path)
model.save(args.model_path) model.save(args.model_path)
else:
model.load(args.model_path)
if args.csv_file and not args.query:
predict_data = LoadData(args.csv_file).predict_data
elif args.query and not args.csv_file:
predict_data = split_sql(args.query)
else:
logging.error('The predict model only supports [--csv-file] or [--query] at the same time.')
sys.exit(1)
args.threshold = -100 if not args.threshold else float(args.threshold)
pred_result = model.transform(predict_data)
if args.predicted_file:
if args.model == 'template':
info_sum = []
for stats, _info in pred_result.items():
if _info:
_info = list(filter(lambda item: item[1]>=args.threshold, _info))
for item in _info:
item.insert(1, stats)
info_sum.extend(_info)
ResultSaver().save(info_sum, args.predicted_file)
else:
pred_result = list(filter(lambda item: float(item[1])>=args.threshold, pred_result))
ResultSaver().save(pred_result, args.predicted_file)
else:
from prettytable import PrettyTable
display_table = PrettyTable()
if args.model == 'template':
display_table.field_names = ['sql', 'status', 'predicted time', 'most similar template']
display_table.align = 'l'
status = ('Suspect illegal SQL', 'No SQL information', 'No SQL template found', 'Fine match')
for stats in status:
if pred_result[stats]:
for sql, predicted_time, similariest_sql in pred_result[stats]:
if predicted_time >= args.threshold or stats == 'Suspect illegal sql':
display_table.add_row([sql, stats, predicted_time, similariest_sql])
else:
display_table.field_names = ['sql', 'predicted time']
display_table.align = 'l'
for sql, predicted_time in pred_result:
if float(predicted_time) >= args.threshold:
display_table.add_row([sql, predicted_time])
print(display_table.get_string())
if __name__ == '__main__': if __name__ == '__main__':
main(parse_args()) main(parse_args())

View File

@ -18,8 +18,6 @@ import sqlparse
from sqlparse.sql import Identifier, IdentifierList from sqlparse.sql import Identifier, IdentifierList
from sqlparse.tokens import Keyword, DML from sqlparse.tokens import Keyword, DML
# split flag in SQL
split_flag = ('!=', '<=', '>=', '==', '<', '>', '=', ',', '(', ')', '*', ';', '%', '+', ',', ';')
DDL_WORDS = ('CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'COMMIT', 'RENAME') DDL_WORDS = ('CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'COMMIT', 'RENAME')
DML_WORDS = ('SELECT', 'INSERT INTO', 'UPDATE', 'DELETE', 'MERGE', 'CALL', DML_WORDS = ('SELECT', 'INSERT INTO', 'UPDATE', 'DELETE', 'MERGE', 'CALL',
@ -31,53 +29,6 @@ KEYWORDS = ('GRANT', 'REVOKE', 'DENY', 'ABORT', 'ADD', 'AGGREGATE', 'ANALYSE', '
FUNC = ('FLOOR', 'SUM', 'NOW', 'UUID', 'COUNT') FUNC = ('FLOOR', 'SUM', 'NOW', 'UUID', 'COUNT')
SQL_SIG = ('&', '&&') SQL_SIG = ('&', '&&')
# filter like (insert into aa (c1, c2) values (v1, v2) => insert into aa * values *)
BRACKET_FILTER = r'\(.*?\)'
# filter (123, 123.123)
PURE_DIGIT_FILTER = r'[\s]+\d+(\.\d+)?'
# filter ('123', '123.123')
SINGLE_QUOTE_DIGIT_FILTER = r'\'\d+(\.\d+)?\''
# filter ("123", "123.123")
DOUBLE_QUOTE_DIGIT_FILTER = r'"\d+(\.\d+)?"'
# filter ('123', 123, '123,123', 123.123) not filter(table1, column1, table_2, column_2)
DIGIT_FILTER = r'([^a-zA-Z])_?\d+(\.\d+)?'
# filter date in sql ('1999-09-09', '1999/09/09', "1999-09-09 20:10:10", '1999/09/09 20:10:10.12345')
PURE_TIME_FILTER = r'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?'
SINGLE_QUOTE_TIME_FILTER = r'\'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,' \
r'2})?(\.\d+)?\' '
DOUBLE_QUOTE_TIME_FILTER = r'"[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?"'
# filter like "where id='abcd" => "where id=#"
SINGLE_QUOTE_FILTER = r'\'.*?\''
# filter like 'where id="abcd" => 'where id=#'
DOUBLE_QUOTE_FILTER = r'".*?"'
# filter annotation like "/* XXX */"
ANNOTATION_FILTER_1 = r'/\s*\*[\w\W]*?\*\s*/\s*'
ANNOTATION_FILTER_2 = r'^--.*\s?'
# filter NULL character '\n \t' in sql
NULL_CHARACTER_FILTER = r'\s+'
# remove data in insert sql
VALUE_BRACKET_FILETER = r'VALUES (\(.*\))'
# remove equal data in sql
WHERE_EQUAL_FILTER = r'= .*?\s'
LESS_EQUAL_FILTER = r'(<= .*? |<= .*$)'
GREATER_EQUAL_FILTER = r'(>= .*? |<= .*$)'
LESS_FILTER = r'(< .*? |< .*$)'
GREATER_FILTER = r'(> .*? |> .*$)'
EQUALS_FILTER = r'(= .*? |= .*$)'
LIMIT_DIGIT = r'LIMIT \d+'
def _unify_sql(sql): def _unify_sql(sql):
""" """
@ -85,65 +36,49 @@ def _unify_sql(sql):
""" """
index = 0 index = 0
sql = re.sub(r'\n', r' ', sql) sql = re.sub(r'\n', r' ', sql)
sql = re.sub(ANNOTATION_FILTER_1, r'', sql) sql = re.sub(r'/\s*\*[\w\W]*?\*\s*/\s*', r'', sql)
sql = re.sub(ANNOTATION_FILTER_2, r'', sql) sql = re.sub(r'^--.*\s?', r'', sql)
while index < len(sql):
if sql[index] in split_flag: sql = re.sub(r'([!><=]=)', r' \1 ', sql)
if sql[index:index + 2] in split_flag: sql = re.sub(r'([^!><=])([=<>])', r'\1 \2 ', sql)
sql = sql[:index].strip() + ' ' + sql[index:index + 2] + ' ' + sql[index + 2:].strip() sql = re.sub(r'([,()*%/+])', r' \1 ', sql)
index = index + 3 sql = re.sub(r'\s+', r' ', sql)
else: sql = sql.upper()
sql = sql[:index].strip() + ' ' + sql[index] + ' ' + sql[index + 1:].strip()
index = index + 2
else:
index = index + 1
new_sql = list()
for word in sql.split():
new_sql.append(word.upper())
sql = ' '.join(new_sql)
return sql.strip() return sql.strip()
def split_sql(sqls):
if not sqls:
return []
sqls = sqls.split(';')
result = list(map(lambda item: _unify_sql(item), sqls))
return result
def templatize_sql(sql): def templatize_sql(sql):
""" """
function: replace the message which is not important in sql SQL desensitization
""" """
sql = _unify_sql(sql) if not sql:
return ''
standard_sql = _unify_sql(sql)
sql = re.sub(r';', r'', sql) if standard_sql.startswith('INSERT'):
standard_sql = re.sub(r'VALUES (\(.*\))', r'VALUES', standard_sql)
# remove digital like 12, 12.565
standard_sql = re.sub(r'[\s]+\d+(\.\d+)?', r' ?', standard_sql)
# remove '$n' in sql
standard_sql = re.sub(r'\$\d+', r'?', standard_sql)
# remove single quotes content
standard_sql = re.sub(r'\'.*?\'', r'?', standard_sql)
# remove double quotes content
standard_sql = re.sub(r'".*?"', r'?', standard_sql)
# remove '`' in sql
standard_sql = re.sub(r'`', r'', standard_sql)
# remove ; in sql
standard_sql = re.sub(r';', r'', standard_sql)
# ? represent date or time return standard_sql.strip()
sql = re.sub(PURE_TIME_FILTER, r'?', sql)
sql = re.sub(SINGLE_QUOTE_TIME_FILTER, r'?', sql)
sql = re.sub(DOUBLE_QUOTE_TIME_FILTER, r'?', sql)
# $ represent insert value
if sql.startswith('INSERT'):
sql = re.sub(VALUE_BRACKET_FILETER, r'VALUES ()', sql)
# $$ represent select value
if sql.startswith('SELECT') and ' = ' in sql:
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$ ', sql)
# $$$ represent delete value
if sql.startswith('DELETE') and ' = ' in sql:
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$$ ', sql)
# & represent logical signal
sql = re.sub(LESS_EQUAL_FILTER, r'<= & ', sql)
sql = re.sub(LESS_FILTER, r'< & ', sql)
sql = re.sub(GREATER_EQUAL_FILTER, r'>= & ', sql)
sql = re.sub(GREATER_FILTER, r'> & ', sql)
sql = re.sub(LIMIT_DIGIT, r'LIMIT &', sql)
sql = re.sub(EQUALS_FILTER, r'= & ', sql)
sql = re.sub(PURE_DIGIT_FILTER, r' &', sql)
sql = re.sub(r'`', r'', sql)
# && represent quote str
sql = re.sub(SINGLE_QUOTE_FILTER, r'?', sql)
sql = re.sub(DOUBLE_QUOTE_FILTER, r'?', sql)
return sql
def _is_select_clause(parsed_sql): def _is_select_clause(parsed_sql):
@ -155,7 +90,6 @@ def _is_select_clause(parsed_sql):
return False return False
# todo: what is token list? from list?
def _get_table_token_list(parsed_sql, token_list): def _get_table_token_list(parsed_sql, token_list):
flag = False flag = False
for token in parsed_sql.tokens: for token in parsed_sql.tokens:

View File

@ -26,6 +26,7 @@ min_limit = 6.78e-05
# Template Methed # Template Methed
#------------------------------------------------------------------------------ #------------------------------------------------------------------------------
[template] [template]
similarity_algorithm = # [cosine_distance, levenshtein, list, parse_tree]
time_list_size = similarity_algorithm = cosine_distance
knn_number = time_list_size = 20
knn_number = 1

View File

@ -25,20 +25,15 @@ class ResultSaver:
os.chmod(dirname, stat.S_IRWXU) os.chmod(dirname, stat.S_IRWXU)
if isinstance(data, (list, tuple)): if isinstance(data, (list, tuple)):
self.save_list(data, realpath) self.save_list(data, realpath)
elif isinstance(data, dict):
self.save_dict(data, path)
else: else:
raise TypeError("mode should be 'list', 'tuple' or 'dict', but input type is '{}'".format(str(type(data)))) raise TypeError("mode should be 'list', 'tuple' or 'dict', but input type is '{}'".format(str(type(data))))
@staticmethod @staticmethod
def save_list(data, path): def save_list(data, path):
data = pd.DataFrame(data) with open(path, mode='w') as f:
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"') for item in data:
content = ",".join([str(sub_item) for sub_item in item])
@staticmethod f.write(content + '\n')
def save_dict(data, path):
data = pd.DataFrame(data.items())
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"')
class DBAgent: class DBAgent:
@ -74,7 +69,7 @@ class DBAgent:
result = list(self.cursor.fetchall()) result = list(self.cursor.fetchall())
return result return result
except Exception as e: except Exception as e:
logging.getLogger('agent').warning(str(e)) logging.warning(str(e))
def close(self): def close(self):
self.cursor.close() self.cursor.close()