Implementing Monte Carlo Tree Search Algorithm

This commit is contained in:
flyly
2022-02-11 11:20:43 +08:00
parent 385f38dffb
commit a7b1911fe0
3 changed files with 457 additions and 46 deletions

View File

@ -31,7 +31,9 @@ class ExecuteFactory:
for pos, index in enumerate(cur_table_indexes[:-1]):
is_redundant = False
for candidate_index in cur_table_indexes[pos + 1:]:
if re.match(r'%s' % index.columns, candidate_index.columns):
existed_index1 = list(map(str.strip, index.columns.split(',')))
existed_index2 = list(map(str.strip, candidate_index.columns.split(',')))
if existed_index1 == existed_index2[0: len(existed_index1)]:
is_redundant = True
index.redundant_obj.append(candidate_index)
if is_redundant:
@ -97,14 +99,19 @@ class ExecuteFactory:
candidate_index.select_sql_num += obj.frequency
# SELECT scenes to filter out positive
if ind not in candidate_index.positive_pos and \
any(column in obj.statement.lower() for column in candidate_index.columns):
any(re.search(r'\b%s\b' % column, obj.statement.lower())
for column in candidate_index.columns.split(', ')):
candidate_index.ineffective_pos.append(ind)
candidate_index.total_sql_num += obj.frequency
@staticmethod
def match_last_result(table_name, index_column, history_indexes, history_invalid_indexes):
for column in history_indexes.get(table_name, dict()):
if re.match(r'%s' % column, index_column):
history_index_column = list(map(str.strip, column.split(',')))
existed_index_column = list(map(str.strip, index_column.split(',')))
if len(history_index_column) > len(existed_index_column):
continue
if history_index_column == existed_index_column[0:len(history_index_column)]:
history_indexes[table_name].remove(column)
history_invalid_indexes[table_name] = history_invalid_indexes.get(table_name, list())
history_invalid_indexes[table_name].append(column)

View File

@ -22,7 +22,9 @@ import re
import json
import select
import logging
from DAO.gsql_execute import GSqlExecute
from mcts import MCTS
ENABLE_MULTI_NODE = False
SAMPLE_NUM = 5
@ -123,15 +125,15 @@ def print_header_boundary(header):
print(green(title))
def filter_low_benefit(pos_list, candidate_indexes, multi_iter_mode, workload):
def filter_low_benefit(candidate_indexes, multi_iter_mode, workload):
remove_list = []
for key, index in enumerate(candidate_indexes):
sql_optimzed = 0
if multi_iter_mode:
cost_list_pos = index.atomic_pos
else:
cost_list_pos = key + 1
for ind, pos in enumerate(index.positive_pos):
if multi_iter_mode:
cost_list_pos = index.atomic_pos
else:
cost_list_pos = pos_list[key] + 1
sql_optimzed += 1 - workload[pos].cost_list[cost_list_pos] / workload[pos].cost_list[0]
negative_ratio = ((index.insert_sql_num + index.delete_sql_num +
index.update_sql_num) / index.total_sql_num) if index.total_sql_num else 0
@ -148,12 +150,8 @@ def display_recommend_result(workload, candidate_indexes, index_cost_total, mult
display_info, integrate_indexes, history_invalid_indexes):
cnt = 0
index_current_storage = 0
pos_list = []
if not multi_iter_mode:
pos_list = [item[0] for item in candidate_indexes]
candidate_indexes = [item[1] for item in candidate_indexes]
# filter candidate indexes with low benefit
filter_low_benefit(pos_list, candidate_indexes, multi_iter_mode, workload)
filter_low_benefit(candidate_indexes, multi_iter_mode, workload)
# display determine result
integrate_indexes['currentIndexes'] = dict()
for key, index in enumerate(candidate_indexes):
@ -178,7 +176,7 @@ def display_recommend_result(workload, candidate_indexes, index_cost_total, mult
if multi_iter_mode:
cost_list_pos = index.atomic_pos
else:
cost_list_pos = pos_list[key] + 1
cost_list_pos = key + 1
sql_info = {'sqlDetails': []}
benefit_types = [index.ineffective_pos, index.positive_pos, index.negative_pos]
@ -328,28 +326,37 @@ def generate_candidate_indexes(workload, workload_table_name, db):
if DRIVER:
db.init_conn_handle()
for k, query in enumerate(workload):
if 'select ' in query.statement.lower():
table_index_dict = db.query_index_advisor(query.statement, workload_table_name)
valid_index_dict = get_valid_index_dict(table_index_dict, query, db)
if not re.search(r'(\A|\s)select\s', query.statement.lower()):
continue
table_index_dict = db.query_index_advisor(query.statement, workload_table_name)
valid_index_dict = get_valid_index_dict(table_index_dict, query, db)
# filter duplicate indexes
for table in valid_index_dict.keys():
if table not in index_dict.keys():
index_dict[table] = {}
for columns in valid_index_dict[table]:
if len(workload[k].valid_index_list) >= FULL_ARRANGEMENT_THRESHOLD:
break
workload[k].valid_index_list.append(IndexItem(table, columns))
if not any(re.match(r'%s' % columns, item) for item in index_dict[table]):
column_sql = {columns: [k]}
index_dict[table].update(column_sql)
elif columns in index_dict[table].keys():
index_dict[table][columns].append(k)
# record valid indexes for every sql of workload and generate candidate indexes
for table in valid_index_dict.keys():
if table not in index_dict.keys():
index_dict[table] = {}
for columns in valid_index_dict[table]:
if len(workload[k].valid_index_list) >= FULL_ARRANGEMENT_THRESHOLD:
break
workload[k].valid_index_list.append(IndexItem(table, columns))
if columns in index_dict[table]:
index_dict[table][columns].append(k)
else:
column_sql = {columns: [k]}
index_dict[table].update(column_sql)
# filter redundant indexes for candidate indexes
for table, column_sqls in index_dict.items():
for column, sql in column_sqls.items():
print("table: ", table, "columns: ", column)
candidate_indexes.append(IndexItem(table, column, sql))
sorted_column_sqls = sorted(column_sqls.items(), key=lambda item: item[0])
for i in range(len(sorted_column_sqls) - 1):
if re.match(sorted_column_sqls[i][0], sorted_column_sqls[i+1][0]):
sorted_column_sqls[i+1][1].extend(sorted_column_sqls[i][1])
else:
print("table: ", table, "columns: ", sorted_column_sqls[i][0])
candidate_indexes.append(IndexItem(table, sorted_column_sqls[i][0],
sorted_column_sqls[i][1]))
print("table: ", table, "columns: ", sorted_column_sqls[-1][0])
candidate_indexes.append(
IndexItem(table, sorted_column_sqls[-1][0], sorted_column_sqls[-1][1]))
if DRIVER:
db.close_conn()
return candidate_indexes
@ -502,7 +509,7 @@ def display_last_recommend_result(integrate_indexes, history_invalid_indexes, in
print_header_boundary(" Historical effective indexes ")
for table_name, index_list in integrate_indexes['historyIndexes'].items():
for column in index_list:
index_name = 'idx_' + table_name + '_' + '_'.join(column.split(', '))
index_name = 'idx_' + table_name.split('.')[-1] + '_' + '_'.join(column.split(', '))
statement = 'CREATE INDEX ' + index_name + ' ON ' + table_name + '(' + column + ');'
print(statement)
# display historical invalid indexes
@ -510,7 +517,7 @@ def display_last_recommend_result(integrate_indexes, history_invalid_indexes, in
print_header_boundary(" Historical invalid indexes ")
for table_name, index_list in history_invalid_indexes.items():
for column in index_list:
index_name = 'idx_' + table_name + '_' + '_'.join(column.split(', '))
index_name = 'idx_' + table_name.split('.')[-1] + '_' + '_'.join(column.split(', '))
statement = 'CREATE INDEX ' + index_name + ' ON ' + table_name + '(' + column + ');'
print(statement)
# save integrate indexes result
@ -531,21 +538,23 @@ def check_unused_index_workload(whole_indexes, redundant_indexes, workload_index
unused_index = list(indexes_name.difference(workload_indexes))
remove_list = []
print_header_boundary(" Current workload useless indexes ")
if not unused_index:
print("No useless index!")
detail_info['uselessIndexes'] = []
# useless index
unused_index_columns = dict()
has_unused_index = False
for cur_index in unused_index:
for index in whole_indexes:
if cur_index == index.indexname:
unused_index_columns[cur_index] = index.columns
if 'UNIQUE INDEX' not in index.indexdef:
has_unused_index = True
statement = "DROP INDEX %s;" % index.indexname
print(statement)
useless_index = {"schemaName": index.schema, "tbName": index.table, "type": 3,
"columns": index.columns, "statement": statement}
detail_info['uselessIndexes'].append(useless_index)
if not has_unused_index:
print("No useless index!")
print_header_boundary(" Redundant indexes ")
# filter redundant index
for pos, index in enumerate(redundant_indexes):
@ -599,16 +608,14 @@ def simple_index_advisor(input_path, max_index_num, integrate_indexes, db):
index_cost_total = [ori_total_cost]
for _, obj in enumerate(candidate_indexes):
new_total_cost = db.estimate_workload_cost_file(workload, [obj])
index_cost_total.append(new_total_cost)
obj.benefit = ori_total_cost - new_total_cost
if obj.benefit > 0:
index_cost_total.append(new_total_cost)
if DRIVER:
db.close_conn()
if len(candidate_indexes) == 0:
if len(index_cost_total) == 1:
print("No optimal indexes generated!")
return ori_indexes_name, workload_table_name, display_info, history_invalid_indexes
candidate_indexes = [item for item in candidate_indexes if item.benefit > 0]
candidate_indexes = sorted(enumerate(candidate_indexes),
key=lambda item: item[1].benefit, reverse=True)
global MAX_INDEX_NUM
MAX_INDEX_NUM = max_index_num
# match the last recommendation result
@ -642,6 +649,7 @@ def greedy_determine_opt_config(workload, atomic_config_total, candidate_indexes
if cur_index and cur_min_cost < min_cost:
if MAX_INDEX_STORAGE and sum([obj.storage for obj in opt_config]) + \
cur_index.storage > MAX_INDEX_STORAGE:
candidate_indexes.remove(cur_index)
continue
if len(opt_config) == MAX_INDEX_NUM:
break
@ -681,9 +689,11 @@ def complex_index_advisor(input_path, integrate_indexes, db):
ori_indexes_name))
if DRIVER:
db.close_conn()
opt_config = greedy_determine_opt_config(workload, atomic_config_total,
candidate_indexes, index_cost_total[0])
if MAX_INDEX_STORAGE:
opt_config = MCTS(workload, atomic_config_total, candidate_indexes, MAX_INDEX_STORAGE)
else:
opt_config = greedy_determine_opt_config(workload, atomic_config_total,
candidate_indexes, index_cost_total[0])
if len(opt_config) == 0:
print("No optimal indexes generated!")
return ori_indexes_name, workload_table_name, display_info, history_invalid_indexes

View File

@ -0,0 +1,394 @@
import sys
import math
import random
import copy
STORAGE_THRESHOLD = 0
AVAILABLE_CHOICES = []
ATOMIC_CHOICES = []
WORKLOAD_INFO = []
def is_same_index(index, compared_index):
return index.table == compared_index.table and \
index.columns == compared_index.columns
def atomic_config_is_valid(atomic_config, candidate_indexes):
# if candidate indexes contains all atomic index of atomic_config, then record it
for atomic_index in atomic_config:
is_exist = False
for index in candidate_indexes:
if is_same_index(index, atomic_index):
index.storage = atomic_index.storage
is_exist = True
break
if not is_exist:
return False
return True
def find_subsets_num(choice):
atomic_subsets_num = []
for pos, atomic in enumerate(ATOMIC_CHOICES):
if not atomic or len(atomic) > len(choice):
continue
# find valid atomic index
if atomic_config_is_valid(atomic, choice):
atomic_subsets_num.append(pos)
# find the same atomic index as the candidate index
if len(atomic) == 1 and (is_same_index(choice[-1], atomic[0])):
choice[-1].atomic_pos = pos
return atomic_subsets_num
def find_best_benefit(choice):
atomic_subsets_num = find_subsets_num(choice)
total_benefit = 0
for ind, obj in enumerate(WORKLOAD_INFO):
# calculate the optimal benefit for each sql
max_benefit = 0
for pos in atomic_subsets_num:
if (obj.cost_list[0] - obj.cost_list[pos]) > max_benefit:
max_benefit = obj.cost_list[0] - obj.cost_list[pos]
total_benefit += max_benefit
return total_benefit
def get_diff(available_choices, choices):
except_choices = copy.copy(available_choices)
for i in available_choices:
for j in choices:
if is_same_index(i, j):
except_choices.remove(i)
return except_choices
class State(object):
"""
The game state of the Monte Carlo tree search,
the state data recorded under a certain Node node,
including the current game score, the current number of game rounds,
and the execution record from the beginning to the current.
It is necessary to realize whether the current state has reached the end of the game state,
and support the operation of randomly fetching from the Action collection.
"""
def __init__(self):
self.current_storage = 0.0
self.current_benefit = 0.0
# record the sum of choices up to the current state
self.accumulation_choices = []
# record available choices of current state
self.available_choices = []
self.displayable_choices = []
def get_available_choices(self):
return self.available_choices
def set_available_choices(self, choices):
self.available_choices = choices
def get_current_storage(self):
return self.current_storage
def set_current_storage(self, value):
self.current_storage = value
def get_current_benefit(self):
return self.current_benefit
def set_current_benefit(self, value):
self.current_benefit = value
def get_accumulation_choices(self):
return self.accumulation_choices
def set_accumulation_choices(self, choices):
self.accumulation_choices = choices
def is_terminal(self):
# the current node is a leaf node
return len(self.accumulation_choices) == len(AVAILABLE_CHOICES)
def compute_benefit(self):
return self.current_benefit
def get_next_state_with_random_choice(self):
# ensure that the choices taken are not repeated
if not self.available_choices:
return None
random_choice = random.choice([choice for choice in self.available_choices])
self.available_choices.remove(random_choice)
choice = copy.copy(self.accumulation_choices)
choice.append(random_choice)
benefit = find_best_benefit(choice)
# if current choice not satisfy restrictions, then continue get next choice
if benefit <= self.current_benefit or \
self.current_storage + random_choice.storage > STORAGE_THRESHOLD:
return self.get_next_state_with_random_choice()
next_state = State()
# initialize the properties of the new state
next_state.set_accumulation_choices(choice)
next_state.set_current_benefit(benefit)
next_state.set_current_storage(self.current_storage + random_choice.storage)
next_state.set_available_choices(get_diff(AVAILABLE_CHOICES, choice))
return next_state
def __repr__(self):
self.displayable_choices = ['{}: {}'.format(choice.table, choice.columns)
for choice in self.accumulation_choices]
return "reward: {}, storage :{}, choices: {}".format(
self.current_benefit, self.current_storage, self.displayable_choices)
class Node(object):
"""
The Node of the Monte Carlo tree search tree contains the parent node and
current point information,
which is used to calculate the traversal times and quality value of the UCB,
and the State of the Node selected by the game.
"""
def __init__(self):
self.visit_number = 0
self.quality = 0.0
self.parent = None
self.children = []
self.state = None
def get_parent(self):
return self.parent
def set_parent(self, parent):
self.parent = parent
def get_children(self):
return self.children
def expand_child(self, node):
node.set_parent(self)
self.children.append(node)
def set_state(self, state):
self.state = state
def get_state(self):
return self.state
def get_visit_number(self):
return self.visit_number
def set_visit_number(self, number):
self.visit_number = number
def update_visit_number(self):
self.visit_number += 1
def get_quality_value(self):
return self.quality
def set_quality_value(self, value):
self.quality = value
def update_quality_value(self, reward):
self.quality += reward
def is_all_expand(self):
return len(self.children) == \
len(AVAILABLE_CHOICES) - len(self.get_state().get_accumulation_choices())
def __repr__(self):
return "Node: {}, Q/N: {}/{}, State: {}".format(
hash(self), self.quality, self.visit_number, self.state)
def tree_policy(node):
"""
In the Selection and Expansion stages of Monte Carlo tree search,
the node that needs to be searched (such as the root node) is passed in,
and the best node that needs to be expanded is returned according to the exploration/exploitation algorithm.
Note that if the node is a leaf node, it will be returned directly.
The basic strategy is to first find the child nodes that have not been selected at present,
and select them randomly if there are more than one. If both are selected,
find the one with the largest UCB value that has weighed exploration/exploitation,
and randomly select if the UCB values are equal.
"""
# check if the current node is leaf node
while node and not node.get_state().is_terminal():
if node.is_all_expand():
node = best_child(node, True)
else:
# return the new sub node
sub_node = expand(node)
# when there is no node that satisfies the condition in the remaining nodes,
# this state is empty
if sub_node.get_state():
return sub_node
# return the leaf node
return node
def default_policy(node):
"""
In the Simulation stage of Monte Carlo tree search, input a node that needs to be expanded,
create a new node after random operation, and return the reward of the new node.
Note that the input node should not be a child node,
and there are unexecuted Actions that can be expendable.
The basic strategy is to choose the Action at random.
"""
# get the state of the game
current_state = copy.deepcopy(node.get_state())
# run until the game over
while not current_state.is_terminal():
# pick one random action to play and get next state
next_state = current_state.get_next_state_with_random_choice()
if not next_state:
break
current_state = next_state
final_state_reward = current_state.compute_benefit()
return final_state_reward
def expand(node):
"""
Enter a node, expand a new node on the node, use the random method to execute the Action,
and return the new node. Note that it is necessary to ensure that the newly
added nodes are different from other node Action
"""
new_state = node.get_state().get_next_state_with_random_choice()
sub_node = Node()
sub_node.set_state(new_state)
node.expand_child(sub_node)
return sub_node
def best_child(node, is_exploration):
"""
Using the UCB algorithm,
select the child node with the highest score after weighing the exploration and exploitation.
Note that if it is the prediction stage,
the current Q-value score with the highest score is directly selected.
"""
best_score = -sys.maxsize
best_sub_node = None
# travel all sub nodes to find the best one
for sub_node in node.get_children():
# The children nodes of the node contains the children node whose state is empty,
# this kind of node comes from the node that does not meet the conditions.
if not sub_node.get_state():
continue
# ignore exploration for inference
if is_exploration:
C = 1 / math.sqrt(2.0)
else:
C = 0.0
# UCB = quality / times + C * sqrt(2 * ln(total_times) / times)
left = sub_node.get_quality_value() / sub_node.get_visit_number()
right = 2.0 * math.log(node.get_visit_number()) / sub_node.get_visit_number()
score = left + C * math.sqrt(right)
# get the maximum score, while filtering nodes that do not meet the space constraints and
# nodes that have no revenue
if score > best_score \
and sub_node.get_state().get_current_storage() <= STORAGE_THRESHOLD \
and sub_node.get_state().get_current_benefit() > 0:
best_sub_node = sub_node
best_score = score
return best_sub_node
def backpropagate(node, reward):
"""
In the Backpropagation stage of Monte Carlo tree search,
input the node that needs to be expended and the reward of the newly executed Action,
feed it back to the expend node and all upstream nodes,
and update the corresponding data.
"""
# update util the root node
while node is not None:
# update the visit number
node.update_visit_number()
# update the quality value
node.update_quality_value(reward)
# change the node to the parent node
node = node.parent
def monte_carlo_tree_search(node):
"""
Implement the Monte Carlo tree search algorithm, pass in a root node,
expand new nodes and update data according to the
tree structure that has been explored before in a limited time,
and then return as long as the child node with the highest exploitation.
When making predictions,
you only need to select the node with the largest exploitation according to the Q value,
and find the next optimal node.
"""
computation_budget = len(AVAILABLE_CHOICES) * 3
# run as much as possible under the computation budget
for i in range(computation_budget):
# 1. find the best node to expand
expand_node = tree_policy(node)
if not expand_node:
# when it is None, it means that all nodes are added but no nodes meet the space limit
break
# 2. random get next action and get reward
reward = default_policy(expand_node)
# 3. update all passing nodes with reward
backpropagate(expand_node, reward)
# get the best next node
best_next_node = best_child(node, False)
return best_next_node
def MCTS(workload_info, atomic_choices, available_choices, storage_threshold):
global ATOMIC_CHOICES, STORAGE_THRESHOLD, WORKLOAD_INFO, AVAILABLE_CHOICES
WORKLOAD_INFO = workload_info
AVAILABLE_CHOICES = available_choices
ATOMIC_CHOICES = atomic_choices
STORAGE_THRESHOLD = storage_threshold
# create the initialized state and initialized node
init_state = State()
choices = copy.copy(available_choices)
init_state.set_available_choices(choices)
init_node = Node()
init_node.set_state(init_state)
current_node = init_node
opt_config = []
# set the rounds to play
for i in range(len(AVAILABLE_CHOICES)):
if current_node:
current_node = monte_carlo_tree_search(current_node)
if current_node:
opt_config = current_node.state.accumulation_choices
else:
break
return opt_config