!996 Update SQLDiag: Add two threshold judgments
Merge pull request !996 from wangtq/master
This commit is contained in:
@ -71,13 +71,17 @@ predict dataset sample: [sample_data/predict.csv](data/predict.csv)
|
|||||||
|
|
||||||
## example for template method using sample dataset
|
## example for template method using sample dataset
|
||||||
|
|
||||||
# training
|
# train
|
||||||
python main.py train -f ./sample_data/train.csv --model template --model-path ./template
|
python main.py train -f ./sample_data/train.csv --model template --model-path ./template
|
||||||
|
|
||||||
# predict
|
# predict
|
||||||
python main.py predict -f ./sample_data/predict.csv --model template --model-path ./template
|
python main.py predict -f ./sample_data/predict.csv --model template --model-path ./template
|
||||||
--predicted-file ./result/t_result
|
--predicted-file ./result/t_result
|
||||||
|
|
||||||
|
# predict with threshold
|
||||||
|
python main.py predict -f ./sample_data/predict.csv --threshold 0.02 --model template --model-path ./template
|
||||||
|
--predicted-file ./result/t_result
|
||||||
|
|
||||||
# update model
|
# update model
|
||||||
python main.py finetune -f ./sample_data/train.csv --model template --model-path ./template
|
python main.py finetune -f ./sample_data/train.csv --model template --model-path ./template
|
||||||
|
|
||||||
|
@ -9,9 +9,9 @@ W2V_SUFFIX = 'word2vector'
|
|||||||
|
|
||||||
|
|
||||||
def check_template_algorithm(param):
|
def check_template_algorithm(param):
|
||||||
if param and param not in ["list", "levenshtein", "parse_tree"]:
|
if param and param not in ["list", "levenshtein", "parse_tree", "cosine_distance"]:
|
||||||
raise ValueError("The similarity algorithm '%s' is invaild, "
|
raise ValueError("The similarity algorithm '%s' is invaild, "
|
||||||
"please choose from ['list', 'levenshtein', 'parse_tree']" % param)
|
"please choose from ['list', 'levenshtein', 'parse_tree', 'cosine_distance']" % param)
|
||||||
|
|
||||||
|
|
||||||
class ModelConfig(object):
|
class ModelConfig(object):
|
||||||
@ -62,7 +62,7 @@ SUPPORTED_ALGORITHM = {'dnn': lambda config: DnnModel(DnnConfig.init_from_config
|
|||||||
|
|
||||||
|
|
||||||
class SQLDiag:
|
class SQLDiag:
|
||||||
def __init__(self, model_algorithm, csv_file, params):
|
def __init__(self, model_algorithm, params):
|
||||||
if model_algorithm not in SUPPORTED_ALGORITHM:
|
if model_algorithm not in SUPPORTED_ALGORITHM:
|
||||||
raise NotImplementedError("do not support {}".format(model_algorithm))
|
raise NotImplementedError("do not support {}".format(model_algorithm))
|
||||||
try:
|
try:
|
||||||
@ -70,20 +70,16 @@ class SQLDiag:
|
|||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logging.error(e, exc_info=True)
|
logging.error(e, exc_info=True)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
self.load_data = LoadData(csv_file)
|
|
||||||
|
|
||||||
def __getattr__(self, item):
|
def fit(self, data):
|
||||||
return getattr(self.load_data, item)
|
self._model.fit(data)
|
||||||
|
|
||||||
def fit(self):
|
def transform(self, data):
|
||||||
self._model.fit(self.train_data)
|
return self._model.transform(data)
|
||||||
|
|
||||||
def transform(self):
|
def fine_tune(self, filepath, data):
|
||||||
return self._model.transform(self.predict_data)
|
|
||||||
|
|
||||||
def fine_tune(self, filepath):
|
|
||||||
self._model.load(filepath)
|
self._model.load(filepath)
|
||||||
self._model.fit(self.train_data)
|
self._model.fit(data)
|
||||||
|
|
||||||
def load(self, filepath):
|
def load(self, filepath):
|
||||||
self._model.load(filepath)
|
self._model.load(filepath)
|
||||||
|
@ -36,7 +36,7 @@ class KerasRegression:
|
|||||||
shape = features.shape[1]
|
shape = features.shape[1]
|
||||||
if self.model is None:
|
if self.model is None:
|
||||||
self.model = self.build_model(shape=shape, encoding_dim=self.encoding_dim)
|
self.model = self.build_model(shape=shape, encoding_dim=self.encoding_dim)
|
||||||
self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=2)
|
self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=0)
|
||||||
|
|
||||||
def predict(self, features):
|
def predict(self, features):
|
||||||
predict_result = self.model.predict(features)
|
predict_result = self.model.predict(features)
|
||||||
@ -129,10 +129,6 @@ class DnnModel(AbstractModel, ABC):
|
|||||||
self.w2v.load(word2vector_path)
|
self.w2v.load(word2vector_path)
|
||||||
with open(scaler_path, 'rb') as f:
|
with open(scaler_path, 'rb') as f:
|
||||||
self.scaler = pickle.load(f)
|
self.scaler = pickle.load(f)
|
||||||
logging.info("dnn model is loaded: '{}'; w2v model is loaded: '{}'; scaler model is loaded: '{}'."
|
|
||||||
.format(dnn_path,
|
|
||||||
word2vector_path,
|
|
||||||
scaler_path))
|
|
||||||
else:
|
else:
|
||||||
logging.error("{} not exist.".format(realpath))
|
logging.error("{} not exist.".format(realpath))
|
||||||
|
|
||||||
@ -149,7 +145,5 @@ class DnnModel(AbstractModel, ABC):
|
|||||||
self.w2v.save(word2vector_path)
|
self.w2v.save(word2vector_path)
|
||||||
with open(scaler_path, 'wb') as f:
|
with open(scaler_path, 'wb') as f:
|
||||||
pickle.dump(self.scaler, f)
|
pickle.dump(self.scaler, f)
|
||||||
logging.info("dnn model is saved: '{}'; w2v model is saved: '{}'; scaler model is saved: '{}'."
|
print("DNN model is stored in '{}'".format(realpath))
|
||||||
.format(dnn_path,
|
|
||||||
word2vector_path,
|
|
||||||
scaler_path))
|
|
||||||
|
@ -18,9 +18,10 @@ import logging
|
|||||||
import os
|
import os
|
||||||
import stat
|
import stat
|
||||||
from functools import reduce
|
from functools import reduce
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
from algorithm.sql_similarity import calc_sql_distance
|
from algorithm.sql_similarity import calc_sql_distance
|
||||||
from preprocessing import get_sql_template
|
from preprocessing import get_sql_template, templatize_sql
|
||||||
from utils import check_illegal_sql, LRUCache
|
from utils import check_illegal_sql, LRUCache
|
||||||
from . import AbstractModel
|
from . import AbstractModel
|
||||||
|
|
||||||
@ -40,97 +41,78 @@ class TemplateModel(AbstractModel):
|
|||||||
for sql, duration_time in data:
|
for sql, duration_time in data:
|
||||||
if check_illegal_sql(sql):
|
if check_illegal_sql(sql):
|
||||||
continue
|
continue
|
||||||
# get 'fine_template' and 'rough_template' of SQL
|
sql_template = templatize_sql(sql)
|
||||||
fine_template, rough_template = get_sql_template(sql)
|
sql_prefix = sql_template.split()[0]
|
||||||
sql_prefix = fine_template.split()[0]
|
|
||||||
# if prefix of SQL is not in 'update', 'delete', 'select' and 'insert',
|
|
||||||
# then convert prefix to 'other'
|
|
||||||
if sql_prefix not in self.__hash_table:
|
if sql_prefix not in self.__hash_table:
|
||||||
sql_prefix = 'OTHER'
|
sql_prefix = 'OTHER'
|
||||||
if rough_template not in self.__hash_table[sql_prefix]:
|
if sql_template not in self.__hash_table[sql_prefix]:
|
||||||
self.__hash_table[sql_prefix][rough_template] = dict()
|
self.__hash_table[sql_prefix][sql_template] = dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
|
||||||
self.__hash_table[sql_prefix][rough_template]['info'] = dict()
|
self.__hash_table[sql_prefix][sql_template]['count'] += 1
|
||||||
if fine_template not in self.__hash_table[sql_prefix][rough_template]['info']:
|
self.__hash_table[sql_prefix][sql_template]['time_list'].append(duration_time)
|
||||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template] = \
|
|
||||||
dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
|
|
||||||
# count the number of occurrences of fine template
|
|
||||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['count'] += 1
|
|
||||||
# store the execution time of the matched template in the corresponding list
|
|
||||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
|
||||||
'time_list'].append(duration_time)
|
|
||||||
# iterative calculation of execution time based on historical data
|
|
||||||
if not self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
|
||||||
'iter_time']:
|
|
||||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
|
||||||
'iter_time'] = duration_time
|
|
||||||
else:
|
|
||||||
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time'] = \
|
|
||||||
(self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
|
|
||||||
'iter_time'] + duration_time) / 2
|
|
||||||
# calculate the average execution time of each template
|
|
||||||
for sql_prefix, sql_prefix_info in self.__hash_table.items():
|
for sql_prefix, sql_prefix_info in self.__hash_table.items():
|
||||||
for rough_template, rough_template_info in sql_prefix_info.items():
|
for sql_template, sql_template_info in sql_prefix_info.items():
|
||||||
for _, fine_template_info in rough_template_info['info'].items():
|
del sql_template_info['time_list'][:-self.time_list_size]
|
||||||
del fine_template_info['time_list'][:-self.time_list_size]
|
sql_template_info['mean_time'] = sum(sql_template_info['time_list']) / len(sql_template_info['time_list'])
|
||||||
fine_template_info['mean_time'] = \
|
sql_template_info['iter_time'] = reduce(lambda x, y: (x+y)/2, sql_template_info['time_list'])
|
||||||
sum(fine_template_info['time_list']) / len(
|
|
||||||
fine_template_info['time_list'])
|
|
||||||
rough_template_info['count'] = len(rough_template_info['info'])
|
|
||||||
rough_template_info['mean_time'] = sum(
|
|
||||||
[value['mean_time'] for key, value in
|
|
||||||
rough_template_info['info'].items()]) / len(
|
|
||||||
rough_template_info['info'])
|
|
||||||
|
|
||||||
def transform(self, data):
|
def transform(self, data):
|
||||||
predict_time_list = []
|
predict_result_dict = defaultdict(list)
|
||||||
for sql in data:
|
for sql in data:
|
||||||
predict_time = self.predict_duration_time(sql)
|
sql_, status, predict_time, top_similarity_sql = self.predict_duration_time(sql)
|
||||||
predict_time_list.append([sql, predict_time])
|
predict_result_dict[status].append([sql_, predict_time, top_similarity_sql])
|
||||||
return predict_time_list
|
for key, value in predict_result_dict.items():
|
||||||
|
if value:
|
||||||
|
value.sort(key=lambda item: item[1], reverse=True)
|
||||||
|
return predict_result_dict
|
||||||
|
|
||||||
@LRUCache(max_size=1024)
|
@LRUCache(max_size=1024)
|
||||||
def predict_duration_time(self, sql):
|
def predict_duration_time(self, sql):
|
||||||
|
top_similarity_sql = None
|
||||||
if check_illegal_sql(sql):
|
if check_illegal_sql(sql):
|
||||||
return -1
|
predict_time = -1
|
||||||
sql_prefix = sql.strip().split()[0]
|
status = 'Suspect illegal sql'
|
||||||
# get 'fine_template' and 'rough_template' of SQL
|
return sql, status, predict_time, top_similarity_sql
|
||||||
fine_template, rough_template = get_sql_template(sql)
|
|
||||||
|
sql_template = templatize_sql(sql)
|
||||||
|
# get 'sql_template' of SQL
|
||||||
|
sql_prefix = sql_template.strip().split()[0]
|
||||||
if sql_prefix not in self.__hash_table:
|
if sql_prefix not in self.__hash_table:
|
||||||
sql_prefix = 'OTHER'
|
sql_prefix = 'OTHER'
|
||||||
if not self.__hash_table[sql_prefix]:
|
if not self.__hash_table[sql_prefix]:
|
||||||
logging.warning("'{}' not in the templates.".format(sql))
|
status = 'No SQL information'
|
||||||
predict_time = -1
|
predict_time = -1
|
||||||
elif rough_template not in self.__hash_table[sql_prefix] or fine_template not in \
|
elif sql_template not in self.__hash_table[sql_prefix]:
|
||||||
self.__hash_table[sql_prefix][rough_template]['info']:
|
|
||||||
similarity_info = []
|
similarity_info = []
|
||||||
"""
|
"""
|
||||||
if the template does not exist in the hash table,
|
if the template does not exist in the hash table,
|
||||||
then calculate the possible execution time based on template
|
then calculate the possible execution time based on template
|
||||||
similarity and KNN algorithm in all other templates
|
similarity and KNN algorithm in all other templates
|
||||||
"""
|
"""
|
||||||
if rough_template not in self.__hash_table[sql_prefix]:
|
status = 'No SQL template found'
|
||||||
for local_rough_template, local_rough_template_info in self.__hash_table[sql_prefix].items():
|
for local_sql_template, local_sql_template_info in self.__hash_table[sql_prefix].items():
|
||||||
similarity_info.append(
|
similarity_info.append(
|
||||||
(self.similarity_algorithm(rough_template, local_rough_template),
|
(self.similarity_algorithm(sql_template, local_sql_template),
|
||||||
local_rough_template_info['mean_time']))
|
local_sql_template_info['mean_time'], local_sql_template))
|
||||||
else:
|
|
||||||
for local_fine_template, local_fine_template_info in \
|
|
||||||
self.__hash_table[sql_prefix][rough_template]['info'].items():
|
|
||||||
similarity_info.append(
|
|
||||||
(self.similarity_algorithm(fine_template, local_fine_template),
|
|
||||||
local_fine_template_info['iter_time']))
|
|
||||||
topn_similarity_info = heapq.nlargest(self.knn_number, similarity_info)
|
topn_similarity_info = heapq.nlargest(self.knn_number, similarity_info)
|
||||||
sum_similarity_scores = sum(item[0] for item in topn_similarity_info) + self.bias
|
sum_similarity_scores = sum(item[0] for item in topn_similarity_info)
|
||||||
similarity_proportions = (item[0] / sum_similarity_scores for item in
|
if not sum_similarity_scores:
|
||||||
topn_similarity_info)
|
sum_similarity_scores = self.bias
|
||||||
topn_duration_time = (item[1] for item in topn_similarity_info)
|
top_similarity_sql = '\n'.join([item[2] for item in topn_similarity_info])
|
||||||
|
similarity_proportions = [item[0] / sum_similarity_scores for item in
|
||||||
|
topn_similarity_info]
|
||||||
|
topn_duration_time = [item[1] for item in topn_similarity_info]
|
||||||
predict_time = reduce(lambda x, y: x + y,
|
predict_time = reduce(lambda x, y: x + y,
|
||||||
map(lambda x, y: x * y, similarity_proportions,
|
map(lambda x, y: x * y, similarity_proportions,
|
||||||
topn_duration_time))
|
topn_duration_time))
|
||||||
else:
|
|
||||||
predict_time = self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time']
|
|
||||||
|
|
||||||
return predict_time
|
else:
|
||||||
|
status = 'Fine match'
|
||||||
|
predict_time = self.__hash_table[sql_prefix][sql_template]['iter_time']
|
||||||
|
top_similarity_sql = sql_template
|
||||||
|
|
||||||
|
return sql, status, predict_time, top_similarity_sql
|
||||||
|
|
||||||
def load(self, filepath):
|
def load(self, filepath):
|
||||||
realpath = os.path.realpath(filepath)
|
realpath = os.path.realpath(filepath)
|
||||||
@ -138,7 +120,6 @@ class TemplateModel(AbstractModel):
|
|||||||
template_path = os.path.join(realpath, 'template.json')
|
template_path = os.path.join(realpath, 'template.json')
|
||||||
with open(template_path, mode='r') as f:
|
with open(template_path, mode='r') as f:
|
||||||
self.__hash_table = json.load(f)
|
self.__hash_table = json.load(f)
|
||||||
logging.info("template model '{}' is loaded.".format(template_path))
|
|
||||||
else:
|
else:
|
||||||
logging.error("{} not exist.".format(realpath))
|
logging.error("{} not exist.".format(realpath))
|
||||||
|
|
||||||
@ -151,4 +132,5 @@ class TemplateModel(AbstractModel):
|
|||||||
template_path = os.path.join(realpath, 'template.json')
|
template_path = os.path.join(realpath, 'template.json')
|
||||||
with open(template_path, mode='w') as f:
|
with open(template_path, mode='w') as f:
|
||||||
json.dump(self.__hash_table, f, indent=4)
|
json.dump(self.__hash_table, f, indent=4)
|
||||||
logging.info("template model is stored in '{}'".format(realpath))
|
print("Template model is stored in '{}'".format(realpath))
|
||||||
|
|
||||||
|
@ -5,6 +5,8 @@ def calc_sql_distance(algorithm):
|
|||||||
from .levenshtein import distance
|
from .levenshtein import distance
|
||||||
elif algorithm == 'parse_tree':
|
elif algorithm == 'parse_tree':
|
||||||
from .parse_tree import distance
|
from .parse_tree import distance
|
||||||
|
elif algorithm == 'cosine_distance':
|
||||||
|
from .cosine_distance import distance
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("do not support '{}'".format(algorithm))
|
raise NotImplementedError("do not support '{}'".format(algorithm))
|
||||||
return distance
|
return distance
|
||||||
|
@ -0,0 +1,26 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||||
|
|
||||||
|
openGauss is licensed under Mulan PSL v2.
|
||||||
|
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||||
|
You may obtain a copy of Mulan PSL v2 at:
|
||||||
|
|
||||||
|
http://license.coscl.org.cn/MulanPSL2
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||||
|
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||||
|
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||||
|
See the Mulan PSL v2 for more details.
|
||||||
|
"""
|
||||||
|
import math
|
||||||
|
from collections import Counter
|
||||||
|
|
||||||
|
|
||||||
|
def distance(str1, str2):
|
||||||
|
c_1 = Counter(str1)
|
||||||
|
c_2 = Counter(str2)
|
||||||
|
c_union = set(c_1).union(c_2)
|
||||||
|
dot_product = sum(c_1.get(item, 0) * c_2.get(item, 0) for item in c_union)
|
||||||
|
mag_c1 = math.sqrt(sum(c_1.get(item, 0)**2 for item in c_union))
|
||||||
|
mag_c2 = math.sqrt(sum(c_2.get(item, 0)**2 for item in c_union))
|
||||||
|
return dot_product / (mag_c1 * mag_c2)
|
@ -1,3 +1,19 @@
|
|||||||
|
"""
|
||||||
|
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
|
||||||
|
|
||||||
|
openGauss is licensed under Mulan PSL v2.
|
||||||
|
You can use this software according to the terms and conditions of the Mulan PSL v2.
|
||||||
|
You may obtain a copy of Mulan PSL v2 at:
|
||||||
|
|
||||||
|
http://license.coscl.org.cn/MulanPSL2
|
||||||
|
|
||||||
|
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
|
||||||
|
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
|
||||||
|
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
|
||||||
|
See the Mulan PSL v2 for more details.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def distance(str1, str2):
|
def distance(str1, str2):
|
||||||
"""
|
"""
|
||||||
func: calculate levenshtein distance between two strings.
|
func: calculate levenshtein distance between two strings.
|
||||||
@ -5,18 +21,20 @@ def distance(str1, str2):
|
|||||||
:param str2: string2
|
:param str2: string2
|
||||||
:return: distance
|
:return: distance
|
||||||
"""
|
"""
|
||||||
m, n = len(str1) + 1, len(str2) + 1
|
len_str1 = len(str1) + 1
|
||||||
matrix = [[0] * n for _ in range(m)]
|
len_str2 = len(str2) + 1
|
||||||
matrix[0][0] = 0
|
|
||||||
for i in range(1, m):
|
|
||||||
matrix[i][0] = matrix[i - 1][0] + 1
|
|
||||||
for j in range(1, n):
|
|
||||||
matrix[0][j] = matrix[0][j - 1] + 1
|
|
||||||
for i in range(1, m):
|
|
||||||
for j in range(1, n):
|
|
||||||
if str1[i - 1] == str2[j - 1]:
|
|
||||||
matrix[i][j] = matrix[i - 1][j - 1]
|
|
||||||
else:
|
|
||||||
matrix[i][j] = max(matrix[i - 1][j - 1], matrix[i - 1][j], matrix[i][j - 1]) + 1
|
|
||||||
|
|
||||||
return matrix[m - 1][n - 1]
|
mat = [[0]*len_str2 for i in range(len_str1)]
|
||||||
|
mat[0][0] = 0
|
||||||
|
for i in range(1,len_str1):
|
||||||
|
mat[i][0] = mat[i-1][0] + 1
|
||||||
|
for j in range(1,len_str2):
|
||||||
|
mat[0][j] = mat[0][j-1]+1
|
||||||
|
for i in range(1,len_str1):
|
||||||
|
for j in range(1,len_str2):
|
||||||
|
if str1[i-1] == str2[j-1]:
|
||||||
|
mat[i][j] = mat[i-1][j-1]
|
||||||
|
else:
|
||||||
|
mat[i][j] = min(mat[i-1][j-1],mat[i-1][j],mat[i][j-1])+1
|
||||||
|
|
||||||
|
return 1 / mat[len_str1-1][j-1]
|
||||||
|
@ -12,10 +12,10 @@ __description__ = "Get sql information based on wdr."
|
|||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
|
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||||
description=__description__)
|
description=__description__)
|
||||||
parser.add_argument('--port', help="User of remote server.", type=int, required=True)
|
parser.add_argument('--port', help="Port of database service.", type=int, required=True)
|
||||||
parser.add_argument('--start_time', help="Start time of query", required=True)
|
parser.add_argument('--start-time', help="Start time of query", required=True)
|
||||||
parser.add_argument('--finish_time', help="Finish time of query", required=True)
|
parser.add_argument('--finish-time', help="Finish time of query", required=True)
|
||||||
parser.add_argument('--save_path', default='sample_data/data.csv', help="Path to save result")
|
parser.add_argument('--save-path', default='sample_data/data.csv', help="Path to save result")
|
||||||
return parser.parse_args()
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,6 +19,7 @@ from configparser import ConfigParser
|
|||||||
|
|
||||||
from algorithm.diag import SQLDiag
|
from algorithm.diag import SQLDiag
|
||||||
from utils import ResultSaver
|
from utils import ResultSaver
|
||||||
|
from preprocessing import LoadData, split_sql
|
||||||
|
|
||||||
__version__ = '2.0.0'
|
__version__ = '2.0.0'
|
||||||
__description__ = 'SQLdiag integrated by openGauss.'
|
__description__ = 'SQLdiag integrated by openGauss.'
|
||||||
@ -40,6 +41,8 @@ def parse_args():
|
|||||||
parser.add_argument('--predicted-file', help='The file path to save the predicted result.')
|
parser.add_argument('--predicted-file', help='The file path to save the predicted result.')
|
||||||
parser.add_argument('--model', default='template', choices=['template', 'dnn'],
|
parser.add_argument('--model', default='template', choices=['template', 'dnn'],
|
||||||
help='Choose the model model to use.')
|
help='Choose the model model to use.')
|
||||||
|
parser.add_argument('--query', help='Input the querys to predict.')
|
||||||
|
parser.add_argument('--threshold', help='Slow SQL threshold.')
|
||||||
parser.add_argument('--model-path', required=True,
|
parser.add_argument('--model-path', required=True,
|
||||||
help='The storage path of the model file, used to read or save the model file.')
|
help='The storage path of the model file, used to read or save the model file.')
|
||||||
parser.add_argument('--config-file', default='sqldiag.conf')
|
parser.add_argument('--config-file', default='sqldiag.conf')
|
||||||
@ -54,23 +57,64 @@ def get_config(filepath):
|
|||||||
|
|
||||||
|
|
||||||
def main(args):
|
def main(args):
|
||||||
logging.basicConfig(level=logging.INFO)
|
logging.basicConfig(level=logging.WARNING)
|
||||||
model = SQLDiag(args.model, args.csv_file, get_config(args.config_file))
|
model = SQLDiag(args.model, get_config(args.config_file))
|
||||||
if args.mode == 'train':
|
if args.mode in ('train', 'finetune'):
|
||||||
model.fit()
|
if not args.csv_file:
|
||||||
model.save(args.model_path)
|
logging.fatal('The [--csv-file] parameter is required for train mode')
|
||||||
elif args.mode == 'predict':
|
|
||||||
if not args.predicted_file:
|
|
||||||
logging.error("The [--predicted-file] parameter is required for predict mode")
|
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
model.load(args.model_path)
|
train_data = LoadData(args.csv_file).train_data
|
||||||
pred_result = model.transform()
|
if args.mode == 'train':
|
||||||
ResultSaver().save(pred_result, args.predicted_file)
|
model.fit(train_data)
|
||||||
logging.info('predicted result in saved in {}'.format(args.predicted_file))
|
else:
|
||||||
elif args.mode == 'finetune':
|
model.fine_tune(args.model_path, train_data)
|
||||||
model.fine_tune(args.model_path)
|
|
||||||
model.save(args.model_path)
|
model.save(args.model_path)
|
||||||
|
else:
|
||||||
|
model.load(args.model_path)
|
||||||
|
if args.csv_file and not args.query:
|
||||||
|
predict_data = LoadData(args.csv_file).predict_data
|
||||||
|
elif args.query and not args.csv_file:
|
||||||
|
predict_data = split_sql(args.query)
|
||||||
|
else:
|
||||||
|
logging.error('The predict model only supports [--csv-file] or [--query] at the same time.')
|
||||||
|
sys.exit(1)
|
||||||
|
args.threshold = -100 if not args.threshold else float(args.threshold)
|
||||||
|
pred_result = model.transform(predict_data)
|
||||||
|
if args.predicted_file:
|
||||||
|
if args.model == 'template':
|
||||||
|
info_sum = []
|
||||||
|
for stats, _info in pred_result.items():
|
||||||
|
if _info:
|
||||||
|
_info = list(filter(lambda item: item[1]>=args.threshold, _info))
|
||||||
|
for item in _info:
|
||||||
|
item.insert(1, stats)
|
||||||
|
info_sum.extend(_info)
|
||||||
|
ResultSaver().save(info_sum, args.predicted_file)
|
||||||
|
else:
|
||||||
|
pred_result = list(filter(lambda item: float(item[1])>=args.threshold, pred_result))
|
||||||
|
ResultSaver().save(pred_result, args.predicted_file)
|
||||||
|
else:
|
||||||
|
from prettytable import PrettyTable
|
||||||
|
|
||||||
|
display_table = PrettyTable()
|
||||||
|
if args.model == 'template':
|
||||||
|
display_table.field_names = ['sql', 'status', 'predicted time', 'most similar template']
|
||||||
|
display_table.align = 'l'
|
||||||
|
status = ('Suspect illegal SQL', 'No SQL information', 'No SQL template found', 'Fine match')
|
||||||
|
for stats in status:
|
||||||
|
if pred_result[stats]:
|
||||||
|
for sql, predicted_time, similariest_sql in pred_result[stats]:
|
||||||
|
if predicted_time >= args.threshold or stats == 'Suspect illegal sql':
|
||||||
|
display_table.add_row([sql, stats, predicted_time, similariest_sql])
|
||||||
|
else:
|
||||||
|
display_table.field_names = ['sql', 'predicted time']
|
||||||
|
display_table.align = 'l'
|
||||||
|
for sql, predicted_time in pred_result:
|
||||||
|
if float(predicted_time) >= args.threshold:
|
||||||
|
display_table.add_row([sql, predicted_time])
|
||||||
|
print(display_table.get_string())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
main(parse_args())
|
main(parse_args())
|
||||||
|
|
||||||
|
@ -18,8 +18,6 @@ import sqlparse
|
|||||||
from sqlparse.sql import Identifier, IdentifierList
|
from sqlparse.sql import Identifier, IdentifierList
|
||||||
from sqlparse.tokens import Keyword, DML
|
from sqlparse.tokens import Keyword, DML
|
||||||
|
|
||||||
# split flag in SQL
|
|
||||||
split_flag = ('!=', '<=', '>=', '==', '<', '>', '=', ',', '(', ')', '*', ';', '%', '+', ',', ';')
|
|
||||||
|
|
||||||
DDL_WORDS = ('CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'COMMIT', 'RENAME')
|
DDL_WORDS = ('CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'COMMIT', 'RENAME')
|
||||||
DML_WORDS = ('SELECT', 'INSERT INTO', 'UPDATE', 'DELETE', 'MERGE', 'CALL',
|
DML_WORDS = ('SELECT', 'INSERT INTO', 'UPDATE', 'DELETE', 'MERGE', 'CALL',
|
||||||
@ -31,53 +29,6 @@ KEYWORDS = ('GRANT', 'REVOKE', 'DENY', 'ABORT', 'ADD', 'AGGREGATE', 'ANALYSE', '
|
|||||||
FUNC = ('FLOOR', 'SUM', 'NOW', 'UUID', 'COUNT')
|
FUNC = ('FLOOR', 'SUM', 'NOW', 'UUID', 'COUNT')
|
||||||
SQL_SIG = ('&', '&&')
|
SQL_SIG = ('&', '&&')
|
||||||
|
|
||||||
# filter like (insert into aa (c1, c2) values (v1, v2) => insert into aa * values *)
|
|
||||||
BRACKET_FILTER = r'\(.*?\)'
|
|
||||||
|
|
||||||
# filter (123, 123.123)
|
|
||||||
PURE_DIGIT_FILTER = r'[\s]+\d+(\.\d+)?'
|
|
||||||
|
|
||||||
# filter ('123', '123.123')
|
|
||||||
SINGLE_QUOTE_DIGIT_FILTER = r'\'\d+(\.\d+)?\''
|
|
||||||
|
|
||||||
# filter ("123", "123.123")
|
|
||||||
DOUBLE_QUOTE_DIGIT_FILTER = r'"\d+(\.\d+)?"'
|
|
||||||
|
|
||||||
# filter ('123', 123, '123,123', 123.123) not filter(table1, column1, table_2, column_2)
|
|
||||||
DIGIT_FILTER = r'([^a-zA-Z])_?\d+(\.\d+)?'
|
|
||||||
|
|
||||||
# filter date in sql ('1999-09-09', '1999/09/09', "1999-09-09 20:10:10", '1999/09/09 20:10:10.12345')
|
|
||||||
PURE_TIME_FILTER = r'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?'
|
|
||||||
SINGLE_QUOTE_TIME_FILTER = r'\'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,' \
|
|
||||||
r'2})?(\.\d+)?\' '
|
|
||||||
DOUBLE_QUOTE_TIME_FILTER = r'"[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?"'
|
|
||||||
|
|
||||||
# filter like "where id='abcd" => "where id=#"
|
|
||||||
SINGLE_QUOTE_FILTER = r'\'.*?\''
|
|
||||||
|
|
||||||
# filter like 'where id="abcd" => 'where id=#'
|
|
||||||
DOUBLE_QUOTE_FILTER = r'".*?"'
|
|
||||||
|
|
||||||
# filter annotation like "/* XXX */"
|
|
||||||
ANNOTATION_FILTER_1 = r'/\s*\*[\w\W]*?\*\s*/\s*'
|
|
||||||
ANNOTATION_FILTER_2 = r'^--.*\s?'
|
|
||||||
|
|
||||||
# filter NULL character '\n \t' in sql
|
|
||||||
NULL_CHARACTER_FILTER = r'\s+'
|
|
||||||
|
|
||||||
# remove data in insert sql
|
|
||||||
VALUE_BRACKET_FILETER = r'VALUES (\(.*\))'
|
|
||||||
|
|
||||||
# remove equal data in sql
|
|
||||||
WHERE_EQUAL_FILTER = r'= .*?\s'
|
|
||||||
|
|
||||||
LESS_EQUAL_FILTER = r'(<= .*? |<= .*$)'
|
|
||||||
GREATER_EQUAL_FILTER = r'(>= .*? |<= .*$)'
|
|
||||||
LESS_FILTER = r'(< .*? |< .*$)'
|
|
||||||
GREATER_FILTER = r'(> .*? |> .*$)'
|
|
||||||
EQUALS_FILTER = r'(= .*? |= .*$)'
|
|
||||||
LIMIT_DIGIT = r'LIMIT \d+'
|
|
||||||
|
|
||||||
|
|
||||||
def _unify_sql(sql):
|
def _unify_sql(sql):
|
||||||
"""
|
"""
|
||||||
@ -85,65 +36,49 @@ def _unify_sql(sql):
|
|||||||
"""
|
"""
|
||||||
index = 0
|
index = 0
|
||||||
sql = re.sub(r'\n', r' ', sql)
|
sql = re.sub(r'\n', r' ', sql)
|
||||||
sql = re.sub(ANNOTATION_FILTER_1, r'', sql)
|
sql = re.sub(r'/\s*\*[\w\W]*?\*\s*/\s*', r'', sql)
|
||||||
sql = re.sub(ANNOTATION_FILTER_2, r'', sql)
|
sql = re.sub(r'^--.*\s?', r'', sql)
|
||||||
while index < len(sql):
|
|
||||||
if sql[index] in split_flag:
|
sql = re.sub(r'([!><=]=)', r' \1 ', sql)
|
||||||
if sql[index:index + 2] in split_flag:
|
sql = re.sub(r'([^!><=])([=<>])', r'\1 \2 ', sql)
|
||||||
sql = sql[:index].strip() + ' ' + sql[index:index + 2] + ' ' + sql[index + 2:].strip()
|
sql = re.sub(r'([,()*%/+])', r' \1 ', sql)
|
||||||
index = index + 3
|
sql = re.sub(r'\s+', r' ', sql)
|
||||||
else:
|
sql = sql.upper()
|
||||||
sql = sql[:index].strip() + ' ' + sql[index] + ' ' + sql[index + 1:].strip()
|
|
||||||
index = index + 2
|
|
||||||
else:
|
|
||||||
index = index + 1
|
|
||||||
new_sql = list()
|
|
||||||
for word in sql.split():
|
|
||||||
new_sql.append(word.upper())
|
|
||||||
sql = ' '.join(new_sql)
|
|
||||||
return sql.strip()
|
return sql.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def split_sql(sqls):
|
||||||
|
if not sqls:
|
||||||
|
return []
|
||||||
|
sqls = sqls.split(';')
|
||||||
|
result = list(map(lambda item: _unify_sql(item), sqls))
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
def templatize_sql(sql):
|
def templatize_sql(sql):
|
||||||
"""
|
"""
|
||||||
function: replace the message which is not important in sql
|
SQL desensitization
|
||||||
"""
|
"""
|
||||||
sql = _unify_sql(sql)
|
if not sql:
|
||||||
|
return ''
|
||||||
|
standard_sql = _unify_sql(sql)
|
||||||
|
|
||||||
sql = re.sub(r';', r'', sql)
|
if standard_sql.startswith('INSERT'):
|
||||||
|
standard_sql = re.sub(r'VALUES (\(.*\))', r'VALUES', standard_sql)
|
||||||
|
# remove digital like 12, 12.565
|
||||||
|
standard_sql = re.sub(r'[\s]+\d+(\.\d+)?', r' ?', standard_sql)
|
||||||
|
# remove '$n' in sql
|
||||||
|
standard_sql = re.sub(r'\$\d+', r'?', standard_sql)
|
||||||
|
# remove single quotes content
|
||||||
|
standard_sql = re.sub(r'\'.*?\'', r'?', standard_sql)
|
||||||
|
# remove double quotes content
|
||||||
|
standard_sql = re.sub(r'".*?"', r'?', standard_sql)
|
||||||
|
# remove '`' in sql
|
||||||
|
standard_sql = re.sub(r'`', r'', standard_sql)
|
||||||
|
# remove ; in sql
|
||||||
|
standard_sql = re.sub(r';', r'', standard_sql)
|
||||||
|
|
||||||
# ? represent date or time
|
return standard_sql.strip()
|
||||||
sql = re.sub(PURE_TIME_FILTER, r'?', sql)
|
|
||||||
sql = re.sub(SINGLE_QUOTE_TIME_FILTER, r'?', sql)
|
|
||||||
sql = re.sub(DOUBLE_QUOTE_TIME_FILTER, r'?', sql)
|
|
||||||
|
|
||||||
# $ represent insert value
|
|
||||||
if sql.startswith('INSERT'):
|
|
||||||
sql = re.sub(VALUE_BRACKET_FILETER, r'VALUES ()', sql)
|
|
||||||
|
|
||||||
# $$ represent select value
|
|
||||||
if sql.startswith('SELECT') and ' = ' in sql:
|
|
||||||
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$ ', sql)
|
|
||||||
|
|
||||||
# $$$ represent delete value
|
|
||||||
if sql.startswith('DELETE') and ' = ' in sql:
|
|
||||||
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$$ ', sql)
|
|
||||||
|
|
||||||
# & represent logical signal
|
|
||||||
sql = re.sub(LESS_EQUAL_FILTER, r'<= & ', sql)
|
|
||||||
sql = re.sub(LESS_FILTER, r'< & ', sql)
|
|
||||||
sql = re.sub(GREATER_EQUAL_FILTER, r'>= & ', sql)
|
|
||||||
sql = re.sub(GREATER_FILTER, r'> & ', sql)
|
|
||||||
sql = re.sub(LIMIT_DIGIT, r'LIMIT &', sql)
|
|
||||||
sql = re.sub(EQUALS_FILTER, r'= & ', sql)
|
|
||||||
sql = re.sub(PURE_DIGIT_FILTER, r' &', sql)
|
|
||||||
sql = re.sub(r'`', r'', sql)
|
|
||||||
|
|
||||||
# && represent quote str
|
|
||||||
sql = re.sub(SINGLE_QUOTE_FILTER, r'?', sql)
|
|
||||||
sql = re.sub(DOUBLE_QUOTE_FILTER, r'?', sql)
|
|
||||||
|
|
||||||
return sql
|
|
||||||
|
|
||||||
|
|
||||||
def _is_select_clause(parsed_sql):
|
def _is_select_clause(parsed_sql):
|
||||||
@ -155,7 +90,6 @@ def _is_select_clause(parsed_sql):
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
# todo: what is token list? from list?
|
|
||||||
def _get_table_token_list(parsed_sql, token_list):
|
def _get_table_token_list(parsed_sql, token_list):
|
||||||
flag = False
|
flag = False
|
||||||
for token in parsed_sql.tokens:
|
for token in parsed_sql.tokens:
|
||||||
|
@ -26,6 +26,7 @@ min_limit = 6.78e-05
|
|||||||
# Template Methed
|
# Template Methed
|
||||||
#------------------------------------------------------------------------------
|
#------------------------------------------------------------------------------
|
||||||
[template]
|
[template]
|
||||||
similarity_algorithm =
|
# [cosine_distance, levenshtein, list, parse_tree]
|
||||||
time_list_size =
|
similarity_algorithm = cosine_distance
|
||||||
knn_number =
|
time_list_size = 20
|
||||||
|
knn_number = 1
|
||||||
|
@ -25,20 +25,15 @@ class ResultSaver:
|
|||||||
os.chmod(dirname, stat.S_IRWXU)
|
os.chmod(dirname, stat.S_IRWXU)
|
||||||
if isinstance(data, (list, tuple)):
|
if isinstance(data, (list, tuple)):
|
||||||
self.save_list(data, realpath)
|
self.save_list(data, realpath)
|
||||||
elif isinstance(data, dict):
|
|
||||||
self.save_dict(data, path)
|
|
||||||
else:
|
else:
|
||||||
raise TypeError("mode should be 'list', 'tuple' or 'dict', but input type is '{}'".format(str(type(data))))
|
raise TypeError("mode should be 'list', 'tuple' or 'dict', but input type is '{}'".format(str(type(data))))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def save_list(data, path):
|
def save_list(data, path):
|
||||||
data = pd.DataFrame(data)
|
with open(path, mode='w') as f:
|
||||||
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"')
|
for item in data:
|
||||||
|
content = ",".join([str(sub_item) for sub_item in item])
|
||||||
@staticmethod
|
f.write(content + '\n')
|
||||||
def save_dict(data, path):
|
|
||||||
data = pd.DataFrame(data.items())
|
|
||||||
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"')
|
|
||||||
|
|
||||||
|
|
||||||
class DBAgent:
|
class DBAgent:
|
||||||
@ -74,7 +69,7 @@ class DBAgent:
|
|||||||
result = list(self.cursor.fetchall())
|
result = list(self.cursor.fetchall())
|
||||||
return result
|
return result
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.getLogger('agent').warning(str(e))
|
logging.warning(str(e))
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.cursor.close()
|
self.cursor.close()
|
||||||
|
Reference in New Issue
Block a user