# Licensed to the Apache Software Foundation (ASF) under one # or more contributor license agreements. See the NOTICE file # distributed with this work for additional information # regarding copyright ownership. The ASF licenses this file # to you under the Apache License, Version 2.0 (the # "License"); you may not use this file except in compliance # with the License. You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, # software distributed under the License is distributed on an # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. from unittest import result import mysql.connector from typing import List, Tuple import re class SQLExecutor: def __init__(self, user: str, password: str, host: str, port: int, database: str) -> None: self.connection = mysql.connector.connect( user=user, password=password, host=host, port=port, database=database ) self.cursor = self.connection.cursor() self.wait_fetch_time_index = 4 def execute_query(self, query: str, parameters: Tuple | None) -> List[Tuple]: if parameters: self.cursor.execute(query, parameters) else: self.cursor.execute(query) results = self.cursor.fetchall() return results def get_execute_time(self, query: str) -> float: self.execute_query(query, None) profile = self.execute_query("show query profile\"\"", None) return self.get_n_ms(profile[0][self.wait_fetch_time_index]) def get_n_ms(self, t: str): res = re.search(r"(\d+h)*(\d+min)*(\d+s)*(\d+ms)", t) if res is None: raise Exception(f"invalid time {t}") n = 0 h = res.group(1) if h is not None: n += int(h.replace("h", "")) * 60 * 60 * 1000 min = res.group(2) if min is not None != 0: n += int(min.replace("min", "")) * 60 * 1000 s = res.group(3) if s is not None != 0: n += int(s.replace("s", "")) * 1000 ms = res.group(4) if len(ms) != 0: n += int(ms.replace("ms", "")) return n def execute_many_queries(self, queries: List[Tuple[str, Tuple]]) -> List[List[Tuple]]: results = [] for query, parameters in queries: result = self.execute_query(query, parameters) results.append(result) return results def get_plan_with_cost(self, query: str): result = self.execute_query(f"explain optimized plan {query}", None) cost = float(result[0][0].replace("cost = ", "")) plan = "".join([s[0] for s in result[1:]]) return plan, cost def commit(self) -> None: self.connection.commit() def rollback(self) -> None: self.connection.rollback() def close(self) -> None: self.cursor.close() self.connection.close()