add test-frame-work for cost model according paper Testing the Accuracy of Query Optimizers
92 lines
3.2 KiB
Python
92 lines
3.2 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from distutils.command.config import config
|
|
from config import Config
|
|
from index_calculator import IndexCalculator
|
|
from sql_executor import SQLExecutor
|
|
import matplotlib.pyplot as plt
|
|
|
|
|
|
class Evaluator:
|
|
def __init__(self, config: Config, query: str) -> None:
|
|
self.config = config
|
|
self.query = query.lower()
|
|
self.setup_queries = [
|
|
"set enable_nereids_planner=true;",
|
|
"set enable_fallback_to_original_planner=false;",
|
|
"set enable_profile=true;"
|
|
]
|
|
self.sql_executor = SQLExecutor(
|
|
config.user,
|
|
config.password,
|
|
config.host,
|
|
config.port,
|
|
config.database)
|
|
|
|
def cold_run(self):
|
|
for _ in range(self.config.cold_run):
|
|
self.sql_executor.execute_query(self.query, None)
|
|
|
|
def evaluate(self):
|
|
self.setup()
|
|
self.cold_run()
|
|
plans = self.extract_all_plans()
|
|
res: list[tuple[float, float]] = []
|
|
for n, (plan, cost) in plans.items():
|
|
time = self.sql_executor.get_execute_time(plan)
|
|
res.append((cost, time))
|
|
if self.config.plot:
|
|
self.plot(res)
|
|
print(res)
|
|
index_calculator = IndexCalculator(res)
|
|
return index_calculator.calculate()
|
|
|
|
def plot(self, data):
|
|
x_values = [t[0] for t in data]
|
|
y_values = [t[1] for t in data]
|
|
fig, ax = plt.subplots()
|
|
ax.scatter(x_values, y_values)
|
|
ax.set_xlabel('Cost')
|
|
ax.set_ylabel('Time')
|
|
plt.show()
|
|
|
|
def setup(self):
|
|
for q in self.setup_queries:
|
|
self.sql_executor.execute_query(q, None)
|
|
|
|
def extract_all_plans(self):
|
|
plan_set = set()
|
|
plan_map: dict[int, tuple[str, float]] = {}
|
|
for n in range(1, self.config.plan_number):
|
|
query = self.inject_nth_optimized_hint(n)
|
|
plan, cost = self.sql_executor.get_plan_with_cost(query)
|
|
if plan in plan_set:
|
|
break
|
|
plan_set.add(plan)
|
|
plan_map[n] = (query, cost)
|
|
return plan_map
|
|
|
|
def inject_nth_optimized_hint(self, n: int):
|
|
if ("set_var(" in self.query):
|
|
query = self.query.replace(
|
|
"/*+set_var(", f"/*+set_var(nth_optimized_plan={n}, ")
|
|
else:
|
|
query = self.query.replace(
|
|
"select", f"select /*+set_var(nth_optimized_plan={n})*/")
|
|
return query
|