[optimize](Nereids): optimize Nereids performance (#22885)
This commit is contained in:
@ -62,7 +62,7 @@ public class GroupExpression {
|
||||
|
||||
private double estOutputRowCount = -1;
|
||||
|
||||
//Record the rule that generate this plan. It's used for debugging
|
||||
// Record the rule that generate this plan. It's used for debugging
|
||||
private Rule fromRule;
|
||||
|
||||
// Mapping from output properties to the corresponding best cost, statistics, and child properties.
|
||||
@ -84,6 +84,7 @@ public class GroupExpression {
|
||||
}
|
||||
|
||||
/**
|
||||
* Notice!!!: children will use param `children` directly, So don't modify it after this constructor outside.
|
||||
* Constructor for GroupExpression.
|
||||
*
|
||||
* @param plan {@link Plan} to reference
|
||||
@ -92,7 +93,7 @@ public class GroupExpression {
|
||||
public GroupExpression(Plan plan, List<Group> children) {
|
||||
this.plan = Objects.requireNonNull(plan, "plan can not be null")
|
||||
.withGroupExpression(Optional.of(this));
|
||||
this.children = Lists.newArrayList(Objects.requireNonNull(children, "children can not be null"));
|
||||
this.children = Objects.requireNonNull(children, "children can not be null");
|
||||
this.children.forEach(childGroup -> childGroup.addParentExpression(this));
|
||||
this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
|
||||
this.statDerived = false;
|
||||
@ -311,7 +312,7 @@ public class GroupExpression {
|
||||
}
|
||||
|
||||
public Statistics childStatistics(int idx) {
|
||||
return new Statistics(child(idx).getStatistics());
|
||||
return child(idx).getStatistics();
|
||||
}
|
||||
|
||||
public void setEstOutputRowCount(double estOutputRowCount) {
|
||||
|
||||
@ -161,10 +161,11 @@ public class GroupExpressionMatching implements Iterable<Plan> {
|
||||
|
||||
// assemble all combination of plan tree by current root plan and children plan
|
||||
while (offset < childrenPlans.size()) {
|
||||
List<Plan> children = Lists.newArrayList();
|
||||
ImmutableList.Builder<Plan> childrenBuilder = ImmutableList.builder();
|
||||
for (int i = 0; i < childrenPlans.size(); i++) {
|
||||
children.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
|
||||
childrenBuilder.add(childrenPlans.get(i).get(childrenPlanIndex[i]));
|
||||
}
|
||||
List<Plan> children = childrenBuilder.build();
|
||||
|
||||
LogicalProperties logicalProperties = groupExpression.getOwnerGroup().getLogicalProperties();
|
||||
// assemble children: replace GroupPlan to real plan,
|
||||
|
||||
@ -74,10 +74,7 @@ public class AppliedAwareRule extends Rule {
|
||||
/** provide this method for the child class get the applied state */
|
||||
public final boolean isAppliedRule(Rule rule, Plan plan) {
|
||||
Optional<BitSet> appliedRules = plan.getMutableState("applied_rules");
|
||||
if (!appliedRules.isPresent()) {
|
||||
return false;
|
||||
}
|
||||
return appliedRules.get().get(rule.getRuleType().ordinal());
|
||||
return appliedRules.map(bitSet -> bitSet.get(rule.getRuleType().ordinal())).orElse(false);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -76,6 +76,7 @@ import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
import com.google.common.collect.Lists;
|
||||
import com.google.common.collect.Sets;
|
||||
|
||||
@ -83,7 +84,6 @@ import java.util.ArrayList;
|
||||
import java.util.Collection;
|
||||
import java.util.Collections;
|
||||
import java.util.HashSet;
|
||||
import java.util.LinkedHashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
@ -149,7 +149,7 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
Set<Expression> boundConjuncts = filter.getConjuncts().stream()
|
||||
.map(expr -> bindSlot(expr, filter.children(), ctx.cascadesContext))
|
||||
.map(expr -> bindFunction(expr, ctx.cascadesContext))
|
||||
.collect(Collectors.toCollection(LinkedHashSet::new));
|
||||
.collect(ImmutableSet.toImmutableSet());
|
||||
return new LogicalFilter<>(boundConjuncts, filter.child());
|
||||
})
|
||||
),
|
||||
@ -613,9 +613,12 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
|
||||
private <E extends Expression> List<E> bindSlot(
|
||||
List<E> exprList, List<Plan> inputs, CascadesContext cascadesContext) {
|
||||
return exprList.stream()
|
||||
.map(expr -> bindSlot(expr, inputs, cascadesContext))
|
||||
.collect(Collectors.toList());
|
||||
List<E> slots = new ArrayList<>();
|
||||
for (E expr : exprList) {
|
||||
E result = bindSlot(expr, inputs, cascadesContext);
|
||||
slots.add(result);
|
||||
}
|
||||
return slots;
|
||||
}
|
||||
|
||||
private <E extends Expression> E bindSlot(E expr, Plan input, CascadesContext cascadesContext) {
|
||||
@ -641,9 +644,10 @@ public class BindExpression implements AnalysisRuleFactory {
|
||||
|
||||
private <E extends Expression> E bindSlot(E expr, List<Plan> inputs, CascadesContext cascadesContext,
|
||||
boolean enableExactMatch, boolean bindSlotInOuterScope) {
|
||||
List<Slot> boundedSlots = inputs.stream()
|
||||
.flatMap(input -> input.getOutput().stream())
|
||||
.collect(Collectors.toList());
|
||||
List<Slot> boundedSlots = new ArrayList<>();
|
||||
for (Plan input : inputs) {
|
||||
boundedSlots.addAll(input.getOutput());
|
||||
}
|
||||
return (E) new SlotBinder(toScope(cascadesContext, boundedSlots), cascadesContext,
|
||||
enableExactMatch, bindSlotInOuterScope).bind(expr);
|
||||
}
|
||||
|
||||
@ -17,7 +17,6 @@
|
||||
|
||||
package org.apache.doris.nereids.stats;
|
||||
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.EqualTo;
|
||||
@ -84,6 +83,8 @@ public class JoinEstimation {
|
||||
List<Double> unTrustEqualRatio = Lists.newArrayList();
|
||||
List<EqualTo> unTrustableCondition = Lists.newArrayList();
|
||||
boolean leftBigger = leftStats.getRowCount() > rightStats.getRowCount();
|
||||
double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
|
||||
double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
|
||||
List<EqualTo> trustableConditions = join.getHashJoinConjuncts().stream()
|
||||
.map(expression -> (EqualTo) expression)
|
||||
.filter(
|
||||
@ -94,8 +95,6 @@ public class JoinEstimation {
|
||||
EqualTo equal = normalizeHashJoinCondition(expression, leftStats, rightStats);
|
||||
ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats);
|
||||
ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats);
|
||||
double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount());
|
||||
double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount());
|
||||
boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold
|
||||
|| eqLeftColStats.ndv / leftStatsRowCount > almostUniqueThreshold;
|
||||
if (!trustable) {
|
||||
@ -123,22 +122,16 @@ public class JoinEstimation {
|
||||
|
||||
double outputRowCount = 1;
|
||||
if (!trustableConditions.isEmpty()) {
|
||||
List<Pair<? extends Expression, Double>> sortedJoinConditions = trustableConditions.stream()
|
||||
.map(expression -> Pair.of(expression, estimateJoinConditionSel(crossJoinStats, expression)))
|
||||
.sorted((a, b) -> {
|
||||
double sub = a.second - b.second;
|
||||
if (sub > 0) {
|
||||
return 1;
|
||||
} else if (sub < 0) {
|
||||
return -1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
}).collect(Collectors.toList());
|
||||
List<Double> joinConditionSels = trustableConditions.stream()
|
||||
.map(expression -> estimateJoinConditionSel(crossJoinStats, expression))
|
||||
.sorted()
|
||||
.collect(Collectors.toList());
|
||||
|
||||
double sel = 1.0;
|
||||
for (int i = 0; i < sortedJoinConditions.size(); i++) {
|
||||
sel *= Math.pow(sortedJoinConditions.get(i).second, 1 / Math.pow(2, i));
|
||||
double denominator = 1.0;
|
||||
for (int i = 0; i < joinConditionSels.size(); i++) {
|
||||
sel *= Math.pow(joinConditionSels.get(i), 1 / denominator);
|
||||
denominator *= 2;
|
||||
}
|
||||
outputRowCount = Math.max(1, crossJoinStats.getRowCount() * sel);
|
||||
outputRowCount = outputRowCount * Math.pow(0.9, unTrustableCondition.size());
|
||||
|
||||
@ -107,7 +107,7 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
|
||||
|
||||
@Override
|
||||
public boolean canBind() {
|
||||
return !bound() && childrenBound();
|
||||
return !bound() && children().stream().allMatch(Plan::bound);
|
||||
}
|
||||
|
||||
/**
|
||||
@ -170,8 +170,7 @@ public abstract class AbstractPlan extends AbstractTreeNode<Plan> implements Pla
|
||||
@Override
|
||||
public LogicalProperties computeLogicalProperties() {
|
||||
boolean hasUnboundChild = children.stream()
|
||||
.map(Plan::getLogicalProperties)
|
||||
.anyMatch(UnboundLogicalProperties.class::isInstance);
|
||||
.anyMatch(child -> !child.bound());
|
||||
if (hasUnboundChild || hasUnboundExpression()) {
|
||||
return UnboundLogicalProperties.INSTANCE;
|
||||
} else {
|
||||
|
||||
@ -55,6 +55,7 @@ public interface Plan extends TreeNode<Plan> {
|
||||
boolean canBind();
|
||||
|
||||
default boolean bound() {
|
||||
// TODO: avoid to use getLogicalProperties()
|
||||
return !(getLogicalProperties() instanceof UnboundLogicalProperties);
|
||||
}
|
||||
|
||||
@ -62,12 +63,6 @@ public interface Plan extends TreeNode<Plan> {
|
||||
return getExpressions().stream().anyMatch(Expression::hasUnbound);
|
||||
}
|
||||
|
||||
default boolean childrenBound() {
|
||||
return children()
|
||||
.stream()
|
||||
.allMatch(Plan::bound);
|
||||
}
|
||||
|
||||
default LogicalProperties computeLogicalProperties() {
|
||||
throw new IllegalStateException("Not support compute logical properties for " + getClass().getName());
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user