From dc497e11bbf51756f68829e544a127291ff8c5ba Mon Sep 17 00:00:00 2001 From: jakevin Date: Thu, 11 May 2023 13:58:54 +0800 Subject: [PATCH] [fix](Nereids) avoid to push top Project of JoinCluster in PushdownProjectThroughJoin (#19441) We shouldn't push top Project of JoinCluster in PushdownProjectThroughJoin like ``` * Project (id + 1) if this project is top project of Join Cluster * | * Join * / \ * Join Join * / .... * Join ``` --- .../join/PushdownProjectThroughInnerJoin.java | 161 ++++++++++-------- .../join/PushdownProjectThroughSemiJoin.java | 93 +++++----- .../PushdownProjectThroughInnerJoinTest.java | 61 ++++--- .../PushdownProjectThroughSemiJoinTest.java | 51 +++--- .../shape/q10.out | 23 +-- .../nereids_tpch_shape_sf1000_p0/shape/q3.out | 15 +- .../nereids_tpch_shape_sf1000_p0/shape/q5.out | 38 ++--- .../nereids_tpch_shape_sf1_p0/shape/q10.out | 23 +-- .../nereids_tpch_shape_sf1_p0/shape/q3.out | 15 +- .../nereids_tpch_shape_sf1_p0/shape/q5.out | 38 ++--- .../nereids_tpch_shape_sf500_p0/shape/q10.out | 23 +-- .../nereids_tpch_shape_sf500_p0/shape/q3.out | 15 +- .../nereids_tpch_shape_sf500_p0/shape/q5.out | 38 ++--- 13 files changed, 330 insertions(+), 264 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java index 6761153e17..db243da5fd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoin.java @@ -20,7 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.CBOUtils; -import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.rules.exploration.ExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -38,75 +38,102 @@ import java.util.Set; import java.util.stream.Collectors; /** - * rule for pushdown project through inner/outer join + * Rule for pushdown project through inner/outer join + * Just push down project inside join to avoid to push the top of Join-Cluster. + *
+ *    Project                   Join
+ *      |            ──►       /    \
+ *     Join               Project  Project
+ *    /   \                  |       |
+ *   A     B                 A       B
+ * 
*/ -public class PushdownProjectThroughInnerJoin extends OneExplorationRuleFactory { +public class PushdownProjectThroughInnerJoin implements ExplorationRuleFactory { public static final PushdownProjectThroughInnerJoin INSTANCE = new PushdownProjectThroughInnerJoin(); - /* - * Project Join - * | ──► / \ - * Join Project Project - * / \ | | - * A B A B - */ @Override - public Rule build() { - return logicalProject(logicalJoin()) - .whenNot(LogicalProject::isAllSlots) - .when(project -> project.child().getJoinType().isInnerJoin()) - .whenNot(project -> project.child().hasJoinHint()) - .then(project -> { - LogicalJoin join = project.child(); - Set aOutputExprIdSet = join.left().getOutputExprIdSet(); - Set bOutputExprIdSet = join.right().getOutputExprIdSet(); - - // reject hyper edge in Project. - if (!project.getProjects().stream().allMatch(expr -> { - Set inputSlotExprIds = expr.getInputSlotExprIds(); - return aOutputExprIdSet.containsAll(inputSlotExprIds) - || bOutputExprIdSet.containsAll(inputSlotExprIds); - })) { - return null; - } - - List aProjects = new ArrayList<>(); - List bProjects = new ArrayList<>(); - for (NamedExpression namedExpression : project.getProjects()) { - Set usedExprIds = namedExpression.getInputSlotExprIds(); - if (aOutputExprIdSet.containsAll(usedExprIds)) { - aProjects.add(namedExpression); - } else { - bProjects.add(namedExpression); - } - } - - boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot)); - boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot)); - // due to JoinCommute, we don't need to consider just right contains. - if (!leftContains) { - return null; - } - - Builder newAProject = ImmutableList.builder().addAll(aProjects); - Set aConditionSlots = CBOUtils.joinChildConditionSlots(join, true); - Set aProjectSlots = aProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet()); - aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add); - Plan newLeft = CBOUtils.projectOrSelf(newAProject.build(), join.left()); - - if (!rightContains) { - Plan newJoin = join.withChildrenNoContext(newLeft, join.right()); - return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); - } - - Builder newBProject = ImmutableList.builder().addAll(bProjects); - Set bConditionSlots = CBOUtils.joinChildConditionSlots(join, false); - Set bProjectSlots = bProjects.stream().map(NamedExpression::toSlot).collect(Collectors.toSet()); - bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add); - Plan newRight = CBOUtils.projectOrSelf(newBProject.build(), join.right()); - - Plan newJoin = join.withChildrenNoContext(newLeft, newRight); - return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); - }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN); + public List buildRules() { + return ImmutableList.of( + logicalJoin(logicalProject(innerLogicalJoin()), group()) + // Just pushdown project with non-column expr like (t.id + 1) + .whenNot(j -> j.left().isAllSlots()) + .whenNot(j -> j.left().child().hasJoinHint()) + .then(topJoin -> { + LogicalProject> project = topJoin.left(); + Plan newLeft = pushdownProject(project); + if (newLeft == null) { + return null; + } + return topJoin.withChildren(newLeft, topJoin.right()); + }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN), + logicalJoin(group(), logicalProject(innerLogicalJoin())) + // Just pushdown project with non-column expr like (t.id + 1) + .whenNot(j -> j.right().isAllSlots()) + .whenNot(j -> j.right().child().hasJoinHint()) + .then(topJoin -> { + LogicalProject> project = topJoin.right(); + Plan newRight = pushdownProject(project); + if (newRight == null) { + return null; + } + return topJoin.withChildren(topJoin.left(), newRight); + }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_INNER_JOIN) + ); } + + private Plan pushdownProject(LogicalProject> project) { + LogicalJoin join = project.child(); + Set aOutputExprIdSet = join.left().getOutputExprIdSet(); + Set bOutputExprIdSet = join.right().getOutputExprIdSet(); + + // reject hyper edge in Project. + if (!project.getProjects().stream().allMatch(expr -> { + Set inputSlotExprIds = expr.getInputSlotExprIds(); + return aOutputExprIdSet.containsAll(inputSlotExprIds) + || bOutputExprIdSet.containsAll(inputSlotExprIds); + })) { + return null; + } + + List aProjects = new ArrayList<>(); + List bProjects = new ArrayList<>(); + for (NamedExpression namedExpression : project.getProjects()) { + Set usedExprIds = namedExpression.getInputSlotExprIds(); + if (aOutputExprIdSet.containsAll(usedExprIds)) { + aProjects.add(namedExpression); + } else { + bProjects.add(namedExpression); + } + } + + boolean leftContains = aProjects.stream().anyMatch(e -> !(e instanceof Slot)); + boolean rightContains = bProjects.stream().anyMatch(e -> !(e instanceof Slot)); + // due to JoinCommute, we don't need to consider just right contains. + if (!leftContains) { + return null; + } + + Builder newAProject = ImmutableList.builder().addAll(aProjects); + Set aConditionSlots = CBOUtils.joinChildConditionSlots(join, true); + Set aProjectSlots = aProjects.stream().map(NamedExpression::toSlot) + .collect(Collectors.toSet()); + aConditionSlots.stream().filter(slot -> !aProjectSlots.contains(slot)).forEach(newAProject::add); + Plan newLeft = CBOUtils.projectOrSelf(newAProject.build(), join.left()); + + if (!rightContains) { + Plan newJoin = join.withChildrenNoContext(newLeft, join.right()); + return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); + } + + Builder newBProject = ImmutableList.builder().addAll(bProjects); + Set bConditionSlots = CBOUtils.joinChildConditionSlots(join, false); + Set bProjectSlots = bProjects.stream().map(NamedExpression::toSlot) + .collect(Collectors.toSet()); + bConditionSlots.stream().filter(slot -> !bProjectSlots.contains(slot)).forEach(newBProject::add); + Plan newRight = CBOUtils.projectOrSelf(newBProject.build(), join.right()); + + Plan newJoin = join.withChildrenNoContext(newLeft, newRight); + return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); + } + } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java index 79b2047af0..851c63dc2d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoin.java @@ -20,8 +20,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.CBOUtils; -import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.rules.exploration.ExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; @@ -29,58 +28,68 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import com.google.common.collect.ImmutableList; + import java.util.ArrayList; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.stream.Collectors; /** - * rule for pushdown project through left-semi/anti join + * Rule for pushdown project through left-semi/anti join + * Just push down project inside join to avoid to push the top of Join-Cluster. + *
+ *     Join                     Join
+ *      |                        |
+ *    Project                   Join
+ *      |            ──►       /   \
+ *     Join                Project  B
+ *    /   \                   |
+ *   A     B                  A
+ * 
*/ -public class PushdownProjectThroughSemiJoin extends OneExplorationRuleFactory { +public class PushdownProjectThroughSemiJoin implements ExplorationRuleFactory { public static final PushdownProjectThroughSemiJoin INSTANCE = new PushdownProjectThroughSemiJoin(); - /* - * Project Join - * | ──► / \ - * Join Project B - * / \ | - * A B A - */ @Override - public Rule build() { - return logicalProject(logicalJoin()) - .when(project -> project.child().getJoinType().isLeftSemiOrAntiJoin()) - // Just pushdown project with non-column expr like (t.id + 1) - .whenNot(LogicalProject::isAllSlots) - .whenNot(project -> project.child().hasJoinHint()) - .then(project -> { - LogicalJoin join = project.child(); - Set conditionLeftSlots = CBOUtils.joinChildConditionSlots(join, true); + public List buildRules() { + return ImmutableList.of( + logicalJoin(logicalProject(logicalJoin()), group()) + .when(j -> j.left().child().getJoinType().isLeftSemiOrAntiJoin()) + // Just pushdown project with non-column expr like (t.id + 1) + .whenNot(j -> j.left().isAllSlots()) + .whenNot(j -> j.left().child().hasJoinHint()) + .then(topJoin -> { + LogicalProject> project = topJoin.left(); + Plan newLeft = pushdownProject(project); + return topJoin.withChildren(newLeft, topJoin.right()); + }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN), - List newProject = new ArrayList<>(project.getProjects()); - Set projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot) - .collect(Collectors.toSet()); - conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)).forEach(newProject::add); - Plan newLeft = CBOUtils.projectOrSelf(newProject, join.left()); - - Plan newJoin = join.withChildrenNoContext(newLeft, join.right()); - return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); - }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN); + logicalJoin(group(), logicalProject(logicalJoin())) + .when(j -> j.right().child().getJoinType().isLeftSemiOrAntiJoin()) + // Just pushdown project with non-column expr like (t.id + 1) + .whenNot(j -> j.right().isAllSlots()) + .whenNot(j -> j.right().child().hasJoinHint()) + .then(topJoin -> { + LogicalProject> project = topJoin.right(); + Plan newRight = pushdownProject(project); + return topJoin.withChildren(topJoin.left(), newRight); + }).toRule(RuleType.PUSH_DOWN_PROJECT_THROUGH_SEMI_JOIN) + ); } - List sort(List projects, Plan sortPlan) { - List orderExprIds = sortPlan.getOutput().stream().map(Slot::getExprId).collect(Collectors.toList()); - // map { project input slot expr id -> project output expr } - Map map = projects.stream() - .collect(Collectors.toMap(expr -> expr.getInputSlots().iterator().next().getExprId(), expr -> expr)); - List newProjects = new ArrayList<>(); - for (ExprId exprId : orderExprIds) { - if (map.containsKey(exprId)) { - newProjects.add(map.get(exprId)); - } - } - return newProjects; + private Plan pushdownProject(LogicalProject> project) { + LogicalJoin join = project.child(); + Set conditionLeftSlots = CBOUtils.joinChildConditionSlots(join, true); + + List newProject = new ArrayList<>(project.getProjects()); + Set projectUsedSlots = project.getProjects().stream().map(NamedExpression::toSlot) + .collect(Collectors.toSet()); + conditionLeftSlots.stream().filter(slot -> !projectUsedSlots.contains(slot)) + .forEach(newProject::add); + Plan newLeft = CBOUtils.projectOrSelf(newProject, join.left()); + + Plan newJoin = join.withChildrenNoContext(newLeft, join.right()); + return CBOUtils.projectOrSelf(new ArrayList<>(project.getOutput()), newJoin); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java index 7d94fa876c..d739039676 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughInnerJoinTest.java @@ -40,6 +40,7 @@ import java.util.List; class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported { private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); @Test public void pushBothSide() { @@ -54,16 +55,20 @@ class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(1, 1)) .projectExprs(projectExprs) + .join(scan3, JoinType.INNER_JOIN, Pair.of(1, 1)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.buildRules()) .printlnOrigin() .printlnExploration() .matchesExploration( logicalJoin( - logicalProject().when(project -> project.getProjects().size() == 2), - logicalProject().when(project -> project.getProjects().size() == 2) + logicalJoin( + logicalProject().when(project -> project.getProjects().size() == 2), + logicalProject().when(project -> project.getProjects().size() == 2) + ), + logicalOlapScan() ) ); } @@ -81,18 +86,22 @@ class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .projectExprs(projectExprs) + .join(scan3, JoinType.INNER_JOIN, Pair.of(1, 1)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.buildRules()) .printlnOrigin() .printlnExploration() .matchesExploration( - logicalProject( - logicalJoin( - logicalProject().when(project -> project.getProjects().size() == 3), - logicalProject().when(project -> project.getProjects().size() == 3) - ) + logicalJoin( + logicalProject( + logicalJoin( + logicalProject().when(project -> project.getProjects().size() == 3), + logicalProject().when(project -> project.getProjects().size() == 3) + ) + ), + logicalOlapScan() ) ); } @@ -108,26 +117,29 @@ class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .projectExprs(projectExprs) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.buildRules()) .printlnOrigin() .printlnExploration() .matchesExploration( - logicalProject( - logicalJoin( - logicalProject() - .when(project -> - project.getProjects().get(0).toSql().equals("(id + name) AS `complex1`") - && project.getProjects().get(1).toSql().equals("id")), - logicalProject() - .when(project -> - project.getProjects().get(0).toSql().equals("(id + name) AS `complex2`") - && project.getProjects().get(1).toSql().equals("id")) - ) - ).when(project -> project.getProjects().get(0).toSql().equals("complex1") - && project.getProjects().get(1).toSql().equals("complex2") + logicalJoin( + logicalProject( + logicalJoin( + logicalProject() + .when(project -> + project.getProjects().get(0).toSql().equals("(id + name) AS `complex1`") + && project.getProjects().get(1).toSql().equals("id")), + logicalProject() + .when(project -> + project.getProjects().get(0).toSql().equals("(id + name) AS `complex2`") + && project.getProjects().get(1).toSql().equals("id")) + ) + ).when(project -> project.getProjects().get(0).toSql().equals("complex1") + && project.getProjects().get(1).toSql().equals("complex2")), + logicalOlapScan() ) ); } @@ -142,10 +154,11 @@ class PushdownProjectThroughInnerJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) .projectExprs(projectExprs) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.build()) + .applyExploration(PushdownProjectThroughInnerJoin.INSTANCE.buildRules()) .checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size())); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java index b47910f748..862580208e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/PushdownProjectThroughSemiJoinTest.java @@ -39,6 +39,7 @@ import java.util.List; class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported { private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); @Test public void pushdownProject() { @@ -51,17 +52,21 @@ class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(1, 1)) .projectExprs(projectExprs) + .join(scan3, JoinType.INNER_JOIN, Pair.of(1, 1)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build()) + .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.buildRules()) .printlnOrigin() .printlnExploration() .matchesExploration( - leftSemiLogicalJoin( - logicalProject( + logicalJoin( + leftSemiLogicalJoin( + logicalProject( + logicalOlapScan() + ).when(project -> project.getProjects().size() == 2), logicalOlapScan() - ).when(project -> project.getProjects().size() == 2), + ), logicalOlapScan() ) ); @@ -78,21 +83,24 @@ class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) .projectExprs(projectExprs) + .join(scan3, JoinType.INNER_JOIN, Pair.of(1, 1)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build()) + .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.buildRules()) .printlnOrigin() .printlnExploration() .matchesExploration( - logicalProject( - leftSemiLogicalJoin( - logicalProject( + logicalJoin( + logicalProject( + leftSemiLogicalJoin( + logicalProject( + logicalOlapScan() + ).when(project -> project.getProjects().size() == 3), logicalOlapScan() - ).when(project -> project.getProjects().size() == 3), - logicalOlapScan() - ) - ).when(project -> project.getProjects().size() == 2) + ) + ).when(project -> project.getProjects().size() == 2), logicalOlapScan() + ) ); } @@ -105,21 +113,24 @@ class PushdownProjectThroughSemiJoinTest implements MemoPatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_SEMI_JOIN, Pair.of(0, 0)) .projectExprs(projectExprs) + .join(scan3, JoinType.INNER_JOIN, Pair.of(0, 0)) .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) - .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.build()) + .applyExploration(PushdownProjectThroughSemiJoin.INSTANCE.buildRules()) .printlnOrigin() .printlnExploration() .matchesExploration( - logicalProject( - leftSemiLogicalJoin( - logicalProject() - .when(project -> project.getProjects().get(0).toSql().equals("(id + name) AS `complex`") - && project.getProjects().get(1).toSql().equals("id")), - logicalOlapScan() + logicalJoin( + logicalProject( + leftSemiLogicalJoin( + logicalProject() + .when(project -> project.getProjects().get(0).toSql().equals("(id + name) AS `complex`") + && project.getProjects().get(1).toSql().equals("id")), + logicalOlapScan() + ) + ).when(project -> project.getProjects().get(0).toSql().equals("complex")), logicalOlapScan() ) - ).when(project -> project.getProjects().get(0).toSql().equals("complex")) ); } } diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q10.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q10.out index 316a267223..cf30f47eb7 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q10.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q10.out @@ -7,17 +7,18 @@ PhysicalTopN --------hashAgg[LOCAL] ----------PhysicalProject ------------hashJoin[INNER_JOIN](customer.c_nationkey = nation.n_nationkey) ---------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) -----------------PhysicalProject -------------------PhysicalOlapScan[customer] -----------------PhysicalDistribute -------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) ---------------------PhysicalProject -----------------------filter((lineitem.l_returnflag = 'R')) -------------------------PhysicalOlapScan[lineitem] ---------------------PhysicalProject -----------------------filter((orders.o_orderdate < 1994-01-01)(orders.o_orderdate >= 1993-10-01)) -------------------------PhysicalOlapScan[orders] +--------------PhysicalProject +----------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) +------------------PhysicalProject +--------------------PhysicalOlapScan[customer] +------------------PhysicalDistribute +--------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) +----------------------PhysicalProject +------------------------filter((lineitem.l_returnflag = 'R')) +--------------------------PhysicalOlapScan[lineitem] +----------------------PhysicalProject +------------------------filter((orders.o_orderdate < 1994-01-01)(orders.o_orderdate >= 1993-10-01)) +--------------------------PhysicalOlapScan[orders] --------------PhysicalDistribute ----------------PhysicalProject ------------------PhysicalOlapScan[nation] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q3.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q3.out index a8a3d52100..aff6e8dcd1 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q3.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q3.out @@ -10,12 +10,13 @@ PhysicalTopN --------------PhysicalProject ----------------filter((lineitem.l_shipdate > 1995-03-15)) ------------------PhysicalOlapScan[lineitem] ---------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) -----------------PhysicalProject -------------------filter((orders.o_orderdate < 1995-03-15)) ---------------------PhysicalOlapScan[orders] -----------------PhysicalDistribute +--------------PhysicalProject +----------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) ------------------PhysicalProject ---------------------filter((customer.c_mktsegment = 'BUILDING')) -----------------------PhysicalOlapScan[customer] +--------------------filter((orders.o_orderdate < 1995-03-15)) +----------------------PhysicalOlapScan[orders] +------------------PhysicalDistribute +--------------------PhysicalProject +----------------------filter((customer.c_mktsegment = 'BUILDING')) +------------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q5.out b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q5.out index 7554f864fe..313bf6c594 100644 --- a/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q5.out +++ b/regression-test/data/nereids_tpch_shape_sf1000_p0/shape/q5.out @@ -12,24 +12,24 @@ PhysicalQuickSort ------------------PhysicalOlapScan[customer] ----------------PhysicalDistribute ------------------PhysicalProject ---------------------hashJoin[INNER_JOIN](lineitem.l_suppkey = supplier.s_suppkey) -----------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) -------------------------PhysicalProject ---------------------------PhysicalOlapScan[lineitem] -------------------------PhysicalProject ---------------------------filter((orders.o_orderdate < 1995-01-01)(orders.o_orderdate >= 1994-01-01)) -----------------------------PhysicalOlapScan[orders] -----------------------PhysicalDistribute -------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[supplier] -----------------------------PhysicalDistribute -------------------------------hashJoin[INNER_JOIN](nation.n_regionkey = region.r_regionkey) ---------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[nation] ---------------------------------PhysicalDistribute +--------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) +----------------------PhysicalProject +------------------------hashJoin[INNER_JOIN](lineitem.l_suppkey = supplier.s_suppkey) +--------------------------PhysicalProject +----------------------------PhysicalOlapScan[lineitem] +--------------------------PhysicalDistribute +----------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) +------------------------------PhysicalProject +--------------------------------PhysicalOlapScan[supplier] +------------------------------PhysicalDistribute +--------------------------------hashJoin[INNER_JOIN](nation.n_regionkey = region.r_regionkey) ----------------------------------PhysicalProject -------------------------------------filter((region.r_name = 'ASIA')) ---------------------------------------PhysicalOlapScan[region] +------------------------------------PhysicalOlapScan[nation] +----------------------------------PhysicalDistribute +------------------------------------PhysicalProject +--------------------------------------filter((region.r_name = 'ASIA')) +----------------------------------------PhysicalOlapScan[region] +----------------------PhysicalProject +------------------------filter((orders.o_orderdate < 1995-01-01)(orders.o_orderdate >= 1994-01-01)) +--------------------------PhysicalOlapScan[orders] diff --git a/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q10.out b/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q10.out index 316a267223..cf30f47eb7 100644 --- a/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q10.out +++ b/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q10.out @@ -7,17 +7,18 @@ PhysicalTopN --------hashAgg[LOCAL] ----------PhysicalProject ------------hashJoin[INNER_JOIN](customer.c_nationkey = nation.n_nationkey) ---------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) -----------------PhysicalProject -------------------PhysicalOlapScan[customer] -----------------PhysicalDistribute -------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) ---------------------PhysicalProject -----------------------filter((lineitem.l_returnflag = 'R')) -------------------------PhysicalOlapScan[lineitem] ---------------------PhysicalProject -----------------------filter((orders.o_orderdate < 1994-01-01)(orders.o_orderdate >= 1993-10-01)) -------------------------PhysicalOlapScan[orders] +--------------PhysicalProject +----------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) +------------------PhysicalProject +--------------------PhysicalOlapScan[customer] +------------------PhysicalDistribute +--------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) +----------------------PhysicalProject +------------------------filter((lineitem.l_returnflag = 'R')) +--------------------------PhysicalOlapScan[lineitem] +----------------------PhysicalProject +------------------------filter((orders.o_orderdate < 1994-01-01)(orders.o_orderdate >= 1993-10-01)) +--------------------------PhysicalOlapScan[orders] --------------PhysicalDistribute ----------------PhysicalProject ------------------PhysicalOlapScan[nation] diff --git a/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q3.out b/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q3.out index a8a3d52100..aff6e8dcd1 100644 --- a/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q3.out +++ b/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q3.out @@ -10,12 +10,13 @@ PhysicalTopN --------------PhysicalProject ----------------filter((lineitem.l_shipdate > 1995-03-15)) ------------------PhysicalOlapScan[lineitem] ---------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) -----------------PhysicalProject -------------------filter((orders.o_orderdate < 1995-03-15)) ---------------------PhysicalOlapScan[orders] -----------------PhysicalDistribute +--------------PhysicalProject +----------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) ------------------PhysicalProject ---------------------filter((customer.c_mktsegment = 'BUILDING')) -----------------------PhysicalOlapScan[customer] +--------------------filter((orders.o_orderdate < 1995-03-15)) +----------------------PhysicalOlapScan[orders] +------------------PhysicalDistribute +--------------------PhysicalProject +----------------------filter((customer.c_mktsegment = 'BUILDING')) +------------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q5.out b/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q5.out index 7554f864fe..313bf6c594 100644 --- a/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q5.out +++ b/regression-test/data/nereids_tpch_shape_sf1_p0/shape/q5.out @@ -12,24 +12,24 @@ PhysicalQuickSort ------------------PhysicalOlapScan[customer] ----------------PhysicalDistribute ------------------PhysicalProject ---------------------hashJoin[INNER_JOIN](lineitem.l_suppkey = supplier.s_suppkey) -----------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) -------------------------PhysicalProject ---------------------------PhysicalOlapScan[lineitem] -------------------------PhysicalProject ---------------------------filter((orders.o_orderdate < 1995-01-01)(orders.o_orderdate >= 1994-01-01)) -----------------------------PhysicalOlapScan[orders] -----------------------PhysicalDistribute -------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[supplier] -----------------------------PhysicalDistribute -------------------------------hashJoin[INNER_JOIN](nation.n_regionkey = region.r_regionkey) ---------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[nation] ---------------------------------PhysicalDistribute +--------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) +----------------------PhysicalProject +------------------------hashJoin[INNER_JOIN](lineitem.l_suppkey = supplier.s_suppkey) +--------------------------PhysicalProject +----------------------------PhysicalOlapScan[lineitem] +--------------------------PhysicalDistribute +----------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) +------------------------------PhysicalProject +--------------------------------PhysicalOlapScan[supplier] +------------------------------PhysicalDistribute +--------------------------------hashJoin[INNER_JOIN](nation.n_regionkey = region.r_regionkey) ----------------------------------PhysicalProject -------------------------------------filter((region.r_name = 'ASIA')) ---------------------------------------PhysicalOlapScan[region] +------------------------------------PhysicalOlapScan[nation] +----------------------------------PhysicalDistribute +------------------------------------PhysicalProject +--------------------------------------filter((region.r_name = 'ASIA')) +----------------------------------------PhysicalOlapScan[region] +----------------------PhysicalProject +------------------------filter((orders.o_orderdate < 1995-01-01)(orders.o_orderdate >= 1994-01-01)) +--------------------------PhysicalOlapScan[orders] diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q10.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q10.out index 316a267223..cf30f47eb7 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q10.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q10.out @@ -7,17 +7,18 @@ PhysicalTopN --------hashAgg[LOCAL] ----------PhysicalProject ------------hashJoin[INNER_JOIN](customer.c_nationkey = nation.n_nationkey) ---------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) -----------------PhysicalProject -------------------PhysicalOlapScan[customer] -----------------PhysicalDistribute -------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) ---------------------PhysicalProject -----------------------filter((lineitem.l_returnflag = 'R')) -------------------------PhysicalOlapScan[lineitem] ---------------------PhysicalProject -----------------------filter((orders.o_orderdate < 1994-01-01)(orders.o_orderdate >= 1993-10-01)) -------------------------PhysicalOlapScan[orders] +--------------PhysicalProject +----------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) +------------------PhysicalProject +--------------------PhysicalOlapScan[customer] +------------------PhysicalDistribute +--------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) +----------------------PhysicalProject +------------------------filter((lineitem.l_returnflag = 'R')) +--------------------------PhysicalOlapScan[lineitem] +----------------------PhysicalProject +------------------------filter((orders.o_orderdate < 1994-01-01)(orders.o_orderdate >= 1993-10-01)) +--------------------------PhysicalOlapScan[orders] --------------PhysicalDistribute ----------------PhysicalProject ------------------PhysicalOlapScan[nation] diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q3.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q3.out index a8a3d52100..aff6e8dcd1 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q3.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q3.out @@ -10,12 +10,13 @@ PhysicalTopN --------------PhysicalProject ----------------filter((lineitem.l_shipdate > 1995-03-15)) ------------------PhysicalOlapScan[lineitem] ---------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) -----------------PhysicalProject -------------------filter((orders.o_orderdate < 1995-03-15)) ---------------------PhysicalOlapScan[orders] -----------------PhysicalDistribute +--------------PhysicalProject +----------------hashJoin[INNER_JOIN](customer.c_custkey = orders.o_custkey) ------------------PhysicalProject ---------------------filter((customer.c_mktsegment = 'BUILDING')) -----------------------PhysicalOlapScan[customer] +--------------------filter((orders.o_orderdate < 1995-03-15)) +----------------------PhysicalOlapScan[orders] +------------------PhysicalDistribute +--------------------PhysicalProject +----------------------filter((customer.c_mktsegment = 'BUILDING')) +------------------------PhysicalOlapScan[customer] diff --git a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q5.out b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q5.out index 7554f864fe..313bf6c594 100644 --- a/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q5.out +++ b/regression-test/data/nereids_tpch_shape_sf500_p0/shape/q5.out @@ -12,24 +12,24 @@ PhysicalQuickSort ------------------PhysicalOlapScan[customer] ----------------PhysicalDistribute ------------------PhysicalProject ---------------------hashJoin[INNER_JOIN](lineitem.l_suppkey = supplier.s_suppkey) -----------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) -------------------------PhysicalProject ---------------------------PhysicalOlapScan[lineitem] -------------------------PhysicalProject ---------------------------filter((orders.o_orderdate < 1995-01-01)(orders.o_orderdate >= 1994-01-01)) -----------------------------PhysicalOlapScan[orders] -----------------------PhysicalDistribute -------------------------PhysicalProject ---------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) -----------------------------PhysicalProject -------------------------------PhysicalOlapScan[supplier] -----------------------------PhysicalDistribute -------------------------------hashJoin[INNER_JOIN](nation.n_regionkey = region.r_regionkey) ---------------------------------PhysicalProject -----------------------------------PhysicalOlapScan[nation] ---------------------------------PhysicalDistribute +--------------------hashJoin[INNER_JOIN](lineitem.l_orderkey = orders.o_orderkey) +----------------------PhysicalProject +------------------------hashJoin[INNER_JOIN](lineitem.l_suppkey = supplier.s_suppkey) +--------------------------PhysicalProject +----------------------------PhysicalOlapScan[lineitem] +--------------------------PhysicalDistribute +----------------------------hashJoin[INNER_JOIN](supplier.s_nationkey = nation.n_nationkey) +------------------------------PhysicalProject +--------------------------------PhysicalOlapScan[supplier] +------------------------------PhysicalDistribute +--------------------------------hashJoin[INNER_JOIN](nation.n_regionkey = region.r_regionkey) ----------------------------------PhysicalProject -------------------------------------filter((region.r_name = 'ASIA')) ---------------------------------------PhysicalOlapScan[region] +------------------------------------PhysicalOlapScan[nation] +----------------------------------PhysicalDistribute +------------------------------------PhysicalProject +--------------------------------------filter((region.r_name = 'ASIA')) +----------------------------------------PhysicalOlapScan[region] +----------------------PhysicalProject +------------------------filter((orders.o_orderdate < 1995-01-01)(orders.o_orderdate >= 1994-01-01)) +--------------------------PhysicalOlapScan[orders]