[enhancement](Nereids): check stats unreliable when deriving stats (#26103)

check stats unreliable when deriving stats
This commit is contained in:
谢健
2023-10-30 17:21:42 +08:00
committed by GitHub
parent 3a954cd1aa
commit 844b7c8cba
6 changed files with 33 additions and 37 deletions

View File

@ -33,9 +33,10 @@ import java.util.List;
public class PlanContext {
private final List<Statistics> childrenStats;
private Statistics planStats;
private final Statistics planStats;
private final int arity;
private boolean isBroadcastJoin = false;
private final boolean isStatsReliable;
/**
* Constructor for PlanContext.
@ -43,15 +44,18 @@ public class PlanContext {
public PlanContext(GroupExpression groupExpression) {
this.arity = groupExpression.arity();
this.planStats = groupExpression.getOwnerGroup().getStatistics();
this.isStatsReliable = groupExpression.getOwnerGroup().isStatsReliable();
this.childrenStats = new ArrayList<>(groupExpression.arity());
for (int i = 0; i < groupExpression.arity(); i++) {
childrenStats.add(groupExpression.childStatistics(i));
}
}
// This is used in GraphSimplifier
public PlanContext(Statistics planStats, List<Statistics> childrenStats) {
this.planStats = planStats;
this.childrenStats = childrenStats;
this.isStatsReliable = false;
this.arity = this.childrenStats.size();
}
@ -71,6 +75,10 @@ public class PlanContext {
return planStats;
}
public boolean isStatsReliable() {
return isStatsReliable;
}
/**
* Get child statistics.
*/

View File

@ -22,7 +22,6 @@ import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
@ -303,7 +302,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
}
// TODO: since the outputs rows may expand a lot, penalty on it will cause bc never be chosen.
// will refine this in next generation cost model.
if (isStatsUnknown(physicalHashJoin, buildStats, probeStats)) {
if (!context.isStatsReliable()) {
// forbid broadcast join when stats is unknown
return CostV1.of(rightRowCount * buildSideFactor + 1 / leftRowCount,
rightRowCount,
@ -315,7 +314,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
0
);
}
if (isStatsUnknown(physicalHashJoin, buildStats, probeStats)) {
if (!context.isStatsReliable()) {
return CostV1.of(rightRowCount + 1 / leftRowCount,
rightRowCount,
0);
@ -326,30 +325,6 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
);
}
private boolean isStatsUnknown(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
Statistics build, Statistics probe) {
for (Slot slot : join.getConditionSlot()) {
if ((build.columnStatistics().containsKey(slot) && !build.columnStatistics().get(slot).isUnKnown)
|| (probe.columnStatistics().containsKey(slot) && !probe.columnStatistics().get(slot).isUnKnown)) {
continue;
}
return true;
}
return false;
}
private boolean isStatsUnknown(PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> join,
Statistics build, Statistics probe) {
for (Slot slot : join.getConditionSlot()) {
if ((build.columnStatistics().containsKey(slot) && !build.columnStatistics().get(slot).isUnKnown)
|| (probe.columnStatistics().containsKey(slot) && !probe.columnStatistics().get(slot).isUnKnown)) {
continue;
}
return true;
}
return false;
}
@Override
public Cost visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
@ -358,7 +333,7 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
Preconditions.checkState(context.arity() == 2);
Statistics leftStatistics = context.getChildStatistics(0);
Statistics rightStatistics = context.getChildStatistics(1);
if (isStatsUnknown(nestedLoopJoin, leftStatistics, rightStatistics)) {
if (!context.isStatsReliable()) {
return CostV1.of(rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(),
rightStatistics.getRowCount(),
0);

View File

@ -59,7 +59,7 @@ public class Group {
private final List<GroupExpression> logicalExpressions = Lists.newArrayList();
private final List<GroupExpression> physicalExpressions = Lists.newArrayList();
private final List<GroupExpression> enforcers = Lists.newArrayList();
private boolean isStatsReliable = true;
private LogicalProperties logicalProperties;
// Map of cost lower bounds
@ -119,6 +119,14 @@ public class Group {
return groupExpression;
}
public void setStatsReliable(boolean statsReliable) {
this.isStatsReliable = statsReliable;
}
public boolean isStatsReliable() {
return isStatsReliable;
}
public void addLogicalExpression(GroupExpression groupExpression) {
groupExpression.setOwnerGroup(this);
logicalExpressions.add(groupExpression);

View File

@ -221,8 +221,12 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
Plan plan = groupExpression.getPlan();
Statistics newStats = plan.accept(this, null);
newStats.enforceValid();
// We ensure that the rowCount remains unchanged in order to make the cost of each plan comparable.
if (groupExpression.getOwnerGroup().getStatistics() == null) {
boolean isReliable = groupExpression.getPlan().getExpressions().stream()
.noneMatch(e -> newStats.isInputSlotsUnknown(e.getInputSlots()));
groupExpression.getOwnerGroup().setStatsReliable(isReliable);
groupExpression.getOwnerGroup().setStatistics(newStats);
groupExpression.setEstOutputRowCount(newStats.getRowCount());
} else {

View File

@ -35,6 +35,7 @@ import org.apache.doris.statistics.Statistics;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import org.json.JSONObject;
@ -42,7 +43,9 @@ import java.util.Collection;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Abstract class for all physical join node.
@ -217,6 +220,11 @@ public abstract class AbstractPhysicalJoin<
.build();
}
public Set<Slot> getConditionSlot() {
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}
@Override
public String toString() {
List<Object> args = Lists.newArrayList("type", joinType,

View File

@ -43,7 +43,6 @@ import org.apache.doris.statistics.Statistics;
import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@ -52,7 +51,6 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Physical hash join plan.
@ -246,11 +244,6 @@ public class PhysicalHashJoin<
return pushedDown;
}
public Set<Slot> getConditionSlot() {
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}
@Override
public String shapeInfo() {
StringBuilder builder = new StringBuilder();