!996 Update SQLDiag: Add two threshold judgments

Merge pull request !996 from wangtq/master
This commit is contained in:
opengauss-bot
2021-05-28 19:34:54 +08:00
committed by Gitee
12 changed files with 235 additions and 239 deletions

View File

@ -71,13 +71,17 @@ predict dataset sample: [sample_data/predict.csv](data/predict.csv)
## example for template method using sample dataset
# training
# train
python main.py train -f ./sample_data/train.csv --model template --model-path ./template
# predict
python main.py predict -f ./sample_data/predict.csv --model template --model-path ./template
--predicted-file ./result/t_result
# predict with threshold
python main.py predict -f ./sample_data/predict.csv --threshold 0.02 --model template --model-path ./template
--predicted-file ./result/t_result
# update model
python main.py finetune -f ./sample_data/train.csv --model template --model-path ./template

View File

@ -9,9 +9,9 @@ W2V_SUFFIX = 'word2vector'
def check_template_algorithm(param):
if param and param not in ["list", "levenshtein", "parse_tree"]:
if param and param not in ["list", "levenshtein", "parse_tree", "cosine_distance"]:
raise ValueError("The similarity algorithm '%s' is invaild, "
"please choose from ['list', 'levenshtein', 'parse_tree']" % param)
"please choose from ['list', 'levenshtein', 'parse_tree', 'cosine_distance']" % param)
class ModelConfig(object):
@ -62,7 +62,7 @@ SUPPORTED_ALGORITHM = {'dnn': lambda config: DnnModel(DnnConfig.init_from_config
class SQLDiag:
def __init__(self, model_algorithm, csv_file, params):
def __init__(self, model_algorithm, params):
if model_algorithm not in SUPPORTED_ALGORITHM:
raise NotImplementedError("do not support {}".format(model_algorithm))
try:
@ -70,20 +70,16 @@ class SQLDiag:
except ValueError as e:
logging.error(e, exc_info=True)
sys.exit(1)
self.load_data = LoadData(csv_file)
def __getattr__(self, item):
return getattr(self.load_data, item)
def fit(self, data):
self._model.fit(data)
def fit(self):
self._model.fit(self.train_data)
def transform(self, data):
return self._model.transform(data)
def transform(self):
return self._model.transform(self.predict_data)
def fine_tune(self, filepath):
def fine_tune(self, filepath, data):
self._model.load(filepath)
self._model.fit(self.train_data)
self._model.fit(data)
def load(self, filepath):
self._model.load(filepath)

View File

@ -36,7 +36,7 @@ class KerasRegression:
shape = features.shape[1]
if self.model is None:
self.model = self.build_model(shape=shape, encoding_dim=self.encoding_dim)
self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=2)
self.model.fit(features, labels, epochs=epochs, batch_size=batch_size, shuffle=True, verbose=0)
def predict(self, features):
predict_result = self.model.predict(features)
@ -129,10 +129,6 @@ class DnnModel(AbstractModel, ABC):
self.w2v.load(word2vector_path)
with open(scaler_path, 'rb') as f:
self.scaler = pickle.load(f)
logging.info("dnn model is loaded: '{}'; w2v model is loaded: '{}'; scaler model is loaded: '{}'."
.format(dnn_path,
word2vector_path,
scaler_path))
else:
logging.error("{} not exist.".format(realpath))
@ -149,7 +145,5 @@ class DnnModel(AbstractModel, ABC):
self.w2v.save(word2vector_path)
with open(scaler_path, 'wb') as f:
pickle.dump(self.scaler, f)
logging.info("dnn model is saved: '{}'; w2v model is saved: '{}'; scaler model is saved: '{}'."
.format(dnn_path,
word2vector_path,
scaler_path))
print("DNN model is stored in '{}'".format(realpath))

View File

@ -18,9 +18,10 @@ import logging
import os
import stat
from functools import reduce
from collections import defaultdict
from algorithm.sql_similarity import calc_sql_distance
from preprocessing import get_sql_template
from preprocessing import get_sql_template, templatize_sql
from utils import check_illegal_sql, LRUCache
from . import AbstractModel
@ -40,97 +41,78 @@ class TemplateModel(AbstractModel):
for sql, duration_time in data:
if check_illegal_sql(sql):
continue
# get 'fine_template' and 'rough_template' of SQL
fine_template, rough_template = get_sql_template(sql)
sql_prefix = fine_template.split()[0]
# if prefix of SQL is not in 'update', 'delete', 'select' and 'insert',
# then convert prefix to 'other'
sql_template = templatize_sql(sql)
sql_prefix = sql_template.split()[0]
if sql_prefix not in self.__hash_table:
sql_prefix = 'OTHER'
if rough_template not in self.__hash_table[sql_prefix]:
self.__hash_table[sql_prefix][rough_template] = dict()
self.__hash_table[sql_prefix][rough_template]['info'] = dict()
if fine_template not in self.__hash_table[sql_prefix][rough_template]['info']:
self.__hash_table[sql_prefix][rough_template]['info'][fine_template] = \
dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
# count the number of occurrences of fine template
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['count'] += 1
# store the execution time of the matched template in the corresponding list
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'time_list'].append(duration_time)
# iterative calculation of execution time based on historical data
if not self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'iter_time']:
self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'iter_time'] = duration_time
else:
self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time'] = \
(self.__hash_table[sql_prefix][rough_template]['info'][fine_template][
'iter_time'] + duration_time) / 2
# calculate the average execution time of each template
if sql_template not in self.__hash_table[sql_prefix]:
self.__hash_table[sql_prefix][sql_template] = dict(time_list=[], count=0, mean_time=0.0, iter_time=0.0)
self.__hash_table[sql_prefix][sql_template]['count'] += 1
self.__hash_table[sql_prefix][sql_template]['time_list'].append(duration_time)
for sql_prefix, sql_prefix_info in self.__hash_table.items():
for rough_template, rough_template_info in sql_prefix_info.items():
for _, fine_template_info in rough_template_info['info'].items():
del fine_template_info['time_list'][:-self.time_list_size]
fine_template_info['mean_time'] = \
sum(fine_template_info['time_list']) / len(
fine_template_info['time_list'])
rough_template_info['count'] = len(rough_template_info['info'])
rough_template_info['mean_time'] = sum(
[value['mean_time'] for key, value in
rough_template_info['info'].items()]) / len(
rough_template_info['info'])
for sql_template, sql_template_info in sql_prefix_info.items():
del sql_template_info['time_list'][:-self.time_list_size]
sql_template_info['mean_time'] = sum(sql_template_info['time_list']) / len(sql_template_info['time_list'])
sql_template_info['iter_time'] = reduce(lambda x, y: (x+y)/2, sql_template_info['time_list'])
def transform(self, data):
predict_time_list = []
predict_result_dict = defaultdict(list)
for sql in data:
predict_time = self.predict_duration_time(sql)
predict_time_list.append([sql, predict_time])
return predict_time_list
sql_, status, predict_time, top_similarity_sql = self.predict_duration_time(sql)
predict_result_dict[status].append([sql_, predict_time, top_similarity_sql])
for key, value in predict_result_dict.items():
if value:
value.sort(key=lambda item: item[1], reverse=True)
return predict_result_dict
@LRUCache(max_size=1024)
def predict_duration_time(self, sql):
top_similarity_sql = None
if check_illegal_sql(sql):
return -1
sql_prefix = sql.strip().split()[0]
# get 'fine_template' and 'rough_template' of SQL
fine_template, rough_template = get_sql_template(sql)
predict_time = -1
status = 'Suspect illegal sql'
return sql, status, predict_time, top_similarity_sql
sql_template = templatize_sql(sql)
# get 'sql_template' of SQL
sql_prefix = sql_template.strip().split()[0]
if sql_prefix not in self.__hash_table:
sql_prefix = 'OTHER'
if not self.__hash_table[sql_prefix]:
logging.warning("'{}' not in the templates.".format(sql))
status = 'No SQL information'
predict_time = -1
elif rough_template not in self.__hash_table[sql_prefix] or fine_template not in \
self.__hash_table[sql_prefix][rough_template]['info']:
elif sql_template not in self.__hash_table[sql_prefix]:
similarity_info = []
"""
if the template does not exist in the hash table,
then calculate the possible execution time based on template
similarity and KNN algorithm in all other templates
"""
if rough_template not in self.__hash_table[sql_prefix]:
for local_rough_template, local_rough_template_info in self.__hash_table[sql_prefix].items():
similarity_info.append(
(self.similarity_algorithm(rough_template, local_rough_template),
local_rough_template_info['mean_time']))
else:
for local_fine_template, local_fine_template_info in \
self.__hash_table[sql_prefix][rough_template]['info'].items():
similarity_info.append(
(self.similarity_algorithm(fine_template, local_fine_template),
local_fine_template_info['iter_time']))
status = 'No SQL template found'
for local_sql_template, local_sql_template_info in self.__hash_table[sql_prefix].items():
similarity_info.append(
(self.similarity_algorithm(sql_template, local_sql_template),
local_sql_template_info['mean_time'], local_sql_template))
topn_similarity_info = heapq.nlargest(self.knn_number, similarity_info)
sum_similarity_scores = sum(item[0] for item in topn_similarity_info) + self.bias
similarity_proportions = (item[0] / sum_similarity_scores for item in
topn_similarity_info)
topn_duration_time = (item[1] for item in topn_similarity_info)
sum_similarity_scores = sum(item[0] for item in topn_similarity_info)
if not sum_similarity_scores:
sum_similarity_scores = self.bias
top_similarity_sql = '\n'.join([item[2] for item in topn_similarity_info])
similarity_proportions = [item[0] / sum_similarity_scores for item in
topn_similarity_info]
topn_duration_time = [item[1] for item in topn_similarity_info]
predict_time = reduce(lambda x, y: x + y,
map(lambda x, y: x * y, similarity_proportions,
topn_duration_time))
else:
predict_time = self.__hash_table[sql_prefix][rough_template]['info'][fine_template]['iter_time']
return predict_time
else:
status = 'Fine match'
predict_time = self.__hash_table[sql_prefix][sql_template]['iter_time']
top_similarity_sql = sql_template
return sql, status, predict_time, top_similarity_sql
def load(self, filepath):
realpath = os.path.realpath(filepath)
@ -138,7 +120,6 @@ class TemplateModel(AbstractModel):
template_path = os.path.join(realpath, 'template.json')
with open(template_path, mode='r') as f:
self.__hash_table = json.load(f)
logging.info("template model '{}' is loaded.".format(template_path))
else:
logging.error("{} not exist.".format(realpath))
@ -151,4 +132,5 @@ class TemplateModel(AbstractModel):
template_path = os.path.join(realpath, 'template.json')
with open(template_path, mode='w') as f:
json.dump(self.__hash_table, f, indent=4)
logging.info("template model is stored in '{}'".format(realpath))
print("Template model is stored in '{}'".format(realpath))

View File

@ -5,6 +5,8 @@ def calc_sql_distance(algorithm):
from .levenshtein import distance
elif algorithm == 'parse_tree':
from .parse_tree import distance
elif algorithm == 'cosine_distance':
from .cosine_distance import distance
else:
raise NotImplementedError("do not support '{}'".format(algorithm))
return distance

View File

@ -0,0 +1,26 @@
"""
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
import math
from collections import Counter
def distance(str1, str2):
c_1 = Counter(str1)
c_2 = Counter(str2)
c_union = set(c_1).union(c_2)
dot_product = sum(c_1.get(item, 0) * c_2.get(item, 0) for item in c_union)
mag_c1 = math.sqrt(sum(c_1.get(item, 0)**2 for item in c_union))
mag_c2 = math.sqrt(sum(c_2.get(item, 0)**2 for item in c_union))
return dot_product / (mag_c1 * mag_c2)

View File

@ -1,3 +1,19 @@
"""
Copyright (c) 2020 Huawei Technologies Co.,Ltd.
openGauss is licensed under Mulan PSL v2.
You can use this software according to the terms and conditions of the Mulan PSL v2.
You may obtain a copy of Mulan PSL v2 at:
http://license.coscl.org.cn/MulanPSL2
THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND,
EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT,
MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE.
See the Mulan PSL v2 for more details.
"""
def distance(str1, str2):
"""
func: calculate levenshtein distance between two strings.
@ -5,18 +21,20 @@ def distance(str1, str2):
:param str2: string2
:return: distance
"""
m, n = len(str1) + 1, len(str2) + 1
matrix = [[0] * n for _ in range(m)]
matrix[0][0] = 0
for i in range(1, m):
matrix[i][0] = matrix[i - 1][0] + 1
for j in range(1, n):
matrix[0][j] = matrix[0][j - 1] + 1
for i in range(1, m):
for j in range(1, n):
if str1[i - 1] == str2[j - 1]:
matrix[i][j] = matrix[i - 1][j - 1]
else:
matrix[i][j] = max(matrix[i - 1][j - 1], matrix[i - 1][j], matrix[i][j - 1]) + 1
len_str1 = len(str1) + 1
len_str2 = len(str2) + 1
return matrix[m - 1][n - 1]
mat = [[0]*len_str2 for i in range(len_str1)]
mat[0][0] = 0
for i in range(1,len_str1):
mat[i][0] = mat[i-1][0] + 1
for j in range(1,len_str2):
mat[0][j] = mat[0][j-1]+1
for i in range(1,len_str1):
for j in range(1,len_str2):
if str1[i-1] == str2[j-1]:
mat[i][j] = mat[i-1][j-1]
else:
mat[i][j] = min(mat[i-1][j-1],mat[i-1][j],mat[i][j-1])+1
return 1 / mat[len_str1-1][j-1]

View File

@ -12,10 +12,10 @@ __description__ = "Get sql information based on wdr."
def parse_args():
parser = argparse.ArgumentParser(formatter_class=argparse.RawDescriptionHelpFormatter,
description=__description__)
parser.add_argument('--port', help="User of remote server.", type=int, required=True)
parser.add_argument('--start_time', help="Start time of query", required=True)
parser.add_argument('--finish_time', help="Finish time of query", required=True)
parser.add_argument('--save_path', default='sample_data/data.csv', help="Path to save result")
parser.add_argument('--port', help="Port of database service.", type=int, required=True)
parser.add_argument('--start-time', help="Start time of query", required=True)
parser.add_argument('--finish-time', help="Finish time of query", required=True)
parser.add_argument('--save-path', default='sample_data/data.csv', help="Path to save result")
return parser.parse_args()

View File

@ -19,6 +19,7 @@ from configparser import ConfigParser
from algorithm.diag import SQLDiag
from utils import ResultSaver
from preprocessing import LoadData, split_sql
__version__ = '2.0.0'
__description__ = 'SQLdiag integrated by openGauss.'
@ -40,6 +41,8 @@ def parse_args():
parser.add_argument('--predicted-file', help='The file path to save the predicted result.')
parser.add_argument('--model', default='template', choices=['template', 'dnn'],
help='Choose the model model to use.')
parser.add_argument('--query', help='Input the querys to predict.')
parser.add_argument('--threshold', help='Slow SQL threshold.')
parser.add_argument('--model-path', required=True,
help='The storage path of the model file, used to read or save the model file.')
parser.add_argument('--config-file', default='sqldiag.conf')
@ -54,23 +57,64 @@ def get_config(filepath):
def main(args):
logging.basicConfig(level=logging.INFO)
model = SQLDiag(args.model, args.csv_file, get_config(args.config_file))
if args.mode == 'train':
model.fit()
model.save(args.model_path)
elif args.mode == 'predict':
if not args.predicted_file:
logging.error("The [--predicted-file] parameter is required for predict mode")
logging.basicConfig(level=logging.WARNING)
model = SQLDiag(args.model, get_config(args.config_file))
if args.mode in ('train', 'finetune'):
if not args.csv_file:
logging.fatal('The [--csv-file] parameter is required for train mode')
sys.exit(1)
model.load(args.model_path)
pred_result = model.transform()
ResultSaver().save(pred_result, args.predicted_file)
logging.info('predicted result in saved in {}'.format(args.predicted_file))
elif args.mode == 'finetune':
model.fine_tune(args.model_path)
train_data = LoadData(args.csv_file).train_data
if args.mode == 'train':
model.fit(train_data)
else:
model.fine_tune(args.model_path, train_data)
model.save(args.model_path)
else:
model.load(args.model_path)
if args.csv_file and not args.query:
predict_data = LoadData(args.csv_file).predict_data
elif args.query and not args.csv_file:
predict_data = split_sql(args.query)
else:
logging.error('The predict model only supports [--csv-file] or [--query] at the same time.')
sys.exit(1)
args.threshold = -100 if not args.threshold else float(args.threshold)
pred_result = model.transform(predict_data)
if args.predicted_file:
if args.model == 'template':
info_sum = []
for stats, _info in pred_result.items():
if _info:
_info = list(filter(lambda item: item[1]>=args.threshold, _info))
for item in _info:
item.insert(1, stats)
info_sum.extend(_info)
ResultSaver().save(info_sum, args.predicted_file)
else:
pred_result = list(filter(lambda item: float(item[1])>=args.threshold, pred_result))
ResultSaver().save(pred_result, args.predicted_file)
else:
from prettytable import PrettyTable
display_table = PrettyTable()
if args.model == 'template':
display_table.field_names = ['sql', 'status', 'predicted time', 'most similar template']
display_table.align = 'l'
status = ('Suspect illegal SQL', 'No SQL information', 'No SQL template found', 'Fine match')
for stats in status:
if pred_result[stats]:
for sql, predicted_time, similariest_sql in pred_result[stats]:
if predicted_time >= args.threshold or stats == 'Suspect illegal sql':
display_table.add_row([sql, stats, predicted_time, similariest_sql])
else:
display_table.field_names = ['sql', 'predicted time']
display_table.align = 'l'
for sql, predicted_time in pred_result:
if float(predicted_time) >= args.threshold:
display_table.add_row([sql, predicted_time])
print(display_table.get_string())
if __name__ == '__main__':
main(parse_args())

View File

@ -18,8 +18,6 @@ import sqlparse
from sqlparse.sql import Identifier, IdentifierList
from sqlparse.tokens import Keyword, DML
# split flag in SQL
split_flag = ('!=', '<=', '>=', '==', '<', '>', '=', ',', '(', ')', '*', ';', '%', '+', ',', ';')
DDL_WORDS = ('CREATE', 'ALTER', 'DROP', 'TRUNCATE', 'COMMIT', 'RENAME')
DML_WORDS = ('SELECT', 'INSERT INTO', 'UPDATE', 'DELETE', 'MERGE', 'CALL',
@ -31,53 +29,6 @@ KEYWORDS = ('GRANT', 'REVOKE', 'DENY', 'ABORT', 'ADD', 'AGGREGATE', 'ANALYSE', '
FUNC = ('FLOOR', 'SUM', 'NOW', 'UUID', 'COUNT')
SQL_SIG = ('&', '&&')
# filter like (insert into aa (c1, c2) values (v1, v2) => insert into aa * values *)
BRACKET_FILTER = r'\(.*?\)'
# filter (123, 123.123)
PURE_DIGIT_FILTER = r'[\s]+\d+(\.\d+)?'
# filter ('123', '123.123')
SINGLE_QUOTE_DIGIT_FILTER = r'\'\d+(\.\d+)?\''
# filter ("123", "123.123")
DOUBLE_QUOTE_DIGIT_FILTER = r'"\d+(\.\d+)?"'
# filter ('123', 123, '123,123', 123.123) not filter(table1, column1, table_2, column_2)
DIGIT_FILTER = r'([^a-zA-Z])_?\d+(\.\d+)?'
# filter date in sql ('1999-09-09', '1999/09/09', "1999-09-09 20:10:10", '1999/09/09 20:10:10.12345')
PURE_TIME_FILTER = r'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?'
SINGLE_QUOTE_TIME_FILTER = r'\'[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,' \
r'2})?(\.\d+)?\' '
DOUBLE_QUOTE_TIME_FILTER = r'"[0-9]{4}[-/][0-9]{1,2}[-/][0-9]{1,2}\s*([0-9]{1,2}[:][0-9]{1,2}[:][0-9]{1,2})?(\.\d+)?"'
# filter like "where id='abcd" => "where id=#"
SINGLE_QUOTE_FILTER = r'\'.*?\''
# filter like 'where id="abcd" => 'where id=#'
DOUBLE_QUOTE_FILTER = r'".*?"'
# filter annotation like "/* XXX */"
ANNOTATION_FILTER_1 = r'/\s*\*[\w\W]*?\*\s*/\s*'
ANNOTATION_FILTER_2 = r'^--.*\s?'
# filter NULL character '\n \t' in sql
NULL_CHARACTER_FILTER = r'\s+'
# remove data in insert sql
VALUE_BRACKET_FILETER = r'VALUES (\(.*\))'
# remove equal data in sql
WHERE_EQUAL_FILTER = r'= .*?\s'
LESS_EQUAL_FILTER = r'(<= .*? |<= .*$)'
GREATER_EQUAL_FILTER = r'(>= .*? |<= .*$)'
LESS_FILTER = r'(< .*? |< .*$)'
GREATER_FILTER = r'(> .*? |> .*$)'
EQUALS_FILTER = r'(= .*? |= .*$)'
LIMIT_DIGIT = r'LIMIT \d+'
def _unify_sql(sql):
"""
@ -85,65 +36,49 @@ def _unify_sql(sql):
"""
index = 0
sql = re.sub(r'\n', r' ', sql)
sql = re.sub(ANNOTATION_FILTER_1, r'', sql)
sql = re.sub(ANNOTATION_FILTER_2, r'', sql)
while index < len(sql):
if sql[index] in split_flag:
if sql[index:index + 2] in split_flag:
sql = sql[:index].strip() + ' ' + sql[index:index + 2] + ' ' + sql[index + 2:].strip()
index = index + 3
else:
sql = sql[:index].strip() + ' ' + sql[index] + ' ' + sql[index + 1:].strip()
index = index + 2
else:
index = index + 1
new_sql = list()
for word in sql.split():
new_sql.append(word.upper())
sql = ' '.join(new_sql)
sql = re.sub(r'/\s*\*[\w\W]*?\*\s*/\s*', r'', sql)
sql = re.sub(r'^--.*\s?', r'', sql)
sql = re.sub(r'([!><=]=)', r' \1 ', sql)
sql = re.sub(r'([^!><=])([=<>])', r'\1 \2 ', sql)
sql = re.sub(r'([,()*%/+])', r' \1 ', sql)
sql = re.sub(r'\s+', r' ', sql)
sql = sql.upper()
return sql.strip()
def split_sql(sqls):
if not sqls:
return []
sqls = sqls.split(';')
result = list(map(lambda item: _unify_sql(item), sqls))
return result
def templatize_sql(sql):
"""
function: replace the message which is not important in sql
SQL desensitization
"""
sql = _unify_sql(sql)
if not sql:
return ''
standard_sql = _unify_sql(sql)
sql = re.sub(r';', r'', sql)
if standard_sql.startswith('INSERT'):
standard_sql = re.sub(r'VALUES (\(.*\))', r'VALUES', standard_sql)
# remove digital like 12, 12.565
standard_sql = re.sub(r'[\s]+\d+(\.\d+)?', r' ?', standard_sql)
# remove '$n' in sql
standard_sql = re.sub(r'\$\d+', r'?', standard_sql)
# remove single quotes content
standard_sql = re.sub(r'\'.*?\'', r'?', standard_sql)
# remove double quotes content
standard_sql = re.sub(r'".*?"', r'?', standard_sql)
# remove '`' in sql
standard_sql = re.sub(r'`', r'', standard_sql)
# remove ; in sql
standard_sql = re.sub(r';', r'', standard_sql)
# ? represent date or time
sql = re.sub(PURE_TIME_FILTER, r'?', sql)
sql = re.sub(SINGLE_QUOTE_TIME_FILTER, r'?', sql)
sql = re.sub(DOUBLE_QUOTE_TIME_FILTER, r'?', sql)
# $ represent insert value
if sql.startswith('INSERT'):
sql = re.sub(VALUE_BRACKET_FILETER, r'VALUES ()', sql)
# $$ represent select value
if sql.startswith('SELECT') and ' = ' in sql:
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$ ', sql)
# $$$ represent delete value
if sql.startswith('DELETE') and ' = ' in sql:
sql = re.sub(WHERE_EQUAL_FILTER, r'= $$$ ', sql)
# & represent logical signal
sql = re.sub(LESS_EQUAL_FILTER, r'<= & ', sql)
sql = re.sub(LESS_FILTER, r'< & ', sql)
sql = re.sub(GREATER_EQUAL_FILTER, r'>= & ', sql)
sql = re.sub(GREATER_FILTER, r'> & ', sql)
sql = re.sub(LIMIT_DIGIT, r'LIMIT &', sql)
sql = re.sub(EQUALS_FILTER, r'= & ', sql)
sql = re.sub(PURE_DIGIT_FILTER, r' &', sql)
sql = re.sub(r'`', r'', sql)
# && represent quote str
sql = re.sub(SINGLE_QUOTE_FILTER, r'?', sql)
sql = re.sub(DOUBLE_QUOTE_FILTER, r'?', sql)
return sql
return standard_sql.strip()
def _is_select_clause(parsed_sql):
@ -155,7 +90,6 @@ def _is_select_clause(parsed_sql):
return False
# todo: what is token list? from list?
def _get_table_token_list(parsed_sql, token_list):
flag = False
for token in parsed_sql.tokens:

View File

@ -26,6 +26,7 @@ min_limit = 6.78e-05
# Template Methed
#------------------------------------------------------------------------------
[template]
similarity_algorithm =
time_list_size =
knn_number =
# [cosine_distance, levenshtein, list, parse_tree]
similarity_algorithm = cosine_distance
time_list_size = 20
knn_number = 1

View File

@ -25,20 +25,15 @@ class ResultSaver:
os.chmod(dirname, stat.S_IRWXU)
if isinstance(data, (list, tuple)):
self.save_list(data, realpath)
elif isinstance(data, dict):
self.save_dict(data, path)
else:
raise TypeError("mode should be 'list', 'tuple' or 'dict', but input type is '{}'".format(str(type(data))))
@staticmethod
def save_list(data, path):
data = pd.DataFrame(data)
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"')
@staticmethod
def save_dict(data, path):
data = pd.DataFrame(data.items())
data.to_csv(path, index=False, sep=',', header=False, quoting=csv.QUOTE_NONE, escapechar='\"')
with open(path, mode='w') as f:
for item in data:
content = ",".join([str(sub_item) for sub_item in item])
f.write(content + '\n')
class DBAgent:
@ -74,7 +69,7 @@ class DBAgent:
result = list(self.cursor.fetchall())
return result
except Exception as e:
logging.getLogger('agent').warning(str(e))
logging.warning(str(e))
def close(self):
self.cursor.close()