From dea40e7095c28f72b313c9071dfc0f4d6d4ecb5c Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 21 Nov 2023 19:08:14 +0800 Subject: [PATCH] [fix](Nereids): NullSafeEqual should be in HashJoinCondition (#27127) Originally, we just put `EqualTo` in `HashJoinCondition`, we also need to allow `NullSafeEqual` --- .../translator/PhysicalPlanTranslator.java | 4 +- .../post/RuntimeFilterGenerator.java | 14 ++--- .../processor/post/RuntimeFilterPruner.java | 2 +- .../rules/rewrite/EliminateOuterJoin.java | 23 ++++--- .../PushdownExpressionsInHashCondition.java | 7 +-- .../AbstractSelectMaterializedIndexRule.java | 5 +- .../doris/nereids/stats/FilterEstimation.java | 6 +- .../doris/nereids/stats/JoinEstimation.java | 28 ++++----- .../trees/expressions/EqualPredicate.java | 36 +++++++++++ .../nereids/trees/expressions/EqualTo.java | 4 +- .../trees/expressions/NullSafeEqual.java | 11 +--- .../nereids/trees/plans/algebra/Join.java | 7 +++ .../plans/physical/PhysicalHashJoin.java | 4 +- .../apache/doris/nereids/util/JoinUtils.java | 42 ++++--------- .../data/nereids_p0/join/test_join_15.out | 60 +++++++++++++++++-- .../nereids_p0/join/test_join_15.groovy | 6 +- 16 files changed, 160 insertions(+), 99 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index 6062332012..d4a8849aee 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -69,7 +69,7 @@ import org.apache.doris.nereids.stats.StatsErrorEstimator; import org.apache.doris.nereids.trees.UnaryNode; import org.apache.doris.nereids.trees.expressions.AggregateExpression; import org.apache.doris.nereids.trees.expressions.CTEId; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -1138,7 +1138,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor execEqConjuncts = hashJoin.getHashJoinConjuncts().stream() - .map(EqualTo.class::cast) + .map(EqualPredicate.class::cast) .map(e -> JoinUtils.swapEqualToForChildrenOrder(e, hashJoin.left().getOutputSet())) .map(e -> ExpressionTranslator.translate(e, context)) .collect(Collectors.toList()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java index 579fe7485a..cff906df20 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterGenerator.java @@ -363,9 +363,10 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { List legalTypes = Arrays.stream(TRuntimeFilterType.values()) .filter(type -> (type.getValue() & ctx.getSessionVariable().getRuntimeFilterType()) > 0) .collect(Collectors.toList()); - for (int i = 0; i < join.getHashJoinConjuncts().size(); i++) { + List hashJoinConjuncts = join.getEqualToConjuncts(); + for (int i = 0; i < hashJoinConjuncts.size(); i++) { EqualTo equalTo = ((EqualTo) JoinUtils.swapEqualToForChildrenOrder( - (EqualTo) join.getHashJoinConjuncts().get(i), join.left().getOutputSet())); + hashJoinConjuncts.get(i), join.left().getOutputSet())); for (TRuntimeFilterType type : legalTypes) { //bitmap rf is generated by nested loop join. if (type == TRuntimeFilterType.BITMAP) { @@ -487,7 +488,7 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { || !(join.getHashJoinConjuncts().get(0) instanceof EqualTo)) { break; } else { - EqualTo equalTo = (EqualTo) join.getHashJoinConjuncts().get(0); + EqualTo equalTo = join.getEqualToConjuncts().get(0); equalTos.add(equalTo); equalCondToJoinMap.put(equalTo, join); } @@ -523,12 +524,11 @@ public class RuntimeFilterGenerator extends PlanPostProcessor { // check further whether the join upper side can bring equal set, which // indicating actually the same runtime filter build side // see above case 2 for reference - List conditions = curJoin.getHashJoinConjuncts(); boolean inSameEqualSet = false; - for (Expression e : conditions) { + for (EqualTo e : curJoin.getEqualToConjuncts()) { if (e instanceof EqualTo) { - SlotReference oneSide = (SlotReference) ((EqualTo) e).left(); - SlotReference anotherSide = (SlotReference) ((EqualTo) e).right(); + SlotReference oneSide = (SlotReference) e.left(); + SlotReference anotherSide = (SlotReference) e.right(); if (anotherSideSlotSet.contains(oneSide) && anotherSideSlotSet.contains(anotherSide)) { inSameEqualSet = true; break; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java index e32d12edac..b39bb8ec18 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/RuntimeFilterPruner.java @@ -85,7 +85,7 @@ public class RuntimeFilterPruner extends PlanPostProcessor { List exprIds = ctx.getTargetExprIdByFilterJoin(join); if (exprIds != null && !exprIds.isEmpty()) { boolean isEffective = false; - for (Expression expr : join.getHashJoinConjuncts()) { + for (Expression expr : join.getEqualToConjuncts()) { if (isEffectiveRuntimeFilter((EqualTo) expr, join)) { isEffective = true; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java index 440e5d73ae..1afd6a175f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateOuterJoin.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.IsNull; @@ -89,24 +90,20 @@ public class EliminateOuterJoin extends OneRewriteRuleFactory { * * TODO: is_not_null can also be inferred from A < B and so on */ - conjunctsChanged |= join.getHashJoinConjuncts().stream() + conjunctsChanged |= join.getEqualToConjuncts().stream() .map(EqualTo.class::cast) - .map(equalTo -> - (EqualTo) JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet())) - .map(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts) - ).anyMatch(Boolean::booleanValue); + .map(equalTo -> JoinUtils.swapEqualToForChildrenOrder(equalTo, join.left().getOutputSet())) + .anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts)); JoinUtils.JoinSlotCoverageChecker checker = new JoinUtils.JoinSlotCoverageChecker( join.left().getOutput(), join.right().getOutput()); - conjunctsChanged |= join.getOtherJoinConjuncts().stream().filter(EqualTo.class::isInstance) - .map(EqualTo.class::cast) - .filter(equalTo -> checker.isHashJoinCondition(equalTo)) - .map(equalTo -> (EqualTo) JoinUtils.swapEqualToForChildrenOrder(equalTo, + conjunctsChanged |= join.getOtherJoinConjuncts().stream() + .filter(EqualTo.class::isInstance) + .filter(equalTo -> checker.isHashJoinCondition((EqualPredicate) equalTo)) + .map(equalTo -> JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) equalTo, join.left().getOutputSet())) - .map(equalTo -> - createIsNotNullIfNecessary(equalTo, conjuncts)) - .anyMatch(Boolean::booleanValue); + .anyMatch(equalTo -> createIsNotNullIfNecessary(equalTo, conjuncts)); } if (conjunctsChanged) { return filter.withConjuncts(conjuncts.stream().collect(ImmutableSet.toImmutableSet())) @@ -135,7 +132,7 @@ public class EliminateOuterJoin extends OneRewriteRuleFactory { return joinType; } - private boolean createIsNotNullIfNecessary(EqualTo swapedEqualTo, Collection container) { + private boolean createIsNotNullIfNecessary(EqualPredicate swapedEqualTo, Collection container) { boolean containerChanged = false; if (swapedEqualTo.left().nullable()) { Not not = new Not(new IsNull(swapedEqualTo.left())); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java index 05da591526..df7acb4553 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownExpressionsInHashCondition.java @@ -20,7 +20,7 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -77,11 +77,10 @@ public class PushdownExpressionsInHashCondition extends OneRewriteRuleFactory { Set rightProjectExprs = Sets.newHashSet(); Map exprReplaceMap = Maps.newHashMap(); join.getHashJoinConjuncts().forEach(conjunct -> { - Preconditions.checkArgument(conjunct instanceof EqualTo); + Preconditions.checkArgument(conjunct instanceof EqualPredicate); // sometimes: t1 join t2 on t2.a + 1 = t1.a + 2, so check the situation, but actually it // doesn't swap the two sides. - conjunct = JoinUtils.swapEqualToForChildrenOrder( - (EqualTo) conjunct, join.left().getOutputSet()); + conjunct = JoinUtils.swapEqualToForChildrenOrder((EqualPredicate) conjunct, join.left().getOutputSet()); generateReplaceMapAndProjectExprs(conjunct.child(0), exprReplaceMap, leftProjectExprs); generateReplaceMapAndProjectExprs(conjunct.child(1), exprReplaceMap, rightProjectExprs); }); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java index 012dec4c91..c1550cb5bd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java @@ -27,13 +27,12 @@ import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.CaseWhen; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.IsNull; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.WhenClause; @@ -306,7 +305,7 @@ public abstract class AbstractSelectMaterializedIndexRule { @Override public PrefixIndexCheckResult visitComparisonPredicate(ComparisonPredicate cp, Map context) { - if (cp instanceof EqualTo || cp instanceof NullSafeEqual) { + if (cp instanceof EqualPredicate) { return check(cp, context, PrefixIndexCheckResult::createEqual); } else { return check(cp, context, PrefixIndexCheckResult::createNonEqual); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java index 055f3b88a0..e2d7f40622 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/stats/FilterEstimation.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.TreeNode; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; @@ -33,7 +34,6 @@ import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.Like; import org.apache.doris.nereids.trees.expressions.Not; -import org.apache.doris.nereids.trees.expressions.NullSafeEqual; import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; @@ -210,7 +210,7 @@ public class FilterEstimation extends ExpressionVisitor rightStats.findColumnStatistics(slot) != null - ); + private static EqualPredicate normalizeHashJoinCondition(EqualPredicate equal, Statistics leftStats, + Statistics rightStats) { + boolean changeOrder = equal.left().getInputSlots().stream() + .anyMatch(slot -> rightStats.findColumnStatistics(slot) != null); if (changeOrder) { - return new EqualTo(equalTo.right(), equalTo.left()); + return equal.commute(); } else { - return equalTo; + return equal; } } @@ -81,18 +81,18 @@ public class JoinEstimation { * In order to avoid error propagation, for unTrustEquations, we only use the biggest selectivity. */ List unTrustEqualRatio = Lists.newArrayList(); - List unTrustableCondition = Lists.newArrayList(); + List unTrustableCondition = Lists.newArrayList(); boolean leftBigger = leftStats.getRowCount() > rightStats.getRowCount(); double rightStatsRowCount = StatsMathUtil.nonZeroDivisor(rightStats.getRowCount()); double leftStatsRowCount = StatsMathUtil.nonZeroDivisor(leftStats.getRowCount()); - List trustableConditions = join.getHashJoinConjuncts().stream() - .map(expression -> (EqualTo) expression) + List trustableConditions = join.getHashJoinConjuncts().stream() + .map(expression -> (EqualPredicate) expression) .filter( expression -> { // since ndv is not accurate, if ndv/rowcount < almostUniqueThreshold, // this column is regarded as unique. double almostUniqueThreshold = 0.9; - EqualTo equal = normalizeHashJoinCondition(expression, leftStats, rightStats); + EqualPredicate equal = normalizeHashJoinCondition(expression, leftStats, rightStats); ColumnStatistic eqLeftColStats = ExpressionEstimation.estimate(equal.left(), leftStats); ColumnStatistic eqRightColStats = ExpressionEstimation.estimate(equal.right(), rightStats); boolean trustable = eqRightColStats.ndv / rightStatsRowCount > almostUniqueThreshold @@ -204,7 +204,7 @@ public class JoinEstimation { } private static double estimateSemiOrAntiRowCountBySlotsEqual(Statistics leftStats, - Statistics rightStats, Join join, EqualTo equalTo) { + Statistics rightStats, Join join, EqualPredicate equalTo) { Expression eqLeft = equalTo.left(); Expression eqRight = equalTo.right(); ColumnStatistic probColStats = leftStats.findColumnStatistics(eqLeft); @@ -261,7 +261,7 @@ public class JoinEstimation { double rowCount = Double.POSITIVE_INFINITY; for (Expression conjunct : join.getHashJoinConjuncts()) { double eqRowCount = estimateSemiOrAntiRowCountBySlotsEqual(leftStats, rightStats, - join, (EqualTo) conjunct); + join, (EqualPredicate) conjunct); if (rowCount > eqRowCount) { rowCount = eqRowCount; } @@ -336,7 +336,7 @@ public class JoinEstimation { private static Statistics updateJoinResultStatsByHashJoinCondition(Statistics innerStats, Join join) { Map updatedCols = new HashMap<>(); for (Expression expr : join.getHashJoinConjuncts()) { - EqualTo equalTo = (EqualTo) expr; + EqualPredicate equalTo = (EqualPredicate) expr; ColumnStatistic leftColStats = ExpressionEstimation.estimate(equalTo.left(), innerStats); ColumnStatistic rightColStats = ExpressionEstimation.estimate(equalTo.right(), innerStats); double minNdv = Math.min(leftColStats.ndv, rightColStats.ndv); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java new file mode 100644 index 0000000000..3f61bd3cf6 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualPredicate.java @@ -0,0 +1,36 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import java.util.List; + +/** + * EqualPredicate + */ +public abstract class EqualPredicate extends ComparisonPredicate { + + protected EqualPredicate(List children, String symbol) { + super(children, symbol); + } + + @Override + public EqualPredicate commute() { + return null; + } +} + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java index 065f6b9340..3faccff6d9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/EqualTo.java @@ -29,7 +29,7 @@ import java.util.List; /** * Equal to expression: a = b. */ -public class EqualTo extends ComparisonPredicate implements PropagateNullable { +public class EqualTo extends EqualPredicate implements PropagateNullable { public EqualTo(Expression left, Expression right) { super(ImmutableList.of(left, right), "="); @@ -55,7 +55,7 @@ public class EqualTo extends ComparisonPredicate implements PropagateNullable { } @Override - public ComparisonPredicate commute() { + public EqualTo commute() { return new EqualTo(right(), left()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java index c2b63aebbd..48d05364fa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/NullSafeEqual.java @@ -29,13 +29,7 @@ import java.util.List; * Null safe equal expression: a <=> b. * Unlike normal equal to expression, null <=> null is true. */ -public class NullSafeEqual extends ComparisonPredicate implements AlwaysNotNullable { - /** - * Constructor of Null Safe Equal ComparisonPredicate. - * - * @param left left child of Null Safe Equal - * @param right right child of Null Safe Equal - */ +public class NullSafeEqual extends EqualPredicate implements AlwaysNotNullable { public NullSafeEqual(Expression left, Expression right) { super(ImmutableList.of(left, right), "<=>"); } @@ -61,8 +55,7 @@ public class NullSafeEqual extends ComparisonPredicate implements AlwaysNotNulla } @Override - public ComparisonPredicate commute() { + public NullSafeEqual commute() { return new NullSafeEqual(right(), left()); } - } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java index 77bf6c9148..3f96c4d11c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.trees.plans.algebra; +import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; import org.apache.doris.nereids.trees.plans.JoinHint; @@ -25,6 +26,7 @@ import org.apache.doris.nereids.trees.plans.JoinType; import java.util.List; import java.util.Optional; +import java.util.stream.Collectors; /** * Common interface for logical/physical join. @@ -34,6 +36,11 @@ public interface Join { List getHashJoinConjuncts(); + default List getEqualToConjuncts() { + return getHashJoinConjuncts().stream().filter(EqualTo.class::isInstance).map(EqualTo.class::cast) + .collect(Collectors.toList()); + } + List getOtherJoinConjuncts(); Optional getOnClauseCondition(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java index 0041812796..183ccaabfa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/PhysicalHashJoin.java @@ -25,7 +25,7 @@ import org.apache.doris.nereids.processor.post.RuntimeFilterContext; import org.apache.doris.nereids.processor.post.RuntimeFilterGenerator; import org.apache.doris.nereids.properties.LogicalProperties; import org.apache.doris.nereids.properties.PhysicalProperties; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference; @@ -213,7 +213,7 @@ public class PhysicalHashJoin< if (ConnectContext.get() != null && ConnectContext.get().getSessionVariable().expandRuntimeFilterByInnerJoin) { if (!this.equals(builderNode) && this.getJoinType() == JoinType.INNER_JOIN) { for (Expression expr : this.getHashJoinConjuncts()) { - EqualTo equalTo = (EqualTo) expr; + EqualPredicate equalTo = (EqualPredicate) expr; if (probeExpr.equals(equalTo.left())) { probExprList.add(equalTo.right()); } else if (probeExpr.equals(equalTo.right())) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index bcf53ce29f..862bf02e46 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -24,7 +24,7 @@ import org.apache.doris.nereids.properties.DistributionSpec; import org.apache.doris.nereids.properties.DistributionSpecHash; import org.apache.doris.nereids.properties.DistributionSpecHash.ShuffleType; import org.apache.doris.nereids.properties.DistributionSpecReplicated; -import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.EqualPredicate; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Not; @@ -88,22 +88,6 @@ public class JoinUtils { rightExprIds = right.stream().map(Slot::getExprId).collect(Collectors.toSet()); } - JoinSlotCoverageChecker(Set left, Set right) { - leftExprIds = left; - rightExprIds = right; - } - - /** - * PushDownExpressionInHashConjuncts ensure the "slots" is only one slot. - */ - boolean isCoveredByLeftSlots(ExprId slot) { - return leftExprIds.contains(slot); - } - - boolean isCoveredByRightSlots(ExprId slot) { - return rightExprIds.contains(slot); - } - /** * consider following cases: * 1# A=1 => not for hash table @@ -112,25 +96,20 @@ public class JoinUtils { * 4# t1.a=t2.a or t1.b=t2.b not for hash table * 5# t1.a > 1 not for hash table * - * @param equalTo a conjunct in on clause condition + * @param equal a conjunct in on clause condition * @return true if the equal can be used as hash join condition */ - public boolean isHashJoinCondition(EqualTo equalTo) { - Set equalLeft = equalTo.left().getInputSlots(); - if (equalLeft.isEmpty()) { + public boolean isHashJoinCondition(EqualPredicate equal) { + Set equalLeftExprIds = equal.left().getInputSlotExprIds(); + if (equalLeftExprIds.isEmpty()) { return false; } - Set equalRight = equalTo.right().getInputSlots(); - if (equalRight.isEmpty()) { + Set equalRightExprIds = equal.right().getInputSlotExprIds(); + if (equalRightExprIds.isEmpty()) { return false; } - List equalLeftExprIds = equalLeft.stream() - .map(Slot::getExprId).collect(Collectors.toList()); - - List equalRightExprIds = equalRight.stream() - .map(Slot::getExprId).collect(Collectors.toList()); return leftExprIds.containsAll(equalLeftExprIds) && rightExprIds.containsAll(equalRightExprIds) || leftExprIds.containsAll(equalRightExprIds) && rightExprIds.containsAll(equalLeftExprIds); } @@ -147,9 +126,8 @@ public class JoinUtils { public static Pair, List> extractExpressionForHashTable(List leftSlots, List rightSlots, List onConditions) { JoinSlotCoverageChecker checker = new JoinSlotCoverageChecker(leftSlots, rightSlots); - Map> mapper = onConditions.stream() - .collect(Collectors.groupingBy( - expr -> (expr instanceof EqualTo) && checker.isHashJoinCondition((EqualTo) expr))); + Map> mapper = onConditions.stream().collect(Collectors.groupingBy( + expr -> (expr instanceof EqualPredicate) && checker.isHashJoinCondition((EqualPredicate) expr))); return Pair.of( mapper.getOrDefault(true, ImmutableList.of()), mapper.getOrDefault(false, ImmutableList.of()) @@ -205,7 +183,7 @@ public class JoinUtils { * The left child of origin predicate is t2.id and the right child of origin predicate is t1.id. * In this situation, the children of predicate need to be swap => t1.id=t2.id. */ - public static Expression swapEqualToForChildrenOrder(EqualTo equalTo, Set leftOutput) { + public static EqualPredicate swapEqualToForChildrenOrder(EqualPredicate equalTo, Set leftOutput) { if (leftOutput.containsAll(equalTo.left().getInputSlots())) { return equalTo; } else { diff --git a/regression-test/data/nereids_p0/join/test_join_15.out b/regression-test/data/nereids_p0/join/test_join_15.out index 5c9df35ba7..e535253a28 100644 --- a/regression-test/data/nereids_p0/join/test_join_15.out +++ b/regression-test/data/nereids_p0/join/test_join_15.out @@ -172,12 +172,20 @@ false true true false false 3 \N null 2019-09-09 \N 8.9 3 \N null 2019-09-09 \N 8.9 5 \N null \N 2019-09-09T00:00 8.9 5 \N null \N 2019-09-09T00:00 8.9 --- !hash_join -- +-- !hash_right_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 \N \N \N \N \N \N 3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N 5 \N null \N 2019-09-09T00:00 8.9 +-- !hash_left_join -- +1 \N null \N \N 8.9 \N \N \N \N \N \N +2 \N 2 \N \N 8.9 \N \N \N \N \N \N +3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N +5 \N null \N 2019-09-09T00:00 8.9 \N \N \N \N \N \N + +-- !hash_inner_join -- + -- !cross_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 @@ -226,12 +234,20 @@ false true true false false 5 \N null \N 2019-09-09T00:00 8.9 3 \N null 2019-09-09 \N 8.9 5 \N null \N 2019-09-09T00:00 8.9 5 \N null \N 2019-09-09T00:00 8.9 --- !hash_join -- +-- !hash_right_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 \N \N \N \N \N \N 3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N 5 \N null \N 2019-09-09T00:00 8.9 +-- !hash_left_join -- +1 \N null \N \N 8.9 \N \N \N \N \N \N +2 \N 2 \N \N 8.9 \N \N \N \N \N \N +3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N +5 \N null \N 2019-09-09T00:00 8.9 \N \N \N \N \N \N + +-- !hash_inner_join -- + -- !cross_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 @@ -271,12 +287,20 @@ false true true false false 5 \N null \N 2019-09-09T00:00 8.9 3 \N null 2019-09-09 \N 8.9 5 \N null \N 2019-09-09T00:00 8.9 5 \N null \N 2019-09-09T00:00 8.9 --- !hash_join -- +-- !hash_right_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 \N \N \N \N \N \N 3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N 5 \N null \N 2019-09-09T00:00 8.9 +-- !hash_left_join -- +1 \N null \N \N 8.9 \N \N \N \N \N \N +2 \N 2 \N \N 8.9 \N \N \N \N \N \N +3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N +5 \N null \N 2019-09-09T00:00 8.9 \N \N \N \N \N \N + +-- !hash_inner_join -- + -- !cross_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 @@ -314,12 +338,20 @@ false true true false false 5 \N null \N 2019-09-09T00:00 8.9 2 \N 2 \N \N 8.9 5 \N null \N 2019-09-09T00:00 8.9 5 \N null \N 2019-09-09T00:00 8.9 --- !hash_join -- +-- !hash_right_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 \N \N \N \N \N \N 3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N 5 \N null \N 2019-09-09T00:00 8.9 +-- !hash_left_join -- +1 \N null \N \N 8.9 \N \N \N \N \N \N +2 \N 2 \N \N 8.9 \N \N \N \N \N \N +3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N +5 \N null \N 2019-09-09T00:00 8.9 \N \N \N \N \N \N + +-- !hash_inner_join -- + -- !cross_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 @@ -357,12 +389,20 @@ false true true false false 3 \N null 2019-09-09 \N 8.9 3 \N null 2019-09-09 \N 8.9 5 \N null \N 2019-09-09T00:00 8.9 5 \N null \N 2019-09-09T00:00 8.9 --- !hash_join -- +-- !hash_right_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 \N \N \N \N \N \N 3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N 5 \N null \N 2019-09-09T00:00 8.9 +-- !hash_left_join -- +1 \N null \N \N 8.9 \N \N \N \N \N \N +2 \N 2 \N \N 8.9 \N \N \N \N \N \N +3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N +5 \N null \N 2019-09-09T00:00 8.9 \N \N \N \N \N \N + +-- !hash_inner_join -- + -- !cross_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 @@ -412,12 +452,20 @@ false true true false false 5 \N null \N 2019-09-09T00:00 8.9 3 \N null 2019-09-09 \N 8.9 5 \N null \N 2019-09-09T00:00 8.9 5 \N null \N 2019-09-09T00:00 8.9 --- !hash_join -- +-- !hash_right_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 \N \N \N \N \N \N 3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N 5 \N null \N 2019-09-09T00:00 8.9 +-- !hash_left_join -- +1 \N null \N \N 8.9 \N \N \N \N \N \N +2 \N 2 \N \N 8.9 \N \N \N \N \N \N +3 \N null 2019-09-09 \N 8.9 \N \N \N \N \N \N +5 \N null \N 2019-09-09T00:00 8.9 \N \N \N \N \N \N + +-- !hash_inner_join -- + -- !cross_join -- \N \N \N \N \N \N 1 \N null \N \N 8.9 \N \N \N \N \N \N 2 \N 2 \N \N 8.9 diff --git a/regression-test/suites/nereids_p0/join/test_join_15.groovy b/regression-test/suites/nereids_p0/join/test_join_15.groovy index 9778e45eb1..22e6f8a06a 100644 --- a/regression-test/suites/nereids_p0/join/test_join_15.groovy +++ b/regression-test/suites/nereids_p0/join/test_join_15.groovy @@ -192,7 +192,11 @@ suite("test_join_15", "nereids_p0") { order by a.k1, b.k1""" qt_right_join"""select * from ${null_table_1} a right join ${null_table_1} b on a.k${index}<=>b.k${index} order by a.k1, b.k1""" - qt_hash_join"""select * from ${null_table_1} a right join ${null_table_1} b on a.k${index}<=>b.k${index} and a.k2=b.k2 + qt_hash_right_join"""select * from ${null_table_1} a right join ${null_table_1} b on a.k${index}<=>b.k${index} and a.k2=b.k2 + order by a.k1, b.k1""" + qt_hash_left_join"""select * from ${null_table_1} a left join ${null_table_1} b on a.k${index}<=>b.k${index} and a.k2=b.k2 + order by a.k1, b.k1""" + qt_hash_inner_join"""select * from ${null_table_1} a inner join ${null_table_1} b on a.k${index}<=>b.k${index} and a.k2=b.k2 order by a.k1, b.k1""" qt_cross_join"""select * from ${null_table_1} a right join ${null_table_1} b on a.k${index}<=>b.k${index} and a.k2 !=b.k2 order by a.k1, b.k1"""