1. Estimate timearithmeticexpr instead of setting Double.MAX Double.MIN directly 2. Enable histogram to derive stats 3. Loose the condition for histogram usage 4. Improve the accuracy for agg on TPC-H 1G greatly 5. Fix avg qerror calculation
140 lines
4.4 KiB
Python
140 lines
4.4 KiB
Python
#!/usr/bin/env python
|
|
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
import os
|
|
import subprocess
|
|
|
|
import requests
|
|
import json
|
|
import time
|
|
|
|
# Change the host port username password and database name on your need
|
|
mycli_cmd = "mysql -h127.0.0.1 -P9030 -uroot -Dtpch1G"
|
|
|
|
# FE http://host:port
|
|
feHttp = "http://localhost:8030"
|
|
trace_url = feHttp + '/rest/v2/manager/query/trace_id/{}'
|
|
qerror_url = feHttp + '/rest/v2/manager/query/qerror/{}'
|
|
|
|
# File path to save test results.
|
|
# Sample:
|
|
# 8
|
|
# {
|
|
# "legacyPlanIdToPhysicalPlan": {
|
|
# "0": {
|
|
# "first": 1.0,
|
|
# "second": 1.0
|
|
# },
|
|
# .......
|
|
# "qError": 34.5
|
|
# }
|
|
# `8` represents q8 in the tpc-h test
|
|
# `first` is the estimated row count for plan which with plan id 0, `second` is the actual returned row count
|
|
qerr_saved_file_path = ""
|
|
|
|
# SQL under this directory would be tested.
|
|
original_sql_dir = "add your tpc-h/tpch-ds/ssb sql directory path here"
|
|
|
|
sql_file_prefix_for_trace = """
|
|
SET enable_nereids_planner=true;
|
|
SET session_context='trace_id:{}';
|
|
"""
|
|
|
|
q_err_list = []
|
|
|
|
|
|
def extract_number(string):
|
|
return int(''.join([c for c in string if c.isdigit()]))
|
|
|
|
|
|
def write_results(path: str, title: str, result: list):
|
|
with open(path, "a") as file:
|
|
file.write(title)
|
|
file.write("\n")
|
|
for item in result:
|
|
file.write(str(item) + " " + "\n")
|
|
file.write("\n")
|
|
|
|
|
|
def read_lines(path: str) -> list:
|
|
with open(path, "r") as f:
|
|
return f.readlines()
|
|
|
|
|
|
def write_result(title: str, result: str):
|
|
wrapped = [result]
|
|
write_results(qerr_saved_file_path, title, wrapped)
|
|
|
|
|
|
def execute_command(cmd: str):
|
|
result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
|
|
return result.stdout
|
|
|
|
|
|
def execute_sql(sql_file: str):
|
|
command = mycli_cmd + " < " + sql_file
|
|
result = execute_command(command).decode("utf-8")
|
|
return result
|
|
|
|
|
|
def get_q_error(trace_id):
|
|
time.sleep(1)
|
|
# 'YWRtaW46' is the base64 encoded result for 'admin:'
|
|
headers = {'Authorization': 'BASIC YWRtaW46'}
|
|
resp_wrapper = requests.get(trace_url.format(trace_id), headers=headers)
|
|
resp_text = resp_wrapper.text
|
|
query_id = json.loads(resp_text)["data"]
|
|
resp_wrapper = requests.get(qerror_url.format(query_id), headers=headers)
|
|
resp_text = resp_wrapper.text
|
|
write_result(str(trace_id), resp_text)
|
|
print(trace_id)
|
|
print(resp_text)
|
|
qerr = json.loads(resp_text)["qError"]
|
|
q_err_list.append(float(qerr))
|
|
|
|
|
|
def iterates_sqls(path: str, if_write_results: bool) -> list:
|
|
cost_times = []
|
|
files = os.listdir(path)
|
|
files.sort(key=extract_number)
|
|
for filename in files:
|
|
if filename.endswith(".sql"):
|
|
filepath = os.path.join(path, filename)
|
|
traced_sql_file = filepath + ".traced"
|
|
content = read_lines(filepath)
|
|
sql_num = extract_number(filename)
|
|
print("sql num" + str(sql_num))
|
|
if if_write_results:
|
|
write_results(traced_sql_file, str(sql_file_prefix_for_trace.format(sql_num)), content)
|
|
execute_sql(traced_sql_file)
|
|
get_q_error(sql_num)
|
|
os.remove(traced_sql_file)
|
|
else:
|
|
execute_sql(filepath)
|
|
return cost_times
|
|
|
|
|
|
if __name__ == '__main__':
|
|
execute_command("echo 'set global enable_nereids_planner=true' | mysql -h127.0.0.1 -P9030")
|
|
execute_command("echo 'set global enable_fallback_to_original_planner=false' | mysql -h127.0.0.1 -P9030")
|
|
print("Preparing")
|
|
iterates_sqls(original_sql_dir, False)
|
|
print("Started...")
|
|
iterates_sqls(original_sql_dir, True)
|
|
write_results(qerr_saved_file_path, "AVG\n", [sum(q_err_list) / len(q_err_list)])
|