Implementing Monte Carlo Tree Search Algorithm
This commit is contained in:
@ -31,7 +31,9 @@ class ExecuteFactory:
|
||||
for pos, index in enumerate(cur_table_indexes[:-1]):
|
||||
is_redundant = False
|
||||
for candidate_index in cur_table_indexes[pos + 1:]:
|
||||
if re.match(r'%s' % index.columns, candidate_index.columns):
|
||||
existed_index1 = list(map(str.strip, index.columns.split(',')))
|
||||
existed_index2 = list(map(str.strip, candidate_index.columns.split(',')))
|
||||
if existed_index1 == existed_index2[0: len(existed_index1)]:
|
||||
is_redundant = True
|
||||
index.redundant_obj.append(candidate_index)
|
||||
if is_redundant:
|
||||
@ -97,14 +99,19 @@ class ExecuteFactory:
|
||||
candidate_index.select_sql_num += obj.frequency
|
||||
# SELECT scenes to filter out positive
|
||||
if ind not in candidate_index.positive_pos and \
|
||||
any(column in obj.statement.lower() for column in candidate_index.columns):
|
||||
any(re.search(r'\b%s\b' % column, obj.statement.lower())
|
||||
for column in candidate_index.columns.split(', ')):
|
||||
candidate_index.ineffective_pos.append(ind)
|
||||
candidate_index.total_sql_num += obj.frequency
|
||||
|
||||
@staticmethod
|
||||
def match_last_result(table_name, index_column, history_indexes, history_invalid_indexes):
|
||||
for column in history_indexes.get(table_name, dict()):
|
||||
if re.match(r'%s' % column, index_column):
|
||||
history_index_column = list(map(str.strip, column.split(',')))
|
||||
existed_index_column = list(map(str.strip, index_column.split(',')))
|
||||
if len(history_index_column) > len(existed_index_column):
|
||||
continue
|
||||
if history_index_column == existed_index_column[0:len(history_index_column)]:
|
||||
history_indexes[table_name].remove(column)
|
||||
history_invalid_indexes[table_name] = history_invalid_indexes.get(table_name, list())
|
||||
history_invalid_indexes[table_name].append(column)
|
||||
|
||||
@ -22,7 +22,9 @@ import re
|
||||
import json
|
||||
import select
|
||||
import logging
|
||||
|
||||
from DAO.gsql_execute import GSqlExecute
|
||||
from mcts import MCTS
|
||||
|
||||
ENABLE_MULTI_NODE = False
|
||||
SAMPLE_NUM = 5
|
||||
@ -123,15 +125,15 @@ def print_header_boundary(header):
|
||||
print(green(title))
|
||||
|
||||
|
||||
def filter_low_benefit(pos_list, candidate_indexes, multi_iter_mode, workload):
|
||||
def filter_low_benefit(candidate_indexes, multi_iter_mode, workload):
|
||||
remove_list = []
|
||||
for key, index in enumerate(candidate_indexes):
|
||||
sql_optimzed = 0
|
||||
if multi_iter_mode:
|
||||
cost_list_pos = index.atomic_pos
|
||||
else:
|
||||
cost_list_pos = key + 1
|
||||
for ind, pos in enumerate(index.positive_pos):
|
||||
if multi_iter_mode:
|
||||
cost_list_pos = index.atomic_pos
|
||||
else:
|
||||
cost_list_pos = pos_list[key] + 1
|
||||
sql_optimzed += 1 - workload[pos].cost_list[cost_list_pos] / workload[pos].cost_list[0]
|
||||
negative_ratio = ((index.insert_sql_num + index.delete_sql_num +
|
||||
index.update_sql_num) / index.total_sql_num) if index.total_sql_num else 0
|
||||
@ -148,12 +150,8 @@ def display_recommend_result(workload, candidate_indexes, index_cost_total, mult
|
||||
display_info, integrate_indexes, history_invalid_indexes):
|
||||
cnt = 0
|
||||
index_current_storage = 0
|
||||
pos_list = []
|
||||
if not multi_iter_mode:
|
||||
pos_list = [item[0] for item in candidate_indexes]
|
||||
candidate_indexes = [item[1] for item in candidate_indexes]
|
||||
# filter candidate indexes with low benefit
|
||||
filter_low_benefit(pos_list, candidate_indexes, multi_iter_mode, workload)
|
||||
filter_low_benefit(candidate_indexes, multi_iter_mode, workload)
|
||||
# display determine result
|
||||
integrate_indexes['currentIndexes'] = dict()
|
||||
for key, index in enumerate(candidate_indexes):
|
||||
@ -178,7 +176,7 @@ def display_recommend_result(workload, candidate_indexes, index_cost_total, mult
|
||||
if multi_iter_mode:
|
||||
cost_list_pos = index.atomic_pos
|
||||
else:
|
||||
cost_list_pos = pos_list[key] + 1
|
||||
cost_list_pos = key + 1
|
||||
|
||||
sql_info = {'sqlDetails': []}
|
||||
benefit_types = [index.ineffective_pos, index.positive_pos, index.negative_pos]
|
||||
@ -328,28 +326,37 @@ def generate_candidate_indexes(workload, workload_table_name, db):
|
||||
if DRIVER:
|
||||
db.init_conn_handle()
|
||||
for k, query in enumerate(workload):
|
||||
if 'select ' in query.statement.lower():
|
||||
table_index_dict = db.query_index_advisor(query.statement, workload_table_name)
|
||||
valid_index_dict = get_valid_index_dict(table_index_dict, query, db)
|
||||
if not re.search(r'(\A|\s)select\s', query.statement.lower()):
|
||||
continue
|
||||
table_index_dict = db.query_index_advisor(query.statement, workload_table_name)
|
||||
valid_index_dict = get_valid_index_dict(table_index_dict, query, db)
|
||||
|
||||
# filter duplicate indexes
|
||||
for table in valid_index_dict.keys():
|
||||
if table not in index_dict.keys():
|
||||
index_dict[table] = {}
|
||||
for columns in valid_index_dict[table]:
|
||||
if len(workload[k].valid_index_list) >= FULL_ARRANGEMENT_THRESHOLD:
|
||||
break
|
||||
workload[k].valid_index_list.append(IndexItem(table, columns))
|
||||
if not any(re.match(r'%s' % columns, item) for item in index_dict[table]):
|
||||
column_sql = {columns: [k]}
|
||||
index_dict[table].update(column_sql)
|
||||
|
||||
elif columns in index_dict[table].keys():
|
||||
index_dict[table][columns].append(k)
|
||||
# record valid indexes for every sql of workload and generate candidate indexes
|
||||
for table in valid_index_dict.keys():
|
||||
if table not in index_dict.keys():
|
||||
index_dict[table] = {}
|
||||
for columns in valid_index_dict[table]:
|
||||
if len(workload[k].valid_index_list) >= FULL_ARRANGEMENT_THRESHOLD:
|
||||
break
|
||||
workload[k].valid_index_list.append(IndexItem(table, columns))
|
||||
if columns in index_dict[table]:
|
||||
index_dict[table][columns].append(k)
|
||||
else:
|
||||
column_sql = {columns: [k]}
|
||||
index_dict[table].update(column_sql)
|
||||
# filter redundant indexes for candidate indexes
|
||||
for table, column_sqls in index_dict.items():
|
||||
for column, sql in column_sqls.items():
|
||||
print("table: ", table, "columns: ", column)
|
||||
candidate_indexes.append(IndexItem(table, column, sql))
|
||||
sorted_column_sqls = sorted(column_sqls.items(), key=lambda item: item[0])
|
||||
for i in range(len(sorted_column_sqls) - 1):
|
||||
if re.match(sorted_column_sqls[i][0], sorted_column_sqls[i+1][0]):
|
||||
sorted_column_sqls[i+1][1].extend(sorted_column_sqls[i][1])
|
||||
else:
|
||||
print("table: ", table, "columns: ", sorted_column_sqls[i][0])
|
||||
candidate_indexes.append(IndexItem(table, sorted_column_sqls[i][0],
|
||||
sorted_column_sqls[i][1]))
|
||||
print("table: ", table, "columns: ", sorted_column_sqls[-1][0])
|
||||
candidate_indexes.append(
|
||||
IndexItem(table, sorted_column_sqls[-1][0], sorted_column_sqls[-1][1]))
|
||||
if DRIVER:
|
||||
db.close_conn()
|
||||
return candidate_indexes
|
||||
@ -502,7 +509,7 @@ def display_last_recommend_result(integrate_indexes, history_invalid_indexes, in
|
||||
print_header_boundary(" Historical effective indexes ")
|
||||
for table_name, index_list in integrate_indexes['historyIndexes'].items():
|
||||
for column in index_list:
|
||||
index_name = 'idx_' + table_name + '_' + '_'.join(column.split(', '))
|
||||
index_name = 'idx_' + table_name.split('.')[-1] + '_' + '_'.join(column.split(', '))
|
||||
statement = 'CREATE INDEX ' + index_name + ' ON ' + table_name + '(' + column + ');'
|
||||
print(statement)
|
||||
# display historical invalid indexes
|
||||
@ -510,7 +517,7 @@ def display_last_recommend_result(integrate_indexes, history_invalid_indexes, in
|
||||
print_header_boundary(" Historical invalid indexes ")
|
||||
for table_name, index_list in history_invalid_indexes.items():
|
||||
for column in index_list:
|
||||
index_name = 'idx_' + table_name + '_' + '_'.join(column.split(', '))
|
||||
index_name = 'idx_' + table_name.split('.')[-1] + '_' + '_'.join(column.split(', '))
|
||||
statement = 'CREATE INDEX ' + index_name + ' ON ' + table_name + '(' + column + ');'
|
||||
print(statement)
|
||||
# save integrate indexes result
|
||||
@ -531,21 +538,23 @@ def check_unused_index_workload(whole_indexes, redundant_indexes, workload_index
|
||||
unused_index = list(indexes_name.difference(workload_indexes))
|
||||
remove_list = []
|
||||
print_header_boundary(" Current workload useless indexes ")
|
||||
if not unused_index:
|
||||
print("No useless index!")
|
||||
detail_info['uselessIndexes'] = []
|
||||
# useless index
|
||||
unused_index_columns = dict()
|
||||
has_unused_index = False
|
||||
for cur_index in unused_index:
|
||||
for index in whole_indexes:
|
||||
if cur_index == index.indexname:
|
||||
unused_index_columns[cur_index] = index.columns
|
||||
if 'UNIQUE INDEX' not in index.indexdef:
|
||||
has_unused_index = True
|
||||
statement = "DROP INDEX %s;" % index.indexname
|
||||
print(statement)
|
||||
useless_index = {"schemaName": index.schema, "tbName": index.table, "type": 3,
|
||||
"columns": index.columns, "statement": statement}
|
||||
detail_info['uselessIndexes'].append(useless_index)
|
||||
if not has_unused_index:
|
||||
print("No useless index!")
|
||||
print_header_boundary(" Redundant indexes ")
|
||||
# filter redundant index
|
||||
for pos, index in enumerate(redundant_indexes):
|
||||
@ -599,16 +608,14 @@ def simple_index_advisor(input_path, max_index_num, integrate_indexes, db):
|
||||
index_cost_total = [ori_total_cost]
|
||||
for _, obj in enumerate(candidate_indexes):
|
||||
new_total_cost = db.estimate_workload_cost_file(workload, [obj])
|
||||
index_cost_total.append(new_total_cost)
|
||||
obj.benefit = ori_total_cost - new_total_cost
|
||||
if obj.benefit > 0:
|
||||
index_cost_total.append(new_total_cost)
|
||||
if DRIVER:
|
||||
db.close_conn()
|
||||
if len(candidate_indexes) == 0:
|
||||
if len(index_cost_total) == 1:
|
||||
print("No optimal indexes generated!")
|
||||
return ori_indexes_name, workload_table_name, display_info, history_invalid_indexes
|
||||
candidate_indexes = [item for item in candidate_indexes if item.benefit > 0]
|
||||
candidate_indexes = sorted(enumerate(candidate_indexes),
|
||||
key=lambda item: item[1].benefit, reverse=True)
|
||||
global MAX_INDEX_NUM
|
||||
MAX_INDEX_NUM = max_index_num
|
||||
# match the last recommendation result
|
||||
@ -642,6 +649,7 @@ def greedy_determine_opt_config(workload, atomic_config_total, candidate_indexes
|
||||
if cur_index and cur_min_cost < min_cost:
|
||||
if MAX_INDEX_STORAGE and sum([obj.storage for obj in opt_config]) + \
|
||||
cur_index.storage > MAX_INDEX_STORAGE:
|
||||
candidate_indexes.remove(cur_index)
|
||||
continue
|
||||
if len(opt_config) == MAX_INDEX_NUM:
|
||||
break
|
||||
@ -681,9 +689,11 @@ def complex_index_advisor(input_path, integrate_indexes, db):
|
||||
ori_indexes_name))
|
||||
if DRIVER:
|
||||
db.close_conn()
|
||||
|
||||
opt_config = greedy_determine_opt_config(workload, atomic_config_total,
|
||||
candidate_indexes, index_cost_total[0])
|
||||
if MAX_INDEX_STORAGE:
|
||||
opt_config = MCTS(workload, atomic_config_total, candidate_indexes, MAX_INDEX_STORAGE)
|
||||
else:
|
||||
opt_config = greedy_determine_opt_config(workload, atomic_config_total,
|
||||
candidate_indexes, index_cost_total[0])
|
||||
if len(opt_config) == 0:
|
||||
print("No optimal indexes generated!")
|
||||
return ori_indexes_name, workload_table_name, display_info, history_invalid_indexes
|
||||
|
||||
394
src/gausskernel/dbmind/tools/index_advisor/mcts.py
Normal file
394
src/gausskernel/dbmind/tools/index_advisor/mcts.py
Normal file
@ -0,0 +1,394 @@
|
||||
import sys
|
||||
import math
|
||||
import random
|
||||
import copy
|
||||
|
||||
|
||||
STORAGE_THRESHOLD = 0
|
||||
AVAILABLE_CHOICES = []
|
||||
ATOMIC_CHOICES = []
|
||||
WORKLOAD_INFO = []
|
||||
|
||||
|
||||
def is_same_index(index, compared_index):
|
||||
return index.table == compared_index.table and \
|
||||
index.columns == compared_index.columns
|
||||
|
||||
|
||||
def atomic_config_is_valid(atomic_config, candidate_indexes):
|
||||
# if candidate indexes contains all atomic index of atomic_config, then record it
|
||||
for atomic_index in atomic_config:
|
||||
is_exist = False
|
||||
for index in candidate_indexes:
|
||||
if is_same_index(index, atomic_index):
|
||||
index.storage = atomic_index.storage
|
||||
is_exist = True
|
||||
break
|
||||
if not is_exist:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def find_subsets_num(choice):
|
||||
atomic_subsets_num = []
|
||||
for pos, atomic in enumerate(ATOMIC_CHOICES):
|
||||
if not atomic or len(atomic) > len(choice):
|
||||
continue
|
||||
# find valid atomic index
|
||||
if atomic_config_is_valid(atomic, choice):
|
||||
atomic_subsets_num.append(pos)
|
||||
# find the same atomic index as the candidate index
|
||||
if len(atomic) == 1 and (is_same_index(choice[-1], atomic[0])):
|
||||
choice[-1].atomic_pos = pos
|
||||
return atomic_subsets_num
|
||||
|
||||
|
||||
def find_best_benefit(choice):
|
||||
atomic_subsets_num = find_subsets_num(choice)
|
||||
total_benefit = 0
|
||||
for ind, obj in enumerate(WORKLOAD_INFO):
|
||||
# calculate the optimal benefit for each sql
|
||||
max_benefit = 0
|
||||
for pos in atomic_subsets_num:
|
||||
if (obj.cost_list[0] - obj.cost_list[pos]) > max_benefit:
|
||||
max_benefit = obj.cost_list[0] - obj.cost_list[pos]
|
||||
total_benefit += max_benefit
|
||||
return total_benefit
|
||||
|
||||
|
||||
def get_diff(available_choices, choices):
|
||||
except_choices = copy.copy(available_choices)
|
||||
for i in available_choices:
|
||||
for j in choices:
|
||||
if is_same_index(i, j):
|
||||
except_choices.remove(i)
|
||||
return except_choices
|
||||
|
||||
|
||||
class State(object):
|
||||
"""
|
||||
The game state of the Monte Carlo tree search,
|
||||
the state data recorded under a certain Node node,
|
||||
including the current game score, the current number of game rounds,
|
||||
and the execution record from the beginning to the current.
|
||||
|
||||
It is necessary to realize whether the current state has reached the end of the game state,
|
||||
and support the operation of randomly fetching from the Action collection.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.current_storage = 0.0
|
||||
self.current_benefit = 0.0
|
||||
# record the sum of choices up to the current state
|
||||
self.accumulation_choices = []
|
||||
# record available choices of current state
|
||||
self.available_choices = []
|
||||
self.displayable_choices = []
|
||||
|
||||
def get_available_choices(self):
|
||||
return self.available_choices
|
||||
|
||||
def set_available_choices(self, choices):
|
||||
self.available_choices = choices
|
||||
|
||||
def get_current_storage(self):
|
||||
return self.current_storage
|
||||
|
||||
def set_current_storage(self, value):
|
||||
self.current_storage = value
|
||||
|
||||
def get_current_benefit(self):
|
||||
return self.current_benefit
|
||||
|
||||
def set_current_benefit(self, value):
|
||||
self.current_benefit = value
|
||||
|
||||
def get_accumulation_choices(self):
|
||||
return self.accumulation_choices
|
||||
|
||||
def set_accumulation_choices(self, choices):
|
||||
self.accumulation_choices = choices
|
||||
|
||||
def is_terminal(self):
|
||||
# the current node is a leaf node
|
||||
return len(self.accumulation_choices) == len(AVAILABLE_CHOICES)
|
||||
|
||||
def compute_benefit(self):
|
||||
return self.current_benefit
|
||||
|
||||
def get_next_state_with_random_choice(self):
|
||||
# ensure that the choices taken are not repeated
|
||||
if not self.available_choices:
|
||||
return None
|
||||
random_choice = random.choice([choice for choice in self.available_choices])
|
||||
self.available_choices.remove(random_choice)
|
||||
choice = copy.copy(self.accumulation_choices)
|
||||
choice.append(random_choice)
|
||||
benefit = find_best_benefit(choice)
|
||||
# if current choice not satisfy restrictions, then continue get next choice
|
||||
if benefit <= self.current_benefit or \
|
||||
self.current_storage + random_choice.storage > STORAGE_THRESHOLD:
|
||||
return self.get_next_state_with_random_choice()
|
||||
|
||||
next_state = State()
|
||||
# initialize the properties of the new state
|
||||
next_state.set_accumulation_choices(choice)
|
||||
next_state.set_current_benefit(benefit)
|
||||
next_state.set_current_storage(self.current_storage + random_choice.storage)
|
||||
next_state.set_available_choices(get_diff(AVAILABLE_CHOICES, choice))
|
||||
return next_state
|
||||
|
||||
def __repr__(self):
|
||||
self.displayable_choices = ['{}: {}'.format(choice.table, choice.columns)
|
||||
for choice in self.accumulation_choices]
|
||||
return "reward: {}, storage :{}, choices: {}".format(
|
||||
self.current_benefit, self.current_storage, self.displayable_choices)
|
||||
|
||||
|
||||
class Node(object):
|
||||
"""
|
||||
The Node of the Monte Carlo tree search tree contains the parent node and
|
||||
current point information,
|
||||
which is used to calculate the traversal times and quality value of the UCB,
|
||||
and the State of the Node selected by the game.
|
||||
"""
|
||||
def __init__(self):
|
||||
self.visit_number = 0
|
||||
self.quality = 0.0
|
||||
|
||||
self.parent = None
|
||||
self.children = []
|
||||
self.state = None
|
||||
|
||||
def get_parent(self):
|
||||
return self.parent
|
||||
|
||||
def set_parent(self, parent):
|
||||
self.parent = parent
|
||||
|
||||
def get_children(self):
|
||||
return self.children
|
||||
|
||||
def expand_child(self, node):
|
||||
node.set_parent(self)
|
||||
self.children.append(node)
|
||||
|
||||
def set_state(self, state):
|
||||
self.state = state
|
||||
|
||||
def get_state(self):
|
||||
return self.state
|
||||
|
||||
def get_visit_number(self):
|
||||
return self.visit_number
|
||||
|
||||
def set_visit_number(self, number):
|
||||
self.visit_number = number
|
||||
|
||||
def update_visit_number(self):
|
||||
self.visit_number += 1
|
||||
|
||||
def get_quality_value(self):
|
||||
return self.quality
|
||||
|
||||
def set_quality_value(self, value):
|
||||
self.quality = value
|
||||
|
||||
def update_quality_value(self, reward):
|
||||
self.quality += reward
|
||||
|
||||
def is_all_expand(self):
|
||||
return len(self.children) == \
|
||||
len(AVAILABLE_CHOICES) - len(self.get_state().get_accumulation_choices())
|
||||
|
||||
def __repr__(self):
|
||||
return "Node: {}, Q/N: {}/{}, State: {}".format(
|
||||
hash(self), self.quality, self.visit_number, self.state)
|
||||
|
||||
|
||||
def tree_policy(node):
|
||||
"""
|
||||
In the Selection and Expansion stages of Monte Carlo tree search,
|
||||
the node that needs to be searched (such as the root node) is passed in,
|
||||
and the best node that needs to be expanded is returned according to the exploration/exploitation algorithm.
|
||||
Note that if the node is a leaf node, it will be returned directly.
|
||||
|
||||
The basic strategy is to first find the child nodes that have not been selected at present,
|
||||
and select them randomly if there are more than one. If both are selected,
|
||||
find the one with the largest UCB value that has weighed exploration/exploitation,
|
||||
and randomly select if the UCB values are equal.
|
||||
"""
|
||||
|
||||
# check if the current node is leaf node
|
||||
while node and not node.get_state().is_terminal():
|
||||
|
||||
if node.is_all_expand():
|
||||
node = best_child(node, True)
|
||||
else:
|
||||
# return the new sub node
|
||||
sub_node = expand(node)
|
||||
# when there is no node that satisfies the condition in the remaining nodes,
|
||||
# this state is empty
|
||||
if sub_node.get_state():
|
||||
return sub_node
|
||||
|
||||
# return the leaf node
|
||||
return node
|
||||
|
||||
|
||||
def default_policy(node):
|
||||
"""
|
||||
In the Simulation stage of Monte Carlo tree search, input a node that needs to be expanded,
|
||||
create a new node after random operation, and return the reward of the new node.
|
||||
Note that the input node should not be a child node,
|
||||
and there are unexecuted Actions that can be expendable.
|
||||
|
||||
The basic strategy is to choose the Action at random.
|
||||
"""
|
||||
|
||||
# get the state of the game
|
||||
current_state = copy.deepcopy(node.get_state())
|
||||
|
||||
# run until the game over
|
||||
while not current_state.is_terminal():
|
||||
# pick one random action to play and get next state
|
||||
next_state = current_state.get_next_state_with_random_choice()
|
||||
if not next_state:
|
||||
break
|
||||
current_state = next_state
|
||||
|
||||
final_state_reward = current_state.compute_benefit()
|
||||
return final_state_reward
|
||||
|
||||
|
||||
def expand(node):
|
||||
"""
|
||||
Enter a node, expand a new node on the node, use the random method to execute the Action,
|
||||
and return the new node. Note that it is necessary to ensure that the newly
|
||||
added nodes are different from other node Action
|
||||
"""
|
||||
|
||||
new_state = node.get_state().get_next_state_with_random_choice()
|
||||
sub_node = Node()
|
||||
sub_node.set_state(new_state)
|
||||
node.expand_child(sub_node)
|
||||
|
||||
return sub_node
|
||||
|
||||
|
||||
def best_child(node, is_exploration):
|
||||
"""
|
||||
Using the UCB algorithm,
|
||||
select the child node with the highest score after weighing the exploration and exploitation.
|
||||
Note that if it is the prediction stage,
|
||||
the current Q-value score with the highest score is directly selected.
|
||||
"""
|
||||
|
||||
best_score = -sys.maxsize
|
||||
best_sub_node = None
|
||||
|
||||
# travel all sub nodes to find the best one
|
||||
for sub_node in node.get_children():
|
||||
# The children nodes of the node contains the children node whose state is empty,
|
||||
# this kind of node comes from the node that does not meet the conditions.
|
||||
if not sub_node.get_state():
|
||||
continue
|
||||
# ignore exploration for inference
|
||||
if is_exploration:
|
||||
C = 1 / math.sqrt(2.0)
|
||||
else:
|
||||
C = 0.0
|
||||
|
||||
# UCB = quality / times + C * sqrt(2 * ln(total_times) / times)
|
||||
left = sub_node.get_quality_value() / sub_node.get_visit_number()
|
||||
right = 2.0 * math.log(node.get_visit_number()) / sub_node.get_visit_number()
|
||||
score = left + C * math.sqrt(right)
|
||||
# get the maximum score, while filtering nodes that do not meet the space constraints and
|
||||
# nodes that have no revenue
|
||||
if score > best_score \
|
||||
and sub_node.get_state().get_current_storage() <= STORAGE_THRESHOLD \
|
||||
and sub_node.get_state().get_current_benefit() > 0:
|
||||
best_sub_node = sub_node
|
||||
best_score = score
|
||||
|
||||
return best_sub_node
|
||||
|
||||
|
||||
def backpropagate(node, reward):
|
||||
"""
|
||||
In the Backpropagation stage of Monte Carlo tree search,
|
||||
input the node that needs to be expended and the reward of the newly executed Action,
|
||||
feed it back to the expend node and all upstream nodes,
|
||||
and update the corresponding data.
|
||||
"""
|
||||
|
||||
# update util the root node
|
||||
while node is not None:
|
||||
# update the visit number
|
||||
node.update_visit_number()
|
||||
|
||||
# update the quality value
|
||||
node.update_quality_value(reward)
|
||||
|
||||
# change the node to the parent node
|
||||
node = node.parent
|
||||
|
||||
|
||||
def monte_carlo_tree_search(node):
|
||||
"""
|
||||
Implement the Monte Carlo tree search algorithm, pass in a root node,
|
||||
expand new nodes and update data according to the
|
||||
tree structure that has been explored before in a limited time,
|
||||
and then return as long as the child node with the highest exploitation.
|
||||
|
||||
When making predictions,
|
||||
you only need to select the node with the largest exploitation according to the Q value,
|
||||
and find the next optimal node.
|
||||
"""
|
||||
|
||||
computation_budget = len(AVAILABLE_CHOICES) * 3
|
||||
|
||||
# run as much as possible under the computation budget
|
||||
for i in range(computation_budget):
|
||||
# 1. find the best node to expand
|
||||
expand_node = tree_policy(node)
|
||||
if not expand_node:
|
||||
# when it is None, it means that all nodes are added but no nodes meet the space limit
|
||||
break
|
||||
# 2. random get next action and get reward
|
||||
reward = default_policy(expand_node)
|
||||
|
||||
# 3. update all passing nodes with reward
|
||||
backpropagate(expand_node, reward)
|
||||
|
||||
# get the best next node
|
||||
best_next_node = best_child(node, False)
|
||||
|
||||
return best_next_node
|
||||
|
||||
|
||||
def MCTS(workload_info, atomic_choices, available_choices, storage_threshold):
|
||||
global ATOMIC_CHOICES, STORAGE_THRESHOLD, WORKLOAD_INFO, AVAILABLE_CHOICES
|
||||
WORKLOAD_INFO = workload_info
|
||||
AVAILABLE_CHOICES = available_choices
|
||||
ATOMIC_CHOICES = atomic_choices
|
||||
STORAGE_THRESHOLD = storage_threshold
|
||||
|
||||
# create the initialized state and initialized node
|
||||
init_state = State()
|
||||
choices = copy.copy(available_choices)
|
||||
init_state.set_available_choices(choices)
|
||||
init_node = Node()
|
||||
init_node.set_state(init_state)
|
||||
current_node = init_node
|
||||
|
||||
opt_config = []
|
||||
# set the rounds to play
|
||||
for i in range(len(AVAILABLE_CHOICES)):
|
||||
if current_node:
|
||||
current_node = monte_carlo_tree_search(current_node)
|
||||
if current_node:
|
||||
opt_config = current_node.state.accumulation_choices
|
||||
else:
|
||||
break
|
||||
return opt_config
|
||||
Reference in New Issue
Block a user