[opt](Nereids) outer join with is null stats estimation enhancement (#31875)

This commit is contained in:
xzj7019
2024-03-13 11:58:39 +08:00
committed by yiguolei
parent 94a75c27e7
commit fd1e0e933e
14 changed files with 237 additions and 24 deletions

View File

@ -67,7 +67,7 @@ public class JoinEdge extends Edge {
}
public JoinEdge withJoinTypeAndCleanCR(JoinType joinType) {
return new JoinEdge(join.withJoinType(joinType, null), getIndex(), getLeftChildEdges(), getRightChildEdges(),
return new JoinEdge(join.withJoinType(joinType), getIndex(), getLeftChildEdges(), getRightChildEdges(),
getSubTreeNodes(), getLeftRequiredNodes(), getRightRequiredNodes(), leftInputSlots, rightInputSlots);
}

View File

@ -39,12 +39,12 @@ public class ConvertInnerOrCrossJoin implements RewriteRuleFactory {
innerLogicalJoin()
.when(join -> join.getHashJoinConjuncts().isEmpty() && join.getOtherJoinConjuncts().isEmpty()
&& join.getMarkJoinConjuncts().isEmpty())
.then(join -> join.withJoinType(JoinType.CROSS_JOIN, join.getJoinReorderContext()))
.then(join -> join.withJoinTypeAndContext(JoinType.CROSS_JOIN, join.getJoinReorderContext()))
.toRule(RuleType.INNER_TO_CROSS_JOIN),
crossLogicalJoin()
.when(join -> !join.getHashJoinConjuncts().isEmpty() || !join.getOtherJoinConjuncts().isEmpty()
|| !join.getMarkJoinConjuncts().isEmpty())
.then(join -> join.withJoinType(JoinType.INNER_JOIN, join.getJoinReorderContext()))
.then(join -> join.withJoinTypeAndContext(JoinType.INNER_JOIN, join.getJoinReorderContext()))
.toRule(RuleType.CROSS_TO_INNER_JOIN)
);
}

View File

@ -68,10 +68,10 @@ public class ConvertOuterJoinToAntiJoin extends OneRewriteRuleFactory {
Plan newJoin = null;
if (join.getJoinType().isLeftOuterJoin() && !rightAlwaysNullSlots.isEmpty()) {
newJoin = join.withJoinType(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext());
newJoin = join.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, join.getJoinReorderContext());
}
if (join.getJoinType().isRightOuterJoin() && !leftAlwaysNullSlots.isEmpty()) {
newJoin = join.withJoinType(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext());
newJoin = join.withJoinTypeAndContext(JoinType.RIGHT_ANTI_JOIN, join.getJoinReorderContext());
}
if (newJoin == null) {
return null;

View File

@ -36,7 +36,7 @@ public class EliminateNullAwareLeftAntiJoin extends OneRewriteRuleFactory {
antiJoin.getOtherJoinConjuncts().stream()),
antiJoin.getMarkJoinConjuncts().stream())
.noneMatch(expression -> expression.nullable())) {
return antiJoin.withJoinType(JoinType.LEFT_ANTI_JOIN, antiJoin.getJoinReorderContext());
return antiJoin.withJoinTypeAndContext(JoinType.LEFT_ANTI_JOIN, antiJoin.getJoinReorderContext());
} else {
return null;
}

View File

@ -106,9 +106,9 @@ public class EliminateOuterJoin extends OneRewriteRuleFactory {
}
if (conjunctsChanged) {
return filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet()))
.withChildren(join.withJoinType(newJoinType, join.getJoinReorderContext()));
.withChildren(join.withJoinTypeAndContext(newJoinType, join.getJoinReorderContext()));
}
return filter.withChildren(join.withJoinType(newJoinType, join.getJoinReorderContext()));
return filter.withChildren(join.withJoinTypeAndContext(newJoinType, join.getJoinReorderContext()));
}).toRule(RuleType.ELIMINATE_OUTER_JOIN);
}

View File

@ -65,8 +65,11 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
public static final double DEFAULT_EQUALITY_COMPARISON_SELECTIVITY = 0.1;
public static final double DEFAULT_LIKE_COMPARISON_SELECTIVITY = 0.2;
public static final double DEFAULT_ISNULL_SELECTIVITY = 0.001;
private Set<Slot> aggSlots;
private boolean isOnBaseTable = false;
public FilterEstimation() {
}
@ -74,6 +77,10 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
this.aggSlots = aggSlots;
}
public FilterEstimation(boolean isOnBaseTable) {
this.isOnBaseTable = isOnBaseTable;
}
/**
* This method will update the stats according to the selectivity.
*/
@ -411,12 +418,20 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
@Override
public Statistics visitIsNull(IsNull isNull, EstimationContext context) {
ColumnStatistic childStats = ExpressionEstimation.estimate(isNull.child(), context.statistics);
if (childStats.isUnKnown()) {
ColumnStatistic childColStats = ExpressionEstimation.estimate(isNull.child(), context.statistics);
if (childColStats.isUnKnown()) {
return new StatisticsBuilder(context.statistics).build();
}
double outputRowCount = childStats.numNulls;
ColumnStatisticBuilder colBuilder = new ColumnStatisticBuilder(childStats);
double outputRowCount = childColStats.numNulls;
if (!isOnBaseTable) {
// for is null on base table, use the numNulls, otherwise
// nulls will be generated such as outer join and then we do a protection
Expression child = isNull.child();
Statistics childStats = child.accept(this, context);
outputRowCount = Math.max(childStats.getRowCount() * DEFAULT_ISNULL_SELECTIVITY, outputRowCount);
outputRowCount = Math.max(outputRowCount, 1);
}
ColumnStatisticBuilder colBuilder = new ColumnStatisticBuilder(childColStats);
colBuilder.setCount(outputRowCount).setNumNulls(outputRowCount)
.setMaxValue(Double.POSITIVE_INFINITY)
.setMinValue(Double.NEGATIVE_INFINITY)
@ -607,12 +622,8 @@ public class FilterEstimation extends ExpressionVisitor<Statistics, EstimationCo
//if (numNulls > rowCount - ndv) {
// numNulls = rowCount - ndv > 0 ? rowCount - ndv : 0;
//}
double notNullSel = rowCount <= 1.0 ? 1.0 : 1 - getValidSelectivity(numNulls / rowCount);
double notNullSel = rowCount <= 1.0 ? 1.0 : 1 - Statistics.getValidSelectivity(numNulls / rowCount);
double validSel = origSel * notNullSel;
return getValidSelectivity(validSel);
}
private static double getValidSelectivity(double nullSel) {
return nullSel < 0 ? 0 : (nullSel > 1 ? 1 : nullSel);
return Statistics.getValidSelectivity(validSel);
}
}

View File

@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.AssertNumRowsElement;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
@ -37,6 +38,8 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Aggregate;
import org.apache.doris.nereids.trees.plans.algebra.CatalogRelation;
@ -73,6 +76,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalOdbcScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSchemaScan;
@ -117,6 +121,7 @@ import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.ColumnStatistic;
import org.apache.doris.statistics.ColumnStatisticBuilder;
@ -140,6 +145,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
@ -591,8 +597,11 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
private Statistics computeFilter(Filter filter) {
Statistics stats = groupExpression.childStatistics(0);
Plan plan = tryToFindChild(groupExpression);
boolean isOnBaseTable = false;
if (plan != null) {
if (plan instanceof Aggregate) {
if (plan instanceof OlapScan) {
isOnBaseTable = true;
} else if (plan instanceof Aggregate) {
Aggregate agg = ((Aggregate<?>) plan);
List<NamedExpression> expressions = agg.getOutputExpressions();
Set<Slot> slots = expressions
@ -604,9 +613,108 @@ public class StatsCalculator extends DefaultPlanVisitor<Statistics, Void> {
if (predicate.anyMatch(s -> slots.contains(s))) {
return new FilterEstimation(slots).estimate(filter.getPredicate(), stats);
}
} else if (plan instanceof LogicalJoin && filter instanceof LogicalFilter
&& filter.getConjuncts().stream().anyMatch(e -> e instanceof IsNull)) {
Statistics isNullStats = computeGeneratedIsNullStats((LogicalJoin) plan, filter);
if (isNullStats != null) {
// overwrite the stats corrected as above before passing to filter estimation
stats = isNullStats;
Set<Expression> newConjuncts = filter.getConjuncts().stream()
.filter(e -> !(e instanceof IsNull))
.collect(Collectors.toSet());
if (newConjuncts.isEmpty()) {
return stats;
} else {
// overwrite the filter by removing is null and remain the others
filter = ((LogicalFilter<?>) filter).withConjunctsAndProps(newConjuncts,
((LogicalFilter<?>) filter).getGroupExpression(),
Optional.of(((LogicalFilter<?>) filter).getLogicalProperties()), plan);
}
}
}
}
return new FilterEstimation().estimate(filter.getPredicate(), stats);
return new FilterEstimation(isOnBaseTable).estimate(filter.getPredicate(), stats);
}
private Statistics computeGeneratedIsNullStats(LogicalJoin join, Filter filter) {
JoinType joinType = join.getJoinType();
Plan left = join.left();
Plan right = join.right();
if (left == null || right == null
|| ((GroupPlan) left).getGroup() == null || ((GroupPlan) right).getGroup() == null
|| ((GroupPlan) left).getGroup().getStatistics() == null
|| ((GroupPlan) right).getGroup().getStatistics() == null
|| !join.getGroupExpression().isPresent()) {
return null;
}
double leftRowCount = ((GroupPlan) left).getGroup().getStatistics().getRowCount();
double rightRowCount = ((GroupPlan) right).getGroup().getStatistics().getRowCount();
if (leftRowCount < 0 || Double.isInfinite(leftRowCount)
|| rightRowCount < 0 || Double.isInfinite(rightRowCount)) {
return null;
}
Statistics origJoinStats = join.getGroupExpression().get().getOwnerGroup().getStatistics();
// for outer join which is anti-like, use anti join to re-estimate the stats
// otherwise, return null and pass through to use the normal filter estimation logical
if (joinType.isOuterJoin()) {
boolean leftHasIsNull = false;
boolean rightHasIsNull = false;
boolean isLeftOuterJoin = join.getJoinType() == JoinType.LEFT_OUTER_JOIN;
boolean isRightOuterJoin = join.getJoinType() == JoinType.RIGHT_OUTER_JOIN;
boolean isFullOuterJoin = join.getJoinType() == JoinType.FULL_OUTER_JOIN;
for (Expression expr : filter.getConjuncts()) {
if (expr instanceof IsNull) {
Expression child = ((IsNull) expr).child();
if (PlanUtils.isColumnRef(child)) {
LogicalPlan leftChild = (LogicalPlan) join.left();
LogicalPlan rightChild = (LogicalPlan) join.right();
leftHasIsNull = PlanUtils.checkSlotFrom(((GroupPlan) leftChild)
.getGroup().getLogicalExpression().getPlan(), (SlotReference) child);
rightHasIsNull = PlanUtils.checkSlotFrom(((GroupPlan) rightChild)
.getGroup().getLogicalExpression().getPlan(), (SlotReference) child);
}
}
}
boolean isLeftAntiLikeJoin = (isLeftOuterJoin && rightHasIsNull) || (isFullOuterJoin && rightHasIsNull);
boolean isRightAntiLikeJoin = (isRightOuterJoin && leftHasIsNull) || (isFullOuterJoin && leftHasIsNull);
if (isLeftAntiLikeJoin || isRightAntiLikeJoin) {
// transform to anti estimation
Statistics newStats = null;
if (isLeftAntiLikeJoin) {
LogicalJoin<GroupPlan, GroupPlan> newJoin = join.withJoinType(JoinType.LEFT_ANTI_JOIN);
StatsCalculator statsCalculator = new StatsCalculator(join.getGroupExpression().get(),
false, getTotalColumnStatisticMap(), false,
cteIdToStats, cascadesContext);
newStats = ((Plan) newJoin).accept(statsCalculator, null);
} else if (isRightAntiLikeJoin) {
LogicalJoin<GroupPlan, GroupPlan> newJoin = join.withJoinType(JoinType.RIGHT_ANTI_JOIN);
StatsCalculator statsCalculator = new StatsCalculator(join.getGroupExpression().get(),
false, this.getTotalColumnStatisticMap(), false,
this.cteIdToStats, this.cascadesContext);
newStats = ((Plan) newJoin).accept(statsCalculator, null);
}
newStats.enforceValid();
double selectivity = Statistics.getValidSelectivity(
newStats.getRowCount() / (leftRowCount * rightRowCount));
double newRows = origJoinStats.getRowCount() * selectivity;
newStats.withRowCount(newRows);
return newStats;
} else {
return null;
}
} else {
return null;
}
}
private ColumnStatistic getColumnStatistic(TableIf table, String colName, long idxId) {

View File

@ -142,6 +142,12 @@ public class LogicalFilter<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
return new LogicalFilter<>(conjuncts, child);
}
public LogicalFilter<Plan> withConjunctsAndProps(Set<Expression> conjuncts,
Optional<GroupExpression> groupExpression,
Optional<LogicalProperties> logicalProperties, Plan child) {
return new LogicalFilter<>(conjuncts, groupExpression, logicalProperties, child);
}
@Override
public FunctionalDependencies computeFuncDeps(Supplier<List<Slot>> outputSupplier) {
Builder fdBuilder = new Builder(

View File

@ -411,7 +411,14 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends
ImmutableList.of(left, right), otherJoinReorderContext);
}
public LogicalJoin<Plan, Plan> withJoinType(JoinType joinType, JoinReorderContext otherJoinReorderContext) {
public LogicalJoin<Plan, Plan> withJoinType(JoinType joinType) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, groupExpression, Optional.of(getLogicalProperties()),
children, joinReorderContext);
}
public LogicalJoin<Plan, Plan> withJoinTypeAndContext(JoinType joinType,
JoinReorderContext otherJoinReorderContext) {
return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, markJoinConjuncts,
hint, markJoinSlotReference, Optional.empty(), Optional.empty(),
children, otherJoinReorderContext);

View File

@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor;
@ -143,6 +144,28 @@ public class PlanUtils {
return output.build();
}
/**
* Check if slot is from the plan.
*/
public static boolean checkSlotFrom(Plan plan, SlotReference slot) {
Set<LogicalCatalogRelation> tableSets = PlanUtils.getLogicalScanFromRootPlan((LogicalPlan) plan);
for (LogicalCatalogRelation table : tableSets) {
if (table.getOutputExprIds().contains(slot.getExprId())) {
return true;
}
}
return false;
}
/**
* Check if the expression is a column reference.
*/
public static boolean isColumnRef(Expression expr) {
return expr instanceof SlotReference
&& ((SlotReference) expr).getColumn().isPresent()
&& ((SlotReference) expr).getTable().isPresent();
}
/**
* collect non_window_agg_func
*/

View File

@ -166,6 +166,10 @@ public class Statistics {
return zero;
}
public static double getValidSelectivity(double nullSel) {
return nullSel < 0 ? 0 : (nullSel > 1 ? 1 : nullSel);
}
/**
* merge this and other colStats.ndv, choose min
*/

View File

@ -1100,7 +1100,7 @@ class FilterEstimationTest {
Or or = new Or(greaterThanEqual, isNull);
Statistics stats = new Statistics(10, new HashMap<>());
stats.addColumnStats(a, builder.build());
FilterEstimation filterEstimation = new FilterEstimation();
FilterEstimation filterEstimation = new FilterEstimation(true);
Statistics result = filterEstimation.estimate(or, stats);
Assertions.assertEquals(result.getRowCount(), 10.0, 0.01);
}

View File

@ -289,11 +289,11 @@ public class HyperGraphBuilder {
Random random = new Random();
int randomIndex = random.nextInt(values.length);
DistributeType hint = values[randomIndex];
Plan hintJoin = ((LogicalJoin) join.withChildren(left, right)).withJoinType(joinType, null);
Plan hintJoin = ((LogicalJoin) join.withChildren(left, right)).withJoinTypeAndContext(joinType, null);
((LogicalJoin) hintJoin).setHint(new DistributeHint(hint));
return hintJoin;
}
return ((LogicalJoin) join.withChildren(left, right)).withJoinType(joinType, null);
return ((LogicalJoin) join.withChildren(left, right)).withJoinTypeAndContext(joinType, null);
}
private Optional<BitSet> findPlan(BitSet bitSet) {

View File

@ -0,0 +1,54 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("test_outerjoin_isnull_estimation") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
sql "DROP TABLE IF EXISTS test_outerjoin_isnull_estimation1"
sql """ CREATE TABLE `test_outerjoin_isnull_estimation1` (
c1 int, c2 int, c3 int
)ENGINE=OLAP
distributed by hash(c1) buckets 10
properties(
"replication_allocation" = "tag.location.default: 1"
);"""
sql "DROP TABLE IF EXISTS test_outerjoin_isnull_estimation2"
sql """ CREATE TABLE `test_outerjoin_isnull_estimation2` (
c1 int, c2 int, c3 int
)ENGINE=OLAP
distributed by hash(c1) buckets 10
properties(
"replication_allocation" = "tag.location.default: 1"
);"""
sql "insert into test_outerjoin_isnull_estimation1 values (1,1,1);"
sql "insert into test_outerjoin_isnull_estimation1 values (2,2,2);"
sql "insert into test_outerjoin_isnull_estimation2 values (3,3,3);"
sql "insert into test_outerjoin_isnull_estimation2 values (4,4,4);"
sql "analyze table test_outerjoin_isnull_estimation1 with full with sync;"
sql "analyze table test_outerjoin_isnull_estimation2 with full with sync;"
explain {
sql("physical plan select t1.c1, t1.c2 from test_outerjoin_isnull_estimation1 t1" +
" left join test_outerjoin_isnull_estimation1 t2 on t1.c1 = t2.c1 where t2.c2 is null;");
contains"stats=1, predicates=c2#4 IS NULL"
notContains"stats=0"
}
}