From 19f5109c8d9072b311e38e799d4feb358d49e4bc Mon Sep 17 00:00:00 2001 From: gw115 <877801999@qq.com> Date: Wed, 20 Jul 2022 07:26:14 +0800 Subject: [PATCH] syn code for workload index advisor --- .../index_advisor/index_advisor_workload.py | 147 ++++--- .../tools/components/index_advisor/mcts.py | 397 ------------------ 2 files changed, 72 insertions(+), 472 deletions(-) delete mode 100644 src/gausskernel/dbmind/tools/components/index_advisor/mcts.py diff --git a/src/gausskernel/dbmind/tools/components/index_advisor/index_advisor_workload.py b/src/gausskernel/dbmind/tools/components/index_advisor/index_advisor_workload.py index 224ffca9a..65757b93f 100644 --- a/src/gausskernel/dbmind/tools/components/index_advisor/index_advisor_workload.py +++ b/src/gausskernel/dbmind/tools/components/index_advisor/index_advisor_workload.py @@ -26,11 +26,9 @@ import logging try: from .dao.gsql_execute import GSqlExecute from .dao.execute_factory import ExecuteFactory - from .mcts import MCTS except ImportError: from dao.gsql_execute import GSqlExecute from dao.execute_factory import ExecuteFactory - from mcts import MCTS ENABLE_MULTI_NODE = False SAMPLE_NUM = 5 @@ -44,11 +42,12 @@ JSON_TYPE = False DRIVER = None BLANK = ' ' SQL_TYPE = ['select', 'delete', 'insert', 'update'] -SQL_PATTERN = [r'\((\s*(\d+(\.\d+)?\s*)[,]?)+\)', # match integer set in the IN collection - r'([^\\])\'((\')|(.*?([^\\])\'))', # match all content in single quotes +NUMBER_SET_PARTTERN = r'\((\s*(\-|\+)?\d+(\.\d+)?\s*)(,\s*(\-|\+)?\d+(\.\d+)?\s*)*[,]?\)' +SQL_PATTERN = [r'([^\\])\'((\')|(.*?([^\\])\'))', # match all content in single quotes + NUMBER_SET_PARTTERN, # match integer set in the IN collection r'(([^<>]\s*=\s*)|([^<>]\s+))(\d+)(\.\d+)?'] # match single integer -SQL_DISPLAY_PATTERN = [r'\((\s*(\d+(\.\d+)?\s*)[,]?)+\)', # match integer set in the IN collection - r'\'((\')|(.*?\'))', # match all content in single quotes +SQL_DISPLAY_PATTERN = [r'\'((\')|(.*?\'))', # match all content in single quotes + NUMBER_SET_PARTTERN, # match integer set in the IN collection r'([^\_\d])\d+(\.\d+)?'] # match single integer logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') @@ -118,16 +117,23 @@ class IndexItem: self.index_type = index_type self.is_candidate = False + def update_positive_pos(self, position): + self.positive_pos.append(position) + if position in self.ineffective_pos: + self.ineffective_pos.remove(position) + def __str__(self): return f'{self.table} {self.columns} {self.index_type}' def singleton(cls): instances = {} + def _singleton(*args, **kwargs): if not cls in instances: instances[cls] = cls(*args, **kwargs) return instances[cls] + return _singleton @@ -157,13 +163,13 @@ class IndexAdvisor: def retain_lower_cost_index(self, candidate_indexes): remove_indexes = [] - for i in range(len(candidate_indexes)-1): - if candidate_indexes[i].table != candidate_indexes[i+1].table: + for i in range(len(candidate_indexes) - 1): + if candidate_indexes[i].table != candidate_indexes[i + 1].table: continue - if candidate_indexes[i].columns == candidate_indexes[i+1].columns: + if candidate_indexes[i].columns == candidate_indexes[i + 1].columns: if self.index_cost_total[candidate_indexes[i].atomic_pos] <= \ - self.index_cost_total[candidate_indexes[i+1].atomic_pos]: - remove_indexes.append(i+1) + self.index_cost_total[candidate_indexes[i + 1].atomic_pos]: + remove_indexes.append(i + 1) else: remove_indexes.append(i) for index in remove_indexes[::-1]: @@ -194,12 +200,9 @@ class IndexAdvisor: self.workload_used_index)) if DRIVER: self.db.close_conn() - if MAX_INDEX_STORAGE: - opt_config = MCTS(self.workload_info[0], atomic_config_total, candidate_indexes, - MAX_INDEX_STORAGE, MAX_INDEX_NUM) - else: - opt_config = greedy_determine_opt_config(self.workload_info[0], atomic_config_total, - candidate_indexes, self.index_cost_total[0]) + + opt_config = greedy_determine_opt_config(self.workload_info[0], atomic_config_total, + candidate_indexes, self.index_cost_total[0]) self.retain_lower_cost_index(candidate_indexes) if len(opt_config) == 0: print("No optimal indexes generated!") @@ -208,16 +211,16 @@ class IndexAdvisor: def retain_high_benefit_index(self, candidate_indexes): remove_indexes = [] - for i in range(len(candidate_indexes)-1): + for i in range(len(candidate_indexes) - 1): candidate_indexes[i].cost_pos = i + 1 - if candidate_indexes[i].table != candidate_indexes[i+1].table: + if candidate_indexes[i].table != candidate_indexes[i + 1].table: continue - if candidate_indexes[i].columns == candidate_indexes[i+1].columns: - if candidate_indexes[i].benefit >= candidate_indexes[i+1].benefit: - remove_indexes.append(i+1) + if candidate_indexes[i].columns == candidate_indexes[i + 1].columns: + if candidate_indexes[i].benefit >= candidate_indexes[i + 1].benefit: + remove_indexes.append(i + 1) else: remove_indexes.append(i) - candidate_indexes[len(candidate_indexes)-1].cost_pos = len(candidate_indexes) + candidate_indexes[len(candidate_indexes) - 1].cost_pos = len(candidate_indexes) for index in remove_indexes[::-1]: candidate_indexes.pop(index) @@ -252,8 +255,6 @@ class IndexAdvisor: if len(self.index_cost_total) == 1: print("No optimal indexes generated!") return None - global MAX_INDEX_NUM - MAX_INDEX_NUM = MAX_INDEX_NUM or 10 return candidate_indexes def filter_low_benefit_index(self, opt_indexes): @@ -266,7 +267,7 @@ class IndexAdvisor: # calculate the average benefit of each positive SQL for pos in index.positive_pos: sql_optimzed += 1 - self.workload_info[0][pos].cost_list[cost_list_pos] / \ - self.workload_info[0][pos].cost_list[0] + self.workload_info[0][pos].cost_list[0] negative_sql_ratio = 0 if index.total_sql_num: negative_sql_ratio = (index.insert_sql_num + index.delete_sql_num + @@ -305,7 +306,7 @@ class IndexAdvisor: - sql_info['deleteRatio'], 2) self.display_detail_info['recommendIndexes'].append(sql_info) - def computer_index_optimization_info(self, index, table_name, statement, opt_indexes): + def computer_index_optimization_info(self, index, table_name, statement): if self.multi_iter_mode: cost_list_pos = index.atomic_pos else: @@ -327,15 +328,15 @@ class IndexAdvisor: sql_detail['sql'] = self.workload_info[0][pos].statement sql_detail['sqlCount'] = int(round(sql_count)) if category == 1: - sql_optimzed = (self.workload_info[0][pos].cost_list[0] - - self.workload_info[0][pos].cost_list[cost_list_pos]) / \ - self.workload_info[0][pos].cost_list[cost_list_pos] - sql_detail['optimized'] = '%.3f' % sql_optimzed + sql_optimized = (self.workload_info[0][pos].cost_list[0] - + self.workload_info[0][pos].cost_list[cost_list_pos]) / \ + self.workload_info[0][pos].cost_list[cost_list_pos] + sql_detail['optimized'] = '%.3f' % sql_optimized sql_detail['correlationType'] = category sql_info['sqlDetails'].append(sql_detail) self.record_info(index, sql_info, cost_list_pos, table_name, statement) - def display_advise_indexes_info(self, opt_indexes, show_detail): + def display_advise_indexes_info(self, show_detail): index_current_storage = 0 cnt = 0 self.display_detail_info['recommendIndexes'] = [] @@ -352,16 +353,16 @@ class IndexAdvisor: # display determine indexes table_name = index.table.split('.')[-1] index_name = 'idx_%s_%s%s' % (table_name, (index.index_type - + '_' if index.index_type else '') \ - ,'_'.join(index.columns.split(', '))) + + '_' if index.index_type else '') \ + , '_'.join(index.columns.split(', '))) statement = 'CREATE INDEX %s ON %s%s%s;' % (index_name, index.table, - '(' + index.columns + ')', - (' '+index.index_type if index.index_type else '')) + '(' + index.columns + ')', + (' ' + index.index_type if index.index_type else '')) print(statement) if show_detail: # record detailed SQL optimization information for each index self.computer_index_optimization_info( - index, table_name, statement, opt_indexes) + index, table_name, statement) def generate_incremental_index(self, history_advise_indexes): self.integrate_indexes = copy.copy(history_advise_indexes) @@ -384,16 +385,18 @@ class IndexAdvisor: workload_file_path): def rm_schema(table_name): return table_name.split('.')[-1] + # display historical effective indexes if self.integrate_indexes['historyIndexes']: print_header_boundary(" Historical effective indexes ") for table_name, index_list in self.integrate_indexes['historyIndexes'].items(): for column in index_list: index_name = 'idx_%s_%s%s' % (rm_schema(table_name), - (column[1] + '_' if column[1] else ''), - '_'.join(column[0].split(', '))) + (column[1] + '_' if column[1] else ''), + '_'.join(column[0].split(', '))) statement = 'CREATE INDEX %s ON %s%s%s;' % (index_name, table_name, - '(' + column[0] + ')', (' ' + column[1] if column[1] else '')) + '(' + column[0] + ')', + (' ' + column[1] if column[1] else '')) print(statement) # display historical invalid indexes if history_invalid_indexes: @@ -401,10 +404,11 @@ class IndexAdvisor: for table_name, index_list in history_invalid_indexes.items(): for column in index_list: index_name = 'idx_%s_%s%s' % (rm_schema(table_name), - (column[1] + '_' if column[1] else ''), - '_'.join(column[0].split(', '))) + (column[1] + '_' if column[1] else ''), + '_'.join(column[0].split(', '))) statement = 'CREATE INDEX %s ON %s%s%s;' % (index_name, table_name, - '(' + column[0] + ')', (' ' + column[1] if column[1] else '')) + '(' + column[0] + ')', + (' ' + column[1] if column[1] else '')) print(statement) # save integrate indexes result integrate_indexes_file = os.path.join(os.path.realpath(os.path.dirname(workload_file_path)), @@ -440,7 +444,7 @@ def load_workload(file_path): wd_dict = {} workload = [] global BLANK - with open(file_path, 'r') as file: + with open(file_path, 'r', errors='ignore') as file: raw_text = ''.join(file.readlines()) sqls = raw_text.split(';') for sql in sqls: @@ -486,7 +490,7 @@ def workload_compression(input_path): compressed_workload = [] total_num = 0 if JSON_TYPE: - with open(input_path, 'r') as file: + with open(input_path, 'r', errors='ignore') as file: templates = json.load(file) else: workload = load_workload(input_path) @@ -573,10 +577,10 @@ def get_valid_index_dict(table_index_dict, query, db): def print_candidate_indexes(column_sqls, table, candidate_indexes): if column_sqls[0][1]: - print("table: ", table, "columns: ",column_sqls[0][0], + print("table: ", table, "columns: ", column_sqls[0][0], "type: ", column_sqls[0][1]) else: - print("table: ", table, "columns: ",column_sqls[0][0]) + print("table: ", table, "columns: ", column_sqls[0][0]) if (table, tuple(column_sqls[0][0]), column_sqls[0][1]) not in IndexItemFactory().indexes: index = IndexItemFactory().get_index(table, column_sqls[0][0], 'local') index.index_type = 'global' @@ -595,30 +599,30 @@ def filter_redundant_indexes(index_dict): merged_column_sqls = [] # merge sqls for i in range(len(sorted_column_sqls) - 1): - if re.match(sorted_column_sqls[i][0][0] + ',', sorted_column_sqls[i+1][0][0]) and \ - sorted_column_sqls[i][0][1] == sorted_column_sqls[i+1][0][1]: - sorted_column_sqls[i+1][1].extend(sorted_column_sqls[i][1]) + if re.match(sorted_column_sqls[i][0][0] + ',', sorted_column_sqls[i + 1][0][0]) and \ + sorted_column_sqls[i][0][1] == sorted_column_sqls[i + 1][0][1]: + sorted_column_sqls[i + 1][1].extend(sorted_column_sqls[i][1]) else: merged_column_sqls.append(sorted_column_sqls[i]) else: merged_column_sqls.append(sorted_column_sqls[-1]) # sort using columns merged_column_sqls.sort(key=lambda item: item[0][0]) - for i in range(len(merged_column_sqls)-1): + for i in range(len(merged_column_sqls) - 1): # same columns if merged_column_sqls[i][0][0] == \ - merged_column_sqls[i+1][0][0]: + merged_column_sqls[i + 1][0][0]: print_candidate_indexes(merged_column_sqls[i], - table, - candidate_indexes) + table, + candidate_indexes) continue # left match for the partation table if re.match(merged_column_sqls[i][0][0] + ',', - merged_column_sqls[i+1][0][0]): - merged_column_sqls[i+1][1].extend( - merged_column_sqls[i][1]) - merged_column_sqls[i+1] = ((merged_column_sqls[i+1][0][0], 'global'), - merged_column_sqls[i+1][1]) + merged_column_sqls[i + 1][0][0]): + merged_column_sqls[i + 1][1].extend( + merged_column_sqls[i][1]) + merged_column_sqls[i + 1] = ((merged_column_sqls[i + 1][0][0], 'global'), + merged_column_sqls[i + 1][1]) continue print_candidate_indexes(merged_column_sqls[i], table, candidate_indexes) else: @@ -640,7 +644,7 @@ def filter_duplicate_indexes(valid_index_dict, index_dict, workload, pos): column_sql = {(columns, index_type): [pos]} index_dict[table].update(column_sql) workload[pos].valid_index_list.append( - IndexItemFactory().get_index(table, columns, index_type=index_type)) + IndexItemFactory().get_index(table, columns, index_type=index_type)) def generate_candidate_indexes(workload, db): @@ -761,13 +765,10 @@ def find_subsets_num(config, atomic_config_total): # infer the total cost of workload for a config according to the cost of atomic configs def infer_workload_cost(workload, config, atomic_config_total): total_cost = 0 - is_computed = False atomic_subsets_num, cur_index_atomic_pos = find_subsets_num( config, atomic_config_total) if len(atomic_subsets_num) == 0: raise ValueError("No atomic configs found for current config!") - if not config[-1].total_sql_num: - is_computed = True for ind, obj in enumerate(workload): if max(atomic_subsets_num) >= len(obj.cost_list): raise ValueError("Wrong atomic config for current query!") @@ -778,10 +779,6 @@ def infer_workload_cost(workload, config, atomic_config_total): min_cost = obj.cost_list[num] total_cost += min_cost - # record ineffective sql and negative sql for candidate indexes - if is_computed: - ExecuteFactory.record_ineffective_negative_sql( - config[-1], obj, ind) return total_cost, cur_index_atomic_pos @@ -812,8 +809,8 @@ def display_redundant_indexes(redundant_indexes, unused_index_columns, remove_li # redundant objects are not in the useless index set or # both redundant objects and redundant index in useless index must be redundant index index_exist = redundant_obj.indexname not in unused_index_columns.keys() or \ - (unused_index_columns.get(redundant_obj.indexname) and - unused_index_columns.get(index.indexname)) + (unused_index_columns.get(redundant_obj.indexname) and + unused_index_columns.get(index.indexname)) if index_exist: is_redundant = True if not is_redundant: @@ -826,7 +823,7 @@ def display_redundant_indexes(redundant_indexes, unused_index_columns, remove_li # redundant index for index in redundant_indexes: statement = "DROP INDEX %s.%s;" % \ - (index.schema, index.indexname) + (index.schema, index.indexname) print(statement) existing_index = [item.indexname + ':' + item.columns for item in index.redundant_obj] @@ -911,7 +908,7 @@ def get_last_indexes_result(input_path): integrate_indexes = {'historyIndexes': {}} if os.path.exists(last_indexes_result_file): try: - with open(last_indexes_result_file, 'r') as file: + with open(last_indexes_result_file, 'r', errors='ignore') as file: integrate_indexes['historyIndexes'] = json.load(file) except json.JSONDecodeError: return integrate_indexes @@ -928,7 +925,7 @@ def index_advisor_workload(history_advise_indexes, db, workload_file_path, opt_indexes = index_advisor.simple_index_advisor() if opt_indexes: index_advisor.filter_low_benefit_index(opt_indexes) - index_advisor.display_advise_indexes_info(opt_indexes, show_detail) + index_advisor.display_advise_indexes_info(show_detail) index_advisor.generate_incremental_index(history_advise_indexes) history_invalid_indexes = {} @@ -948,7 +945,7 @@ def check_parameter(args): raise argparse.ArgumentTypeError("%s is an invalid positive int value" % args.max_index_num) if args.max_index_storage is not None and args.max_index_storage <= 0: - raise argparse.ArgumentTypeError("%s is an invalid positive float value" % + raise argparse.ArgumentTypeError("%s is an invalid positive int value" % args.max_index_storage) JSON_TYPE = args.json MAX_INDEX_NUM = args.max_index_num @@ -966,7 +963,7 @@ def main(argv): arg_parser.add_argument("p", help="Port of database", type=int) arg_parser.add_argument("d", help="Name of database", action=CheckValid) arg_parser.add_argument( - "--h", help="Host for database", action=CheckValid) + "--h", help="Host for database", action=CheckValid) arg_parser.add_argument( "-U", help="Username for database log-in", action=CheckValid) arg_parser.add_argument( @@ -976,7 +973,7 @@ def main(argv): arg_parser.add_argument( "--max_index_num", help="Maximum number of suggested indexes", type=int) arg_parser.add_argument("--max_index_storage", - help="Maximum storage of suggested indexes/MB", type=float) + help="Maximum storage of suggested indexes/MB", type=int) arg_parser.add_argument("--multi_iter_mode", action='store_true', help="Whether to use multi-iteration algorithm", default=False) arg_parser.add_argument("--multi_node", action='store_true', diff --git a/src/gausskernel/dbmind/tools/components/index_advisor/mcts.py b/src/gausskernel/dbmind/tools/components/index_advisor/mcts.py deleted file mode 100644 index 760b42763..000000000 --- a/src/gausskernel/dbmind/tools/components/index_advisor/mcts.py +++ /dev/null @@ -1,397 +0,0 @@ -import sys -import math -import random -import copy - -STORAGE_THRESHOLD = 0 -AVAILABLE_CHOICES = None -ATOMIC_CHOICES = None -WORKLOAD_INFO = None -MAX_INDEX_NUM = 0 - - -def is_same_index(index, compared_index): - return index.table == compared_index.table and \ - index.columns == compared_index.columns and \ - index.index_type == compared_index.index_type - - -def atomic_config_is_valid(atomic_config, config): - # if candidate indexes contains all atomic index of current config1, then record it - for atomic_index in atomic_config: - is_exist = False - for index in config: - if is_same_index(index, atomic_index): - index.storage = atomic_index.storage - is_exist = True - break - if not is_exist: - return False - return True - - -def find_subsets_num(choice): - atomic_subsets_num = [] - for pos, atomic in enumerate(ATOMIC_CHOICES): - if not atomic or len(atomic) > len(choice): - continue - # find valid atomic index - if atomic_config_is_valid(atomic, choice): - atomic_subsets_num.append(pos) - # find the same atomic index as the candidate index - if len(atomic) == 1 and (is_same_index(choice[-1], atomic[0])): - choice[-1].atomic_pos = pos - return atomic_subsets_num - - -def find_best_benefit(choice): - atomic_subsets_num = find_subsets_num(choice) - total_benefit = 0 - for ind, obj in enumerate(WORKLOAD_INFO): - # calculate the best benefit for the current sql - max_benefit = 0 - for pos in atomic_subsets_num: - if (obj.cost_list[0] - obj.cost_list[pos]) > max_benefit: - max_benefit = obj.cost_list[0] - obj.cost_list[pos] - total_benefit += max_benefit - return total_benefit - - -def get_diff(available_choices, choices): - except_choices = copy.copy(available_choices) - for i in available_choices: - for j in choices: - if is_same_index(i, j): - except_choices.remove(i) - return except_choices - - -class State(object): - """ - The game state of the Monte Carlo tree search, - the state data recorded under a certain Node node, - including the current game score, the current number of game rounds, - and the execution record from the beginning to the current. - - It is necessary to realize whether the current state has reached the end of the game state, - and support the operation of randomly fetching from the Action collection. - """ - - def __init__(self): - self.current_storage = 0.0 - self.current_benefit = 0.0 - # record the sum of choices up to the current state - self.accumulation_choices = [] - # record available choices of current state - self.available_choices = [] - self.displayable_choices = [] - - def get_available_choices(self): - return self.available_choices - - def set_available_choices(self, choices): - self.available_choices = choices - - def get_current_storage(self): - return self.current_storage - - def set_current_storage(self, value): - self.current_storage = value - - def get_current_benefit(self): - return self.current_benefit - - def set_current_benefit(self, value): - self.current_benefit = value - - def get_accumulation_choices(self): - return self.accumulation_choices - - def set_accumulation_choices(self, choices): - self.accumulation_choices = choices - - def is_terminal(self): - # the current node is a leaf node - return len(self.accumulation_choices) == MAX_INDEX_NUM - - def compute_benefit(self): - return self.current_benefit - - def get_next_state_with_random_choice(self): - # ensure that the choices taken are not repeated - if not self.available_choices: - return None - random_choice = random.choice([choice for choice in self.available_choices]) - self.available_choices.remove(random_choice) - choice = copy.copy(self.accumulation_choices) - choice.append(random_choice) - benefit = find_best_benefit(choice) - # if current choice not satisfy restrictions, then continue get next choice - if benefit <= self.current_benefit or \ - self.current_storage + random_choice.storage > STORAGE_THRESHOLD: - return self.get_next_state_with_random_choice() - - next_state = State() - # initialize the properties of the new state - next_state.set_accumulation_choices(choice) - next_state.set_current_benefit(benefit) - next_state.set_current_storage(self.current_storage + random_choice.storage) - next_state.set_available_choices(get_diff(AVAILABLE_CHOICES, choice)) - return next_state - - def __repr__(self): - self.displayable_choices = ['{}: {}'.format(choice.table, choice.columns) - for choice in self.accumulation_choices] - return "reward: {}, storage :{}, choices: {}".format( - self.current_benefit, self.current_storage, self.displayable_choices) - - -class Node(object): - """ - The Node of the Monte Carlo tree search tree contains the parent node and - current point information, - which is used to calculate the traversal times and quality value of the UCB, - and the State of the Node selected by the game. - """ - def __init__(self): - self.visit_number = 0 - self.quality = 0.0 - - self.parent = None - self.children = [] - self.state = None - - def get_parent(self): - return self.parent - - def set_parent(self, parent): - self.parent = parent - - def get_children(self): - return self.children - - def expand_child(self, node): - node.set_parent(self) - self.children.append(node) - - def set_state(self, state): - self.state = state - - def get_state(self): - return self.state - - def get_visit_number(self): - return self.visit_number - - def set_visit_number(self, number): - self.visit_number = number - - def update_visit_number(self): - self.visit_number += 1 - - def get_quality_value(self): - return self.quality - - def set_quality_value(self, value): - self.quality = value - - def update_quality_value(self, reward): - self.quality += reward - - def is_all_expand(self): - return len(self.children) == \ - len(AVAILABLE_CHOICES) - len(self.get_state().get_accumulation_choices()) - - def __repr__(self): - return "Node: {}, Q/N: {}/{}, State: {}".format( - hash(self), self.quality, self.visit_number, self.state) - - -def tree_policy(node): - """ - In the Selection and Expansion stages of Monte Carlo tree search, - the node that needs to be searched (such as the root node) is passed in, - and the best node that needs to be expanded is returned - according to the exploration/exploitation algorithm. - Note that if the node is a leaf node, it will be returned directly. - - The basic strategy is to first find the child nodes that have not been selected at present, - and select them randomly if there are more than one. If both are selected, - find the one with the largest UCB value that has weighed exploration/exploitation, - and randomly select if the UCB values are equal. - """ - - # check if the current node is leaf node - while node and not node.get_state().is_terminal(): - - if node.is_all_expand(): - node = best_child(node, True) - else: - # return the new sub node - sub_node = expand(node) - # when there is no node that satisfies the condition in the remaining nodes, - # this state is empty - if sub_node.get_state(): - return sub_node - - # return the leaf node - return node - - -def default_policy(node): - """ - In the Simulation stage of Monte Carlo tree search, input a node that needs to be expanded, - create a new node after random operation, and return the reward of the new node. - Note that the input node should not be a child node, - and there are unexecuted Actions that can be expendable. - - The basic strategy is to choose the Action at random. - """ - - # get the state of the game - current_state = copy.deepcopy(node.get_state()) - - # run until the game over - while not current_state.is_terminal(): - # pick one random action to play and get next state - next_state = current_state.get_next_state_with_random_choice() - if not next_state: - break - current_state = next_state - - final_state_reward = current_state.compute_benefit() - return final_state_reward - - -def expand(node): - """ - Enter a node, expand a new node on the node, use the random method to execute the Action, - and return the new node. Note that it is necessary to ensure that the newly - added nodes are different from other node Action - """ - - new_state = node.get_state().get_next_state_with_random_choice() - sub_node = Node() - sub_node.set_state(new_state) - node.expand_child(sub_node) - - return sub_node - - -def best_child(node, is_exploration): - """ - Using the UCB algorithm, - select the child node with the highest score after weighing the exploration and exploitation. - Note that if it is the prediction stage, - the current Q-value score with the highest score is directly selected. - """ - - best_score = -sys.maxsize - best_sub_node = None - - # travel all sub nodes to find the best one - for sub_node in node.get_children(): - # The children nodes of the node contains the children node whose state is empty, - # this kind of node comes from the node that does not meet the conditions. - if not sub_node.get_state(): - continue - # ignore exploration for inference - if is_exploration: - C = 1 / math.sqrt(2.0) - else: - C = 0.0 - - # UCB = quality / times + C * sqrt(2 * ln(total_times) / times) - left = sub_node.get_quality_value() / sub_node.get_visit_number() - right = 2.0 * math.log(node.get_visit_number()) / sub_node.get_visit_number() - score = left + C * math.sqrt(right) - # get the maximum score, while filtering nodes that do not meet the space constraints and - # nodes that have no revenue - if score > best_score \ - and sub_node.get_state().get_current_storage() <= STORAGE_THRESHOLD \ - and sub_node.get_state().get_current_benefit() > 0: - best_sub_node = sub_node - best_score = score - - return best_sub_node - - -def backpropagate(node, reward): - """ - In the Backpropagation stage of Monte Carlo tree search, - input the node that needs to be expended and the reward of the newly executed Action, - feed it back to the expend node and all upstream nodes, - and update the corresponding data. - """ - - # update util the root node - while node is not None: - # update the visit number - node.update_visit_number() - - # update the quality value - node.update_quality_value(reward) - - # change the node to the parent node - node = node.parent - - -def monte_carlo_tree_search(node): - """ - Implement the Monte Carlo tree search algorithm, pass in a root node, - expand new nodes and update data according to the - tree structure that has been explored before in a limited time, - and then return as long as the child node with the highest exploitation. - - When making predictions, - you only need to select the node with the largest exploitation according to the Q value, - and find the next optimal node. - """ - - computation_budget = len(AVAILABLE_CHOICES) * 3 - - # run as much as possible under the computation budget - for i in range(computation_budget): - # 1. find the best node to expand - expand_node = tree_policy(node) - if not expand_node: - # when it is None, it means that all nodes are added but no nodes meet the space limit - break - # 2. random get next action and get reward - reward = default_policy(expand_node) - - # 3. update all passing nodes with reward - backpropagate(expand_node, reward) - - # get the best next node - best_next_node = best_child(node, False) - - return best_next_node - - -def MCTS(workload_info, atomic_choices, available_choices, storage_threshold, max_index_num): - global ATOMIC_CHOICES, STORAGE_THRESHOLD, WORKLOAD_INFO, AVAILABLE_CHOICES, MAX_INDEX_NUM - WORKLOAD_INFO = workload_info - AVAILABLE_CHOICES = available_choices - ATOMIC_CHOICES = atomic_choices - STORAGE_THRESHOLD = storage_threshold - MAX_INDEX_NUM = max_index_num if max_index_num else len(available_choices) - - # create the initialized state and initialized node - init_state = State() - choices = copy.copy(available_choices) - init_state.set_available_choices(choices) - init_node = Node() - init_node.set_state(init_state) - current_node = init_node - - opt_config = [] - # set the rounds to play - for i in range(len(AVAILABLE_CHOICES)): - if current_node: - current_node = monte_carlo_tree_search(current_node) - if current_node: - opt_config = current_node.state.accumulation_choices - else: - break - return opt_config