188 lines
4.8 KiB
Python
188 lines
4.8 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from curses.ascii import isalpha
|
|
import sys
|
|
import re
|
|
import os
|
|
|
|
# line format is "sql_name, groovy_name"
|
|
|
|
suite = """// Licensed to the Apache Software Foundation (ASF) under one
|
|
// or more contributor license agreements. See the NOTICE file
|
|
// distributed with this work for additional information
|
|
// regarding copyright ownership. The ASF licenses this file
|
|
// to you under the Apache License, Version 2.0 (the
|
|
// "License"); you may not use this file except in compliance
|
|
// with the License. You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing,
|
|
// software distributed under the License is distributed on an
|
|
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
// KIND, either express or implied. See the License for the
|
|
// specific language governing permissions and limitations
|
|
// under the License.
|
|
|
|
suite("test_explain_tpch_sf_1_q{}", "tpch_sf1") {{
|
|
String realDb = context.config.getDbNameByFile(context.file)
|
|
// get parent directory's group
|
|
realDb = realDb.substring(0, realDb.lastIndexOf("_"))
|
|
|
|
sql "use ${{realDb}}"
|
|
"""
|
|
|
|
params = sys.argv
|
|
|
|
sql_dir = params[1]
|
|
|
|
output_dir = params[2]
|
|
|
|
if not os.path.exists(output_dir):
|
|
os.mkdir(output_dir)
|
|
|
|
sql_files = [f for f in os.listdir(sql_dir) if os.path.isfile(os.path.join(sql_dir, f))]
|
|
|
|
sql_numbers = [re.findall(r'\d+', f)[0] for f in sql_files]
|
|
|
|
output_files = [os.path.join(output_dir, 'test_'+x[1] + '.groovy') for x in zip(sql_files, sql_numbers)]
|
|
|
|
sql_files_path = [os.path.join(sql_dir, p) for p in sql_files]
|
|
|
|
tasks = zip(sql_files_path, output_files, sql_numbers)
|
|
|
|
patterns = [
|
|
'order by.*$',
|
|
'runtime filters.*$',
|
|
'VTOP-N',
|
|
'join op.*$',
|
|
'equal join.*$',
|
|
'TABLE.*$',
|
|
'PREDICATES.*$',
|
|
'output.*$',
|
|
'other join predicates.*$',
|
|
'other predicates.*$',
|
|
'group by.*$',
|
|
'VAGGREGATE.*$',
|
|
'STREAMING',
|
|
'cross join.*$',
|
|
'predicates.*$',
|
|
]
|
|
|
|
for task in tasks:
|
|
|
|
print(task)
|
|
f1, f2, num = task
|
|
|
|
f = open(f1, 'r')
|
|
oldsql = f.read()
|
|
f.close()
|
|
|
|
f = open(f1, 'w')
|
|
f.write('explain\n' + oldsql + ';')
|
|
f.close()
|
|
|
|
conn = 'mysql -h 127.0.0.1 -P 9030 -uroot regression_test_tpch_sf1 <'
|
|
|
|
print(conn + f1)
|
|
|
|
lines = os.popen(conn + f1).readlines()
|
|
|
|
f = open(f1, 'w')
|
|
f.write(oldsql)
|
|
f.close()
|
|
|
|
# print(oldsql)
|
|
|
|
sqls = open(f1).readlines()
|
|
|
|
print(len(lines))
|
|
|
|
res_pattern = []
|
|
|
|
res_g = []
|
|
|
|
for id, line in enumerate(lines):
|
|
lines[id] = line.rstrip('\n')
|
|
|
|
res = []
|
|
|
|
cur = []
|
|
for id, line in enumerate(lines):
|
|
# print(g)
|
|
line = line.replace('$','\$')
|
|
for p in patterns:
|
|
if len(re.findall(p, line)) > 0: #and line.find('null') == -1:
|
|
cur.append((line, id))
|
|
|
|
# print(cur)
|
|
res_g = [[cur[0][0]]]
|
|
for i in range(1, len(cur)):
|
|
if cur[i][1] == cur[i - 1][1] + 1:
|
|
res_g[-1].append(cur[i][0])
|
|
else:
|
|
res_g.append([cur[i][0]])
|
|
|
|
for g in res_g:
|
|
s = g[0]
|
|
while not isalpha(s[0]):
|
|
s = s[1:]
|
|
g[0] = s
|
|
|
|
explain_pattern = '''
|
|
explainStr ->
|
|
explainStr.contains('''
|
|
|
|
explain_pattern1 = ''') &&
|
|
explainStr.contains('''
|
|
|
|
chkstr = ''
|
|
for g in res_g:
|
|
chkstr = chkstr + '\t\t' + 'explainStr.contains("' + g[0]
|
|
if len(g) == 1:
|
|
chkstr = chkstr + '") && \n'
|
|
continue
|
|
else:
|
|
chkstr = chkstr + '\\n" + \n'
|
|
for s in g[1:-1]:
|
|
chkstr = chkstr + '\t\t\t\t"' + s + '\\n" + \n'
|
|
chkstr = chkstr + '\t\t\t\t"' + g[-1] + '") && \n'
|
|
|
|
chkstr = chkstr[:-4]
|
|
pattern = '''
|
|
explain {
|
|
sql """
|
|
'''
|
|
pattern1 = '''
|
|
"""
|
|
check {
|
|
explainStr ->
|
|
'''
|
|
|
|
pattern2 = '''
|
|
|
|
}
|
|
}
|
|
}'''
|
|
|
|
sql = ''
|
|
for line in sqls[1:]:
|
|
sql = sql + '\t\t' + line
|
|
|
|
# print(suite.format(num) + pattern + sql + pattern1 + chkstr + pattern2)
|
|
open(f2, 'w').write(suite.format(num) + pattern + sql + pattern1 + chkstr + pattern2) |