#!/usr/bin/env python # Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. import os import subprocess import requests import json import time # Change the host port username password and database name on your need mycli_cmd = "mysql -h127.0.0.1 -P9030 -uroot -Dtpch1G" # FE http://host:port feHttp = "http://localhost:8030" trace_url = feHttp + '/rest/v2/manager/query/trace_id/{}' qerror_url = feHttp + '/rest/v2/manager/query/qerror/{}' # File path to save test results. # Sample: # 8 # { # "legacyPlanIdToPhysicalPlan": { # "0": { # "first": 1.0, # "second": 1.0 # }, # ....... # "qError": 34.5 # } # `8` represents q8 in the tpc-h test # `first` is the estimated row count for plan which with plan id 0, `second` is the actual returned row count qerr_saved_file_path = "" # SQL under this directory would be tested. original_sql_dir = "add your tpc-h/tpch-ds/ssb sql directory path here" sql_file_prefix_for_trace = """ SET enable_nereids_planner=true; SET session_context='trace_id:{}'; """ q_err_list = [] def extract_number(string): return int(''.join([c for c in string if c.isdigit()])) def write_results(path: str, title: str, result: list): with open(path, "a") as file: file.write(title) file.write("\n") for item in result: file.write(str(item) + " " + "\n") file.write("\n") def read_lines(path: str) -> list: with open(path, "r") as f: return f.readlines() def write_result(title: str, result: str): wrapped = [result] write_results(qerr_saved_file_path, title, wrapped) def execute_command(cmd: str): result = subprocess.run(cmd, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE) return result.stdout def execute_sql(sql_file: str): command = mycli_cmd + " < " + sql_file result = execute_command(command).decode("utf-8") return result def get_q_error(trace_id): time.sleep(1) # 'YWRtaW46' is the base64 encoded result for 'admin:' headers = {'Authorization': 'BASIC YWRtaW46'} resp_wrapper = requests.get(trace_url.format(trace_id), headers=headers) resp_text = resp_wrapper.text query_id = json.loads(resp_text)["data"] resp_wrapper = requests.get(qerror_url.format(query_id), headers=headers) resp_text = resp_wrapper.text write_result(str(trace_id), resp_text) print(trace_id) print(resp_text) qerr = json.loads(resp_text)["qError"] q_err_list.append(float(qerr)) def iterates_sqls(path: str, if_write_results: bool) -> list: cost_times = [] files = os.listdir(path) files.sort(key=extract_number) for filename in files: if filename.endswith(".sql"): filepath = os.path.join(path, filename) traced_sql_file = filepath + ".traced" content = read_lines(filepath) sql_num = extract_number(filename) print("sql num" + str(sql_num)) if if_write_results: write_results(traced_sql_file, str(sql_file_prefix_for_trace.format(sql_num)), content) execute_sql(traced_sql_file) get_q_error(sql_num) os.remove(traced_sql_file) else: execute_sql(filepath) return cost_times if __name__ == '__main__': execute_command("echo 'set global enable_nereids_planner=true' | mysql -h127.0.0.1 -P9030") execute_command("echo 'set global enable_fallback_to_original_planner=false' | mysql -h127.0.0.1 -P9030") print("Preparing") iterates_sqls(original_sql_dir, False) print("Started...") iterates_sqls(original_sql_dir, True) write_results(qerr_saved_file_path, "AVG\n", [sum(q_err_list) / len(q_err_list)])