90 lines
4.3 KiB
Python
90 lines
4.3 KiB
Python
|
|
import sys
|
|
import subprocess
|
|
import os
|
|
import time
|
|
import jaydebeapi
|
|
|
|
# Abstract SQL connection
|
|
class SQLConnection:
|
|
def __init__(self, port = '3306', host = '127.0.0.1', user = 'root', password = ''):
|
|
self.host = str(host)
|
|
self.port = str(port)
|
|
self.user = str(user)
|
|
self.password = str(password)
|
|
|
|
# Connect to a server
|
|
def connect(self, options = ""):
|
|
try:
|
|
self.conn = jaydebeapi.connect("org.mariadb.jdbc.Driver", ["jdbc:mariadb://" + self.host + ":" + self.port + "/test?" + options, self.user, self.password],"./maxscale/java/mariadb-java-client-1.3.3.jar")
|
|
except Exception as ex:
|
|
print("Failed to connect to " + self.host + ":" + self.port + " as " + self.user + ":" + self.password)
|
|
print(unicode(ex))
|
|
exit(1)
|
|
|
|
# Start a transaction
|
|
def begin(self):
|
|
curs = self.conn.cursor()
|
|
curs.execute("BEGIN")
|
|
curs.close()
|
|
# Commit a transaction
|
|
def commit(self):
|
|
curs = self.conn.cursor()
|
|
curs.execute("COMMIT")
|
|
curs.close()
|
|
|
|
# Query and test if the result matches the expected value if one is provided
|
|
def query(self, query, compare = None, column = 0):
|
|
curs = self.conn.cursor()
|
|
curs.execute(query)
|
|
return curs.fetchall()
|
|
|
|
def query_and_compare(self, query, column):
|
|
data = self.query(query)
|
|
for row in data:
|
|
if str(row[column]) == compare:
|
|
return True
|
|
return False
|
|
|
|
def disconnect(self):
|
|
self.conn.close()
|
|
|
|
def query_and_close(self, query):
|
|
self.connect()
|
|
self.query(query)
|
|
self.disconnect()
|
|
|
|
# Test environment abstraction
|
|
class MaxScaleTest:
|
|
def __init__(self, testname = "python_test"):
|
|
|
|
self.testname = testname
|
|
prepare_test(testname)
|
|
|
|
# MaxScale connections
|
|
self.maxscale = dict()
|
|
self.maxscale['rwsplit'] = SQLConnection(host = os.getenv("maxscale_IP"), port = "4006", user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.maxscale['rcmaster'] = SQLConnection(host = os.getenv("maxscale_IP"), port = "4008", user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.maxscale['rcslave'] = SQLConnection(host = os.getenv("maxscale_IP"), port = "4009", user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
|
|
# Master-Slave nodes
|
|
self.repl = dict()
|
|
self.repl['node0'] = SQLConnection(host = os.getenv("node_000_network"), port = os.getenv("node_000_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.repl['node1'] = SQLConnection(host = os.getenv("node_001_network"), port = os.getenv("node_001_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.repl['node2'] = SQLConnection(host = os.getenv("node_002_network"), port = os.getenv("node_002_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.repl['node3'] = SQLConnection(host = os.getenv("node_003_network"), port = os.getenv("node_003_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
|
|
# Galera nodes
|
|
self.galera = dict()
|
|
self.galera['node0'] = SQLConnection(host = os.getenv("galera_000_network"), port = os.getenv("galera_000_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.galera['node1'] = SQLConnection(host = os.getenv("galera_001_network"), port = os.getenv("galera_001_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.galera['node2'] = SQLConnection(host = os.getenv("galera_002_network"), port = os.getenv("galera_002_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
self.galera['node3'] = SQLConnection(host = os.getenv("galera_003_network"), port = os.getenv("galera_003_port"), user = os.getenv("maxscale_user"), password = os.getenv("maxscale_password"))
|
|
|
|
def __del__(self):
|
|
subprocess.call(os.getcwd() + "/copy_logs.sh " + str(self.testname), shell=True)
|
|
|
|
# Read test environment variables
|
|
def prepare_test(testname = "replication"):
|
|
subprocess.call(os.getcwd() + "/non_native_setup " + str(testname), shell=True)
|