diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java index 2bc2e1d888..122f73e463 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java @@ -18,7 +18,6 @@ package org.apache.doris.nereids.rules.rewrite.logical; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; @@ -41,6 +40,7 @@ import com.google.common.collect.Lists; import java.util.ArrayList; import java.util.HashSet; import java.util.List; +import java.util.Map; import java.util.Set; import java.util.stream.Collectors; @@ -203,7 +203,7 @@ public class ReorderJoin extends OneRewriteRuleFactory { } Builder builder = ImmutableList.builder(); - // recursively hanlde multiJoin children. + // recursively handle multiJoin children. for (Plan child : multiJoin.children()) { if (child instanceof MultiJoin) { MultiJoin childMultiJoin = (MultiJoin) child; @@ -215,60 +215,56 @@ public class ReorderJoin extends OneRewriteRuleFactory { MultiJoin multiJoinHandleChildren = multiJoin.withChildren(builder.build()); if (!multiJoinHandleChildren.getJoinType().isInnerOrCrossJoin()) { - List leftFilter = Lists.newArrayList(); - List rightFilter = Lists.newArrayList(); - List remainingFilter = Lists.newArrayList(); + List remainingFilter; + Plan left; Plan right; if (multiJoinHandleChildren.getJoinType().isLeftJoin()) { + right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1); + Set rightOutputSet = right.getOutputSet(); + Map> split = multiJoin.getJoinFilter().stream() + .collect(Collectors.partitioningBy(expr -> + ExpressionUtils.isIntersecting(rightOutputSet, expr.getInputSlots()) + )); + remainingFilter = split.get(true); + List pushedFilter = split.get(false); left = multiJoinToJoin(new MultiJoin( multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1), - leftFilter, + pushedFilter, JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION)); - right = multiJoinHandleChildren.child(multiJoinHandleChildren.arity() - 1); } else if (multiJoinHandleChildren.getJoinType().isRightJoin()) { left = multiJoinHandleChildren.child(0); + Set leftOutputSet = left.getOutputSet(); + Map> split = multiJoin.getJoinFilter().stream() + .collect(Collectors.partitioningBy(expr -> + ExpressionUtils.isIntersecting(leftOutputSet, expr.getInputSlots()) + )); + remainingFilter = split.get(true); + List pushedFilter = split.get(false); right = multiJoinToJoin(new MultiJoin( multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()), - rightFilter, + pushedFilter, JoinType.INNER_JOIN, ExpressionUtils.EMPTY_CONDITION)); } else { - left = multiJoinToJoin(new MultiJoin( - multiJoinHandleChildren.children().subList(0, multiJoinHandleChildren.arity() - 1), - leftFilter, - JoinType.INNER_JOIN, - ExpressionUtils.EMPTY_CONDITION)); - right = multiJoinToJoin(new MultiJoin( - multiJoinHandleChildren.children().subList(1, multiJoinHandleChildren.arity()), - rightFilter, - JoinType.INNER_JOIN, - ExpressionUtils.EMPTY_CONDITION)); - } - - // split filter - for (Expression expr : multiJoinHandleChildren.getJoinFilter()) { - Set exprInputSlots = expr.getInputSlots(); - Preconditions.checkState(!exprInputSlots.isEmpty()); - - if (left.getOutputSet().containsAll(exprInputSlots)) { - leftFilter.add(expr); - } else if (right.getOutputSet().containsAll(exprInputSlots)) { - rightFilter.add(expr); - } else if (multiJoin.getOutputSet().containsAll(exprInputSlots)) { - remainingFilter.add(expr); - } else { - NereidsPlanner.LOG.error("invalid expression, exist slot that multiJoin don't contains."); - throw new RuntimeException("invalid expression, exist slot that multiJoin don't contains."); - } + remainingFilter = multiJoin.getJoinFilter(); + Preconditions.checkState(multiJoinHandleChildren.arity() == 2); + List children = multiJoinHandleChildren.children().stream().map(child -> { + if (child instanceof MultiJoin) { + return multiJoinToJoin((MultiJoin) child); + } else { + return child; + } + }).collect(Collectors.toList()); + left = children.get(0); + right = children.get(1); } return PlanUtils.filterOrSelf(remainingFilter, new LogicalJoin<>( multiJoinHandleChildren.getJoinType(), - ExpressionUtils.EMPTY_CONDITION, - multiJoinHandleChildren.getNotInnerJoinConditions(), - PlanUtils.filterOrSelf(leftFilter, left), PlanUtils.filterOrSelf(rightFilter, right))); + ExpressionUtils.EMPTY_CONDITION, multiJoinHandleChildren.getNotInnerJoinConditions(), + left, right)); } // following this multiJoin just contain INNER/CROSS. diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java index da386f8911..57cd260cb4 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoinTest.java @@ -96,20 +96,31 @@ class ReorderJoinTest implements PatternMatchSupported { @Test public void testRightSemiJoin() { - ImmutableList plans = ImmutableList.of( - new LogicalPlanBuilder(scan1) - .hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)) - .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) - .filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0))) - .build(), - new LogicalPlanBuilder(scan1) - .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) - .hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)) - .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) - .build() - ); + LogicalPlan plan1 = new LogicalPlanBuilder(scan1) + .hashJoinUsing(scan2, JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0)) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .filter(new EqualTo(scan3.getOutput().get(0), scan2.getOutput().get(0))) + .build(); + check(ImmutableList.of(plan1)); + + LogicalPlan plan2 = new LogicalPlanBuilder(scan2) + .hashJoinUsing( + new LogicalPlanBuilder(scan1) + .hashJoinEmptyOn(scan3, JoinType.CROSS_JOIN) + .build(), + JoinType.RIGHT_SEMI_JOIN, Pair.of(0, 0) + ) + .filter(new EqualTo(scan3.getOutput().get(0), scan1.getOutput().get(0))) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), plan2) + .rewrite() + .matchesFromRoot( + rightSemiLogicalJoin( + leafPlan(), + innerLogicalJoin() + ) + ); - check(plans); } @Test @@ -171,7 +182,7 @@ class ReorderJoinTest implements PatternMatchSupported { } } - /** + /* * join * crossjoin / \ * / \ join D diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java index 5beb12445c..b4eb11c2c5 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/MultiJoinTest.java @@ -104,4 +104,20 @@ public class MultiJoinTest extends SqlTestBase { .printlnTree(); } } + + @Test + void testOuterJoin() { + String sql = "SELECT * FROM T1 LEFT OUTER JOIN T2 ON T1.id = T2.id, T3 WHERE T2.score > 0"; + PlanChecker.from(connectContext) + .analyze(sql) + .applyBottomUp(new ReorderJoin()) + .printlnTree() + .matches( + crossLogicalJoin( + leftOuterLogicalJoin() + .when(join -> join.getOtherJoinConjuncts().size() == 1), + logicalOlapScan() + ) + ); + } }