diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index 546d7f3e74..96b3dd3e0a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -39,7 +39,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalPro import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort; import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; -import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuter; +import org.apache.doris.nereids.rules.rewrite.logical.EliminateOuterJoin; import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters; import org.apache.doris.nereids.rules.rewrite.logical.MergeLimits; import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects; @@ -79,7 +79,7 @@ public class RuleSet { new PushdownExpressionsInHashCondition(), new PushdownProjectThroughLimit(), new PushdownFilterThroughProject(), - EliminateOuter.INSTANCE, + new EliminateOuterJoin(), new MergeProjects(), new MergeFilters(), new MergeLimits()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 2ee123c482..cc38568c32 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -107,7 +107,7 @@ public enum RuleType { // Eliminate plan ELIMINATE_LIMIT(RuleTypeClass.REWRITE), ELIMINATE_FILTER(RuleTypeClass.REWRITE), - ELIMINATE_OUTER(RuleTypeClass.REWRITE), + ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE), FIND_HASH_CONDITION_FOR_JOIN(RuleTypeClass.REWRITE), MATERIALIZED_INDEX_AGG_SCAN(RuleTypeClass.REWRITE), MATERIALIZED_INDEX_AGG_FILTER_SCAN(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuter.java deleted file mode 100644 index 10a03df464..0000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuter.java +++ /dev/null @@ -1,91 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.rewrite.logical; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; -import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; - -import com.google.common.collect.ImmutableMap; - -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; - -/** - * Eliminate outer. - */ -public class EliminateOuter extends OneRewriteRuleFactory { - public static EliminateOuter INSTANCE = new EliminateOuter(); - - // right nullable - public static Map ELIMINATE_LEFT_MAP = ImmutableMap.of( - JoinType.LEFT_OUTER_JOIN, JoinType.INNER_JOIN, - JoinType.FULL_OUTER_JOIN, JoinType.RIGHT_OUTER_JOIN - ); - - // left nullable - public static Map ELIMINATE_RIGHT_MAP = ImmutableMap.of( - JoinType.RIGHT_OUTER_JOIN, JoinType.INNER_JOIN, - JoinType.FULL_OUTER_JOIN, JoinType.LEFT_OUTER_JOIN - ); - - @Override - public Rule build() { - return logicalFilter(logicalJoin()) - .when(filter -> filter.child().getJoinType().isOuterJoin()) - .then(filter -> { - List predicates = ExpressionUtils.extractConjunction(filter.getPredicates()); - Set notNullSlots = new HashSet<>(); - for (Expression predicate : predicates) { - // TODO: more case. - if (predicate instanceof ComparisonPredicate) { - notNullSlots.addAll(predicate.getInputSlots()); - } - } - LogicalJoin join = filter.child(); - JoinType joinType = join.getJoinType(); - if (!joinType.isLeftOuterJoin() && ExpressionUtils.isIntersecting(join.left().getOutputSet(), - notNullSlots)) { - joinType = ELIMINATE_RIGHT_MAP.get(joinType); - } - if (!joinType.isRightOuterJoin() && ExpressionUtils.isIntersecting(join.right().getOutputSet(), - notNullSlots)) { - joinType = ELIMINATE_LEFT_MAP.get(joinType); - } - - if (joinType == join.getJoinType()) { - return null; - } - - return new LogicalFilter<>(filter.getPredicates(), - new LogicalJoin<>(joinType, - join.getHashJoinConjuncts(), join.getOtherJoinConjuncts(), - join.left(), join.right())); - }).toRule(RuleType.ELIMINATE_OUTER); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuterJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuterJoin.java new file mode 100644 index 0000000000..f88a1ab6b0 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuterJoin.java @@ -0,0 +1,98 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; +import org.apache.doris.nereids.rules.expression.rewrite.rules.FoldConstantRule; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Maps; + +import java.util.List; +import java.util.Map; + +/** + * Eliminate outer join. + */ +public class EliminateOuterJoin extends OneRewriteRuleFactory { + + @Override + public Rule build() { + return logicalFilter( + logicalJoin().when(join -> join.getJoinType().isOuterJoin()) + ).then(filter -> { + LogicalJoin join = filter.child(); + List conjuncts = filter.getConjuncts(); + List leftPredicates = ExpressionUtils.extractCoveredConjunction(conjuncts, + join.left().getOutputSet()); + List rightPredicates = ExpressionUtils.extractCoveredConjunction(conjuncts, + join.right().getOutputSet()); + boolean canFilterLeftNull = canFilterNull(leftPredicates); + boolean canFilterRightNull = canFilterNull(rightPredicates); + JoinType newJoinType = tryEliminateOuterJoin(join.getJoinType(), canFilterLeftNull, + canFilterRightNull); + if (newJoinType == join.getJoinType()) { + return filter; + } else { + return filter.withChildren(join.withJoinType(newJoinType)); + } + }).toRule(RuleType.ELIMINATE_OUTER_JOIN); + } + + private JoinType tryEliminateOuterJoin(JoinType joinType, boolean canFilterLeftNull, boolean canFilterRightNull) { + if (joinType.isRightOuterJoin() && canFilterLeftNull) { + return JoinType.INNER_JOIN; + } + if (joinType.isLeftOuterJoin() && canFilterRightNull) { + return JoinType.INNER_JOIN; + } + if (joinType.isFullOuterJoin() && canFilterLeftNull && canFilterRightNull) { + return JoinType.INNER_JOIN; + } + if (joinType.isFullOuterJoin() && canFilterLeftNull) { + return JoinType.LEFT_OUTER_JOIN; + } + if (joinType.isFullOuterJoin() && canFilterRightNull) { + return JoinType.RIGHT_OUTER_JOIN; + } + return joinType; + } + + private boolean canFilterNull(List predicates) { + Literal nullLiteral = Literal.of(null); + for (Expression predicate : predicates) { + Map replaceMap = Maps.newHashMap(); + predicate.getInputSlots().forEach(slot -> replaceMap.put(slot, nullLiteral)); + Expression evalExpr = FoldConstantRule.INSTANCE.rewrite(ExpressionUtils.replace(predicate, replaceMap), + new ExpressionRewriteContext(null)); + if (nullLiteral.equals(evalExpr) || BooleanLiteral.FALSE.equals(evalExpr)) { + return true; + } + } + return false; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 2904d1d7cd..d710cd3f4e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -254,4 +254,8 @@ public class LogicalJoin(joinType, hashJoinConjuncts, otherJoinConjuncts, children.get(0), children.get(1), joinReorderContext); } + + public LogicalJoin withJoinType(JoinType joinType) { + return new LogicalJoin<>(joinType, hashJoinConjuncts, otherJoinConjuncts, left(), right(), joinReorderContext); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 42b73e573e..9bb36489c9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -300,4 +300,18 @@ public class ExpressionUtils { public static boolean isAllNullLiteral(List children) { return children.stream().allMatch(c -> c instanceof NullLiteral); } + + /** + * extract the predicate that is covered by `slots` + */ + public static List extractCoveredConjunction(List predicates, Set slots) { + List coveredPredicates = Lists.newArrayList(); + for (Expression predicate : predicates) { + if (slots.containsAll(predicate.getInputSlots())) { + coveredPredicates.add(predicate); + } + } + return coveredPredicates; + } } + diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuterTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuterTest.java index 8959031c7f..ad3feb5457 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuterTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/EliminateOuterTest.java @@ -39,12 +39,12 @@ class EliminateOuterTest implements PatternMatchSupported { @Test void testEliminateLeft() { LogicalPlan plan = new LogicalPlanBuilder(scan1) - .hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(1, 1)) // t1.id = t2.id - .filter(new GreaterThan(scan2.getOutput().get(1), Literal.of(1))) + .hashJoinUsing(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .filter(new GreaterThan(scan2.getOutput().get(0), Literal.of(1))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(EliminateOuter.INSTANCE) + .applyTopDown(new EliminateOuterJoin()) .matchesFromRoot( logicalFilter( logicalJoin().when(join -> join.getJoinType().isInnerJoin()) @@ -55,12 +55,12 @@ class EliminateOuterTest implements PatternMatchSupported { @Test void testEliminateRight() { LogicalPlan plan = new LogicalPlanBuilder(scan1) - .hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(1, 1)) // t1.id = t2.id - .filter(new GreaterThan(scan1.getOutput().get(1), Literal.of(1))) + .hashJoinUsing(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .filter(new GreaterThan(scan1.getOutput().get(0), Literal.of(1))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(EliminateOuter.INSTANCE) + .applyTopDown(new EliminateOuterJoin()) .matchesFromRoot( logicalFilter( logicalJoin().when(join -> join.getJoinType().isInnerJoin()) @@ -71,14 +71,14 @@ class EliminateOuterTest implements PatternMatchSupported { @Test void testEliminateBoth() { LogicalPlan plan = new LogicalPlanBuilder(scan1) - .hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(1, 1)) // t1.id = t2.id + .hashJoinUsing(scan2, JoinType.FULL_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id .filter(new And( - new GreaterThan(scan2.getOutput().get(1), Literal.of(1)), - new GreaterThan(scan1.getOutput().get(1), Literal.of(1)))) + new GreaterThan(scan2.getOutput().get(0), Literal.of(1)), + new GreaterThan(scan1.getOutput().get(0), Literal.of(1)))) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyTopDown(EliminateOuter.INSTANCE) + .applyTopDown(new EliminateOuterJoin()) .matchesFromRoot( logicalFilter( logicalJoin().when(join -> join.getJoinType().isInnerJoin())