92 lines
3.1 KiB
Python
92 lines
3.1 KiB
Python
# Licensed to the Apache Software Foundation (ASF) under one
|
|
# or more contributor license agreements. See the NOTICE file
|
|
# distributed with this work for additional information
|
|
# regarding copyright ownership. The ASF licenses this file
|
|
# to you under the Apache License, Version 2.0 (the
|
|
# "License"); you may not use this file except in compliance
|
|
# with the License. You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing,
|
|
# software distributed under the License is distributed on an
|
|
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
|
# KIND, either express or implied. See the License for the
|
|
# specific language governing permissions and limitations
|
|
# under the License.
|
|
|
|
from unittest import result
|
|
import mysql.connector
|
|
from typing import List, Tuple
|
|
import re
|
|
|
|
|
|
class SQLExecutor:
|
|
def __init__(self, user: str, password: str, host: str, port: int, database: str) -> None:
|
|
self.connection = mysql.connector.connect(
|
|
user=user,
|
|
password=password,
|
|
host=host,
|
|
port=port,
|
|
database=database
|
|
)
|
|
self.cursor = self.connection.cursor()
|
|
self.wait_fetch_time_index = 4
|
|
|
|
def execute_query(self, query: str, parameters: Tuple | None) -> List[Tuple]:
|
|
if parameters:
|
|
self.cursor.execute(query, parameters)
|
|
else:
|
|
self.cursor.execute(query)
|
|
results = self.cursor.fetchall()
|
|
return results
|
|
|
|
def get_execute_time(self, query: str) -> float:
|
|
self.execute_query(query, None)
|
|
profile = self.execute_query("show query profile\"\"", None)
|
|
return self.get_n_ms(profile[0][self.wait_fetch_time_index])
|
|
|
|
def get_n_ms(self, t: str):
|
|
res = re.search(r"(\d+h)*(\d+min)*(\d+s)*(\d+ms)", t)
|
|
if res is None:
|
|
raise Exception(f"invalid time {t}")
|
|
n = 0
|
|
|
|
h = res.group(1)
|
|
if h is not None:
|
|
n += int(h.replace("h", "")) * 60 * 60 * 1000
|
|
min = res.group(2)
|
|
if min is not None != 0:
|
|
n += int(min.replace("min", "")) * 60 * 1000
|
|
s = res.group(3)
|
|
if s is not None != 0:
|
|
n += int(s.replace("s", "")) * 1000
|
|
ms = res.group(4)
|
|
if len(ms) != 0:
|
|
n += int(ms.replace("ms", ""))
|
|
|
|
return n
|
|
|
|
def execute_many_queries(self, queries: List[Tuple[str, Tuple]]) -> List[List[Tuple]]:
|
|
results = []
|
|
for query, parameters in queries:
|
|
result = self.execute_query(query, parameters)
|
|
results.append(result)
|
|
return results
|
|
|
|
def get_plan_with_cost(self, query: str):
|
|
result = self.execute_query(f"explain optimized plan {query}", None)
|
|
cost = float(result[0][0].replace("cost = ", ""))
|
|
plan = "".join([s[0] for s in result[1:]])
|
|
return plan, cost
|
|
|
|
def commit(self) -> None:
|
|
self.connection.commit()
|
|
|
|
def rollback(self) -> None:
|
|
self.connection.rollback()
|
|
|
|
def close(self) -> None:
|
|
self.cursor.close()
|
|
self.connection.close()
|