[feature](Nereids): generate left deep tree when stats is unknown (#25620)

generate left deep tree when stats is unknown
This commit is contained in:
谢健
2023-10-23 15:41:55 +08:00
committed by GitHub
parent 6f7f0a24c5
commit 2c3bc65fae
4 changed files with 73 additions and 2 deletions

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecGather;
import org.apache.doris.nereids.properties.DistributionSpecHash;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDeferMaterializeOlapScan;
@ -302,17 +303,53 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
}
// TODO: since the outputs rows may expand a lot, penalty on it will cause bc never be chosen.
// will refine this in next generation cost model.
if (isStatsUnknown(physicalHashJoin, buildStats, probeStats)) {
// forbid broadcast join when stats is unknown
return CostV1.of(rightRowCount * buildSideFactor + 1 / leftRowCount,
rightRowCount,
0
);
}
return CostV1.of(leftRowCount + rightRowCount * buildSideFactor + outputRowCount * probeSideFactor,
rightRowCount,
0
);
}
if (isStatsUnknown(physicalHashJoin, buildStats, probeStats)) {
return CostV1.of(rightRowCount + 1 / leftRowCount,
rightRowCount,
0);
}
return CostV1.of(leftRowCount + rightRowCount + outputRowCount,
rightRowCount,
0
);
}
private boolean isStatsUnknown(PhysicalHashJoin<? extends Plan, ? extends Plan> join,
Statistics build, Statistics probe) {
for (Slot slot : join.getConditionSlot()) {
if ((build.columnStatistics().containsKey(slot) && !build.columnStatistics().get(slot).isUnKnown)
|| (probe.columnStatistics().containsKey(slot) && !probe.columnStatistics().get(slot).isUnKnown)) {
continue;
}
return true;
}
return false;
}
private boolean isStatsUnknown(PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> join,
Statistics build, Statistics probe) {
for (Slot slot : join.getConditionSlot()) {
if ((build.columnStatistics().containsKey(slot) && !build.columnStatistics().get(slot).isUnKnown)
|| (probe.columnStatistics().containsKey(slot) && !probe.columnStatistics().get(slot).isUnKnown)) {
continue;
}
return true;
}
return false;
}
@Override
public Cost visitPhysicalNestedLoopJoin(
PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
@ -321,7 +358,11 @@ class CostModelV1 extends PlanVisitor<Cost, PlanContext> {
Preconditions.checkState(context.arity() == 2);
Statistics leftStatistics = context.getChildStatistics(0);
Statistics rightStatistics = context.getChildStatistics(1);
if (isStatsUnknown(nestedLoopJoin, leftStatistics, rightStatistics)) {
return CostV1.of(rightStatistics.getRowCount() + 1 / leftStatistics.getRowCount(),
rightStatistics.getRowCount(),
0);
}
return CostV1.of(
leftStatistics.getRowCount() * rightStatistics.getRowCount(),
rightStatistics.getRowCount(),

View File

@ -147,6 +147,21 @@ public class JoinEstimation {
}
private static Statistics estimateNestLoopJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = (leftStats.getRowCount() + rightStats.getRowCount());
// We do more like the nested loop join with one rows than inner join
if (leftStats.getRowCount() == 1 || rightStats.getRowCount() == 1) {
rowCount *= 0.99;
} else {
rowCount *= 1.01;
}
rowCount = Math.max(1, rowCount);
return new StatisticsBuilder()
.setRowCount(rowCount)
.putColumnStatistics(leftStats.columnStatistics())
.putColumnStatistics(rightStats.columnStatistics())
.build();
}
return new StatisticsBuilder()
.setRowCount(Math.max(1, leftStats.getRowCount() * rightStats.getRowCount()))
.putColumnStatistics(leftStats.columnStatistics())
@ -156,7 +171,7 @@ public class JoinEstimation {
private static Statistics estimateInnerJoin(Statistics leftStats, Statistics rightStats, Join join) {
if (hashJoinConditionContainsUnknownColumnStats(leftStats, rightStats, join)) {
double rowCount = Math.max(leftStats.getRowCount(), rightStats.getRowCount());
double rowCount = leftStats.getRowCount() + rightStats.getRowCount();
rowCount = Math.max(1, rowCount);
return new StatisticsBuilder()
.setRowCount(rowCount)

View File

@ -43,6 +43,7 @@ import org.apache.doris.statistics.Statistics;
import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
@ -51,6 +52,7 @@ import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Physical hash join plan.
@ -244,6 +246,11 @@ public class PhysicalHashJoin<
return pushedDown;
}
public Set<Slot> getConditionSlot() {
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}
@Override
public String shapeInfo() {
StringBuilder builder = new StringBuilder();

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.JoinHint;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
@ -31,11 +32,13 @@ import org.apache.doris.nereids.util.MutableState;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Stream;
/**
* Use nested loop algorithm to do join.
@ -173,6 +176,11 @@ public class PhysicalNestedLoopJoin<
return bitMapRuntimeFilterConditions.isEmpty();
}
public Set<Slot> getConditionSlot() {
return Stream.concat(hashJoinConjuncts.stream(), otherJoinConjuncts.stream())
.flatMap(expr -> expr.getInputSlots().stream()).collect(ImmutableSet.toImmutableSet());
}
@Override
public String shapeInfo() {
StringBuilder builder = new StringBuilder("NestedLoopJoin");