[enhancement](Nereids): use enforcer to choose the n-th plan (#22929)

This commit is contained in:
谢健
2023-09-28 15:16:24 +08:00
committed by GitHub
parent b50c1448df
commit a574f29d76
7 changed files with 277 additions and 126 deletions

View File

@ -27,10 +27,13 @@ import org.apache.doris.nereids.metrics.consumer.LogConsumer;
import org.apache.doris.nereids.metrics.event.GroupMergeEvent;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.properties.RequestPropertyDeriver;
import org.apache.doris.nereids.properties.RequirePropertiesSupplier;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.LeafPlan;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
@ -52,7 +55,9 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import javax.annotation.Nullable;
/**
@ -796,71 +801,101 @@ public class Memo {
public Pair<Long, Double> rank(long n) {
double threshold = 0.000000001;
Preconditions.checkArgument(n > 0, "the n %d must be greater than 0 in nthPlan", n);
List<Pair<Long, Double>> plans = rankGroup(root, PhysicalProperties.GATHER);
plans = plans.stream().filter(
p -> !p.second.equals(Double.NaN)
&& !p.second.equals(Double.POSITIVE_INFINITY)
&& !p.second.equals(Double.NEGATIVE_INFINITY))
List<Pair<Long, Cost>> plans = rankGroup(root, PhysicalProperties.GATHER);
plans = plans.stream()
.filter(p -> Double.isFinite(p.second.getValue()))
.collect(Collectors.toList());
// This is big heap, it always pops the element with larger cost or larger id.
PriorityQueue<Pair<Long, Double>> pq = new PriorityQueue<>((l, r) -> Math.abs(l.second - r.second) < threshold
? -Long.compare(l.first, r.first) : -Double.compare(l.second, r.second));
for (Pair<Long, Double> p : plans) {
PriorityQueue<Pair<Long, Cost>> pq = new PriorityQueue<>((l, r) ->
Math.abs(l.second.getValue() - r.second.getValue()) < threshold
? -Long.compare(l.first, r.first)
: -Double.compare(l.second.getValue(), r.second.getValue()));
for (Pair<Long, Cost> p : plans) {
pq.add(p);
if (pq.size() > n) {
pq.poll();
}
}
return pq.peek();
Preconditions.checkArgument(pq.peek() != null, "rank error because there is no valid plan");
return Pair.of(pq.peek().first, pq.peek().second.getValue());
}
private List<Pair<Long, Double>> rankGroup(Group group, PhysicalProperties prop) {
List<Pair<Long, Double>> res = new ArrayList<>();
int prefix = res.size();
for (GroupExpression groupExpression : extractGroupExpressionContainsProp(group, prop)) {
for (Pair<Long, Double> idCostPair : rankGroupExpression(groupExpression, prop)) {
/**
* return number of plan that can be ranked
*/
public int getRankSize() {
List<Pair<Long, Cost>> plans = rankGroup(root, PhysicalProperties.GATHER);
plans = plans.stream().filter(
p -> !p.second.equals(Double.NaN)
&& !p.second.equals(Double.POSITIVE_INFINITY)
&& !p.second.equals(Double.NEGATIVE_INFINITY))
.collect(Collectors.toList());
return plans.size();
}
private List<Pair<Long, Cost>> rankGroup(Group group, PhysicalProperties prop) {
List<Pair<Long, Cost>> res = new ArrayList<>();
int prefix = 0;
List<GroupExpression> validGroupExprList = extractGroupExpressionSatisfyProp(group, prop);
for (GroupExpression groupExpression : validGroupExprList) {
for (Pair<Long, Cost> idCostPair : rankGroupExpression(groupExpression, prop)) {
res.add(Pair.of(idCostPair.first + prefix, idCostPair.second));
}
prefix = res.size();
// avoid ranking all plans
if (res.size() > 1e2) {
break;
}
}
return res;
}
private List<Pair<Long, Double>> rankGroupExpression(GroupExpression groupExpression,
private List<Pair<Long, Cost>> rankGroupExpression(GroupExpression groupExpression,
PhysicalProperties prop) {
if (!groupExpression.getLowestCostTable().containsKey(prop)) {
return new ArrayList<>();
}
List<Pair<Long, Double>> res = new ArrayList<>();
List<PhysicalProperties> inputProperties = groupExpression.getInputPropertiesList(prop);
List<Pair<Long, Cost>> res = new ArrayList<>();
if (groupExpression.getPlan() instanceof LeafPlan) {
res.add(Pair.of(0L, groupExpression.getCostByProperties(prop)));
res.add(Pair.of(0L, groupExpression.getCostValueByProperties(prop)));
return res;
}
List<List<Pair<Long, Double>>> children = new ArrayList<>();
for (int i = 0; i < inputProperties.size(); i++) {
// To avoid reach a circle, we don't allow ranking the same group with the same physical properties.
Preconditions.checkArgument(!groupExpression.child(i).equals(groupExpression.getOwnerGroup())
|| !prop.equals(inputProperties.get(i)));
List<Pair<Long, Double>> idCostPair
= rankGroup(groupExpression.child(i), inputProperties.get(i));
children.add(idCostPair);
}
List<Pair<Long, List<Integer>>> childrenId = new ArrayList<>();
permute(children, 0, childrenId, new ArrayList<>());
Cost cost = CostCalculator.calculateCost(groupExpression, inputProperties);
for (Pair<Long, List<Integer>> c : childrenId) {
Cost totalCost = cost;
for (int i = 0; i < children.size(); i++) {
totalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
totalCost,
groupExpression.child(i).getLowestCostPlan(inputProperties.get(i)).get().first,
i);
List<List<PhysicalProperties>> inputPropertiesList = extractInputProperties(groupExpression, prop);
for (List<PhysicalProperties> inputProperties : inputPropertiesList) {
int prefix = res.size();
List<List<Pair<Long, Cost>>> children = new ArrayList<>();
for (int i = 0; i < inputProperties.size(); i++) {
// To avoid reach a circle, we don't allow ranking the same group with the same physical properties.
Preconditions.checkArgument(!groupExpression.child(i).equals(groupExpression.getOwnerGroup())
|| !prop.equals(inputProperties.get(i)));
List<Pair<Long, Cost>> idCostPair
= rankGroup(groupExpression.child(i), inputProperties.get(i));
children.add(idCostPair);
}
List<Pair<Long, List<Integer>>> childrenId = new ArrayList<>();
permute(children, 0, childrenId, new ArrayList<>());
Cost cost = CostCalculator.calculateCost(groupExpression, inputProperties);
for (Pair<Long, List<Integer>> c : childrenId) {
Cost totalCost = cost;
for (int i = 0; i < children.size(); i++) {
totalCost = CostCalculator.addChildCost(groupExpression.getPlan(),
totalCost,
children.get(i).get(c.second.get(i)).second,
i);
}
if (res.isEmpty()) {
Preconditions.checkArgument(
Math.abs(totalCost.getValue() - groupExpression.getCostByProperties(prop)) < 0.0001,
"Please check operator %s, expected cost %s but found %s",
groupExpression.getPlan().shapeInfo(), totalCost.getValue(),
groupExpression.getCostByProperties(prop));
}
res.add(Pair.of(prefix + c.first, totalCost));
}
res.add(Pair.of(c.first, totalCost.getValue()));
}
return res;
}
@ -869,7 +904,7 @@ public class Memo {
* for children [1, 2] [1, 2, 3]
* we can get: 0: [1,1] 1:[1, 2] 2:[1, 3] 3:[2, 1] 4:[2, 2] 5:[2, 3]
*/
private void permute(List<List<Pair<Long, Double>>> children, int index,
private void permute(List<List<Pair<Long, Cost>>> children, int index,
List<Pair<Long, List<Integer>>> result, List<Integer> current) {
if (index == children.size()) {
result.add(Pair.of(getUniqueId(children, current), current));
@ -889,7 +924,7 @@ public class Memo {
* [0, 0]: 0*1 + 0*1*2
* [0, 1]: 0*1 + 1*1*2
*/
private static long getUniqueId(List<List<Pair<Long, Double>>> lists, List<Integer> current) {
private static long getUniqueId(List<List<Pair<Long, Cost>>> lists, List<Integer> current) {
long id = 0;
long factor = 1;
for (int i = 0; i < lists.size(); i++) {
@ -899,54 +934,153 @@ public class Memo {
return id;
}
private List<GroupExpression> extractGroupExpressionContainsProp(Group group, PhysicalProperties prop) {
List<GroupExpression> validExpressions = new ArrayList<>();
private List<GroupExpression> extractGroupExpressionSatisfyProp(Group group, PhysicalProperties prop) {
GroupExpression bestExpr = group.getLowestCostPlan(prop).get().second;
validExpressions.add(bestExpr);
for (GroupExpression groupExpression : group.getPhysicalExpressions()) {
if (!groupExpression.equals(bestExpr) && groupExpression.getLowestCostTable().containsKey(prop)) {
validExpressions.add(groupExpression);
}
}
return validExpressions;
List<GroupExpression> exprs = Lists.newArrayList(bestExpr);
Set<GroupExpression> hasVisited = new HashSet<>();
hasVisited.add(bestExpr);
Stream.concat(group.getPhysicalExpressions().stream(), group.getEnforcers().stream())
.forEach(groupExpression -> {
if (!groupExpression.getInputPropertiesListOrEmpty(prop).isEmpty()
&& !groupExpression.equals(bestExpr) && !hasVisited.contains(groupExpression)) {
hasVisited.add(groupExpression);
exprs.add(groupExpression);
}
});
return exprs;
}
private PhysicalPlan unrankGroup(Group group, PhysicalProperties prop, long rank) {
// ----------------------------------------------------------------
// extract input properties for a given group expression and required output properties
// There are three cases:
// 1. If group expression is enforcer, return the input properties of the best expression
// 2. If group expression require any, return any input properties
// 3. Otherwise, return all input properties that satisfies the required output properties
private List<List<PhysicalProperties>> extractInputProperties(GroupExpression groupExpression,
PhysicalProperties prop) {
List<List<PhysicalProperties>> res = new ArrayList<>();
res.add(groupExpression.getInputPropertiesList(prop));
// return optimized input for enforcer
if (groupExpression.getOwnerGroup().getEnforcers().contains(groupExpression)) {
return res;
}
// return any if exits except RequirePropertiesSupplier and SetOperators
// Because PropRegulator could change their input properties
RequestPropertyDeriver requestPropertyDeriver = new RequestPropertyDeriver(prop);
List<List<PhysicalProperties>> requestList = requestPropertyDeriver
.getRequestChildrenPropertyList(groupExpression);
Optional<List<PhysicalProperties>> any = requestList.stream()
.filter(e -> e.stream().allMatch(PhysicalProperties.ANY::equals))
.findAny();
if (any.isPresent()
&& !(groupExpression.getPlan() instanceof RequirePropertiesSupplier)
&& !(groupExpression.getPlan() instanceof SetOperation)) {
res.clear();
res.add(any.get());
return res;
}
// return all optimized inputs
Set<List<PhysicalProperties>> inputProps = groupExpression.getLowestCostTable().keySet().stream()
.filter(physicalProperties -> physicalProperties.satisfy(prop))
.map(groupExpression::getInputPropertiesList)
.collect(Collectors.toSet());
res.addAll(inputProps);
return res;
}
private int getGroupSize(Group group, PhysicalProperties prop,
Map<GroupExpression, List<List<Integer>>> exprSizeCache) {
List<GroupExpression> validGroupExprs = extractGroupExpressionSatisfyProp(group, prop);
int groupCount = 0;
for (GroupExpression groupExpression : validGroupExprs) {
int exprCount = getExprSize(groupExpression, prop, exprSizeCache);
groupCount += exprCount;
if (groupCount > 1e2) {
break;
}
}
return groupCount;
}
// return size for each input properties
private int getExprSize(GroupExpression groupExpression, PhysicalProperties properties,
Map<GroupExpression, List<List<Integer>>> exprChildSizeCache) {
List<List<Integer>> exprCount = new ArrayList<>();
if (!groupExpression.getLowestCostTable().containsKey(properties)) {
exprCount.add(Lists.newArrayList(0));
} else if (groupExpression.getPlan() instanceof LeafPlan) {
exprCount.add(Lists.newArrayList(1));
} else {
List<List<PhysicalProperties>> inputPropertiesList = extractInputProperties(groupExpression, properties);
for (List<PhysicalProperties> inputProperties : inputPropertiesList) {
List<Integer> groupExprSize = new ArrayList<>();
for (int i = 0; i < inputProperties.size(); i++) {
groupExprSize.add(
getGroupSize(groupExpression.child(i), inputProperties.get(i), exprChildSizeCache));
}
exprCount.add(groupExprSize);
}
}
exprChildSizeCache.put(groupExpression, exprCount);
return exprCount.stream()
.mapToInt(s -> s.stream().reduce(1, (a, b) -> a * b))
.sum();
}
private PhysicalPlan unrankGroup(Group group, PhysicalProperties prop, long rank,
Map<GroupExpression, List<List<Integer>>> exprSizeCache) {
int prefix = 0;
for (GroupExpression groupExpression : extractGroupExpressionContainsProp(group, prop)) {
List<Pair<Long, Double>> possiblePlans = rankGroupExpression(groupExpression, prop);
if (!possiblePlans.isEmpty() && rank - prefix <= possiblePlans.get(possiblePlans.size() - 1).first) {
return unrankGroupExpression(groupExpression, prop, rank - prefix);
for (GroupExpression groupExpression : extractGroupExpressionSatisfyProp(group, prop)) {
int exprCount = exprSizeCache.get(groupExpression).stream()
.mapToInt(s -> s.stream().reduce(1, (a, b) -> a * b))
.sum();
// rank is start from 0
if (exprCount != 0 && rank + 1 - prefix <= exprCount) {
return unrankGroupExpression(groupExpression, prop, rank - prefix,
exprSizeCache);
}
prefix += possiblePlans.size();
prefix += exprCount;
}
Preconditions.checkArgument(false, "unrank Group error");
return null;
throw new RuntimeException("the group has no plan for prop %s in rank job");
}
private PhysicalPlan unrankGroupExpression(GroupExpression groupExpression,
PhysicalProperties prop, long rank) {
private PhysicalPlan unrankGroupExpression(GroupExpression groupExpression, PhysicalProperties prop, long rank,
Map<GroupExpression, List<List<Integer>>> exprSizeCache) {
if (groupExpression.getPlan() instanceof LeafPlan) {
Preconditions.checkArgument(rank == 0);
Preconditions.checkArgument(rank == 0,
"leaf plan's %s rank must be 0 but is %d", groupExpression, rank);
return ((PhysicalPlan) groupExpression.getPlan()).withPhysicalPropertiesAndStats(
groupExpression.getOutputProperties(prop),
groupExpression.getOwnerGroup().getStatistics());
}
List<List<Pair<Long, Double>>> children = new ArrayList<>();
List<PhysicalProperties> properties = groupExpression.getInputPropertiesList(prop);
for (int i = 0; i < properties.size(); i++) {
children.add(rankGroup(groupExpression.child(i), properties.get(i)));
}
List<Long> childrenRanks = extractChildRanks(rank, children);
List<Plan> childrenPlan = new ArrayList<>();
for (int i = 0; i < properties.size(); i++) {
childrenPlan.add(unrankGroup(groupExpression.child(i), properties.get(i), childrenRanks.get(i)));
List<List<PhysicalProperties>> inputPropertiesList = extractInputProperties(groupExpression, prop);
for (int i = 0; i < inputPropertiesList.size(); i++) {
List<PhysicalProperties> properties = inputPropertiesList.get(i);
List<Integer> childrenSize = exprSizeCache.get(groupExpression).get(i);
int count = childrenSize.stream().reduce(1, (a, b) -> a * b);
if (rank >= count) {
rank -= count;
continue;
}
List<Long> childrenRanks = extractChildRanks(rank, childrenSize);
List<Plan> childrenPlan = new ArrayList<>();
for (int j = 0; j < properties.size(); j++) {
Plan plan = unrankGroup(groupExpression.child(j), properties.get(j),
childrenRanks.get(j), exprSizeCache);
Preconditions.checkArgument(plan != null, "rank group get null");
childrenPlan.add(plan);
}
Plan plan = groupExpression.getPlan().withChildren(childrenPlan);
return ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
groupExpression.getOutputProperties(prop),
groupExpression.getOwnerGroup().getStatistics());
}
Plan plan = groupExpression.getPlan().withChildren(childrenPlan);
return ((PhysicalPlan) plan).withPhysicalPropertiesAndStats(
groupExpression.getOutputProperties(prop),
groupExpression.getOwnerGroup().getStatistics());
throw new RuntimeException("the groupExpr has no plan for prop in rank job");
}
/**
@ -955,20 +1089,20 @@ public class Memo {
* 1: [1%1, 1%(1*2)]
* 2: [2%1, 2%(1*2)]
*/
private List<Long> extractChildRanks(long rank, List<List<Pair<Long, Double>>> children) {
Preconditions.checkArgument(!children.isEmpty(), "children should not empty in extractChildRanks");
int factor = children.get(0).size();
private List<Long> extractChildRanks(long rank, List<Integer> childrenSize) {
Preconditions.checkArgument(!childrenSize.isEmpty(), "children should not empty in extractChildRanks");
List<Long> indices = new ArrayList<>();
for (int i = 0; i < children.size() - 1; i++) {
for (int i = 0; i < childrenSize.size(); i++) {
int factor = childrenSize.get(i);
indices.add(rank % factor);
rank = rank / factor;
factor *= children.get(i + 1).size();
}
indices.add(rank % factor);
return indices;
}
public PhysicalPlan unrank(long id) {
return unrankGroup(getRoot(), PhysicalProperties.GATHER, id);
Map<GroupExpression, List<List<Integer>>> exprSizeCache = new HashMap<>();
getGroupSize(getRoot(), PhysicalProperties.GATHER, exprSizeCache);
return unrankGroup(getRoot(), PhysicalProperties.GATHER, id, exprSizeCache);
}
}

View File

@ -72,6 +72,10 @@ public class RequestPropertyDeriver extends PlanVisitor<Void, PlanContext> {
this.requestPropertyFromParent = context.getRequiredProperties();
}
public RequestPropertyDeriver(PhysicalProperties requestPropertyFromParent) {
this.requestPropertyFromParent = requestPropertyFromParent;
}
/**
* get request children property list
*/