From 3c0fd5029de6ccb19f361066b9e7ec1818571e60 Mon Sep 17 00:00:00 2001 From: jakevin Date: Tue, 24 Oct 2023 14:52:30 +0800 Subject: [PATCH] [fix](Nereids): support complex project in PushdownTopNThroughJoin (#25748) --- .../rewrite/PushdownTopNThroughJoin.java | 15 +++- .../rewrite/PushdownTopNThroughJoinTest.java | 69 ++++++++++++++++++- 2 files changed, 81 insertions(+), 3 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java index e72c7a051e..b025d40b6d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoin.java @@ -29,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Set; import java.util.stream.Collectors; /** @@ -55,14 +56,24 @@ public class PushdownTopNThroughJoin implements RewriteRuleFactory { .toRule(RuleType.PUSH_TOP_N_THROUGH_JOIN), // topN -> project -> join - logicalTopN(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots)) - // TODO: complex project + logicalTopN(logicalProject(logicalJoin())) .when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr) .allMatch(Slot.class::isInstance)) .then(topN -> { LogicalProject> project = topN.child(); LogicalJoin join = project.child(); + // If orderby exprs aren't all in the output of the project, we can't push down. + // topN(order by: slot(a+1)) + // - project(a+1, b) + // TODO: in the future, we also can push down it. + Set outputSet = project.child().getOutputSet(); + if (!topN.getOrderKeys().stream().map(OrderKey::getExpr) + .flatMap(e -> e.getInputSlots().stream()) + .allMatch(outputSet::contains)) { + return null; + } + Plan newJoin = pushLimitThroughJoin(topN, join); if (newJoin == null || join.children().equals(newJoin.children())) { return null; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoinTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoinTest.java index b44ca08a2d..c033ca46bc 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoinTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/PushdownTopNThroughJoinTest.java @@ -18,9 +18,13 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.util.LogicalPlanBuilder; import org.apache.doris.nereids.util.MemoPatternMatchSupported; import org.apache.doris.nereids.util.PlanChecker; @@ -30,6 +34,8 @@ import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; +import java.util.List; + class PushdownTopNThroughJoinTest extends TestWithFeService implements MemoPatternMatchSupported { private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); @@ -85,6 +91,29 @@ class PushdownTopNThroughJoinTest extends TestWithFeService implements MemoPatte ); } + @Test + void testProject1() { + List projectExpres = ImmutableList.of(scan1.getOutput().get(0), + new Cast(scan1.getOutput().get(1), VarcharType.SYSTEM_DEFAULT).alias("cast")); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) + .projectExprs(projectExpres) + .topN(10, 0, ImmutableList.of(0)) + .build(); + PlanChecker.from(connectContext, plan) + .applyTopDown(new PushdownTopNThroughJoin()) + .matches( + logicalTopN( + logicalProject( + logicalJoin( + logicalTopN().when(l -> l.getLimit() == 10 && l.getOffset() == 0), + logicalOlapScan() + ) + ) + ) + ); + } + @Test void testJoinSql() { PlanChecker.from(connectContext) @@ -103,7 +132,45 @@ class PushdownTopNThroughJoinTest extends TestWithFeService implements MemoPatte } @Test - void badCase() { + void testProjectSql() { + PlanChecker.from(connectContext) + .analyze( + "select t1.k1, cast(t1.k2 as varchar) from t1 left join t2 on t1.k1 = t2.k1 order by t1.k1 limit 10") + .rewrite() + .matches( + logicalTopN( + logicalProject( + logicalJoin( + logicalTopN().when(l -> l.getLimit() == 10 && l.getOffset() == 0), + logicalProject(logicalOlapScan()) + ) + ) + ) + ); + } + + @Test + void rejectTopNUseProjectComplexExpr() { + List projectExpres = ImmutableList.of( + (new Add(scan1.getOutput().get(0), scan1.getOutput().get(1))).alias("add") + ); + LogicalPlan plan = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) + .projectExprs(projectExpres) + .topN(10, 0, ImmutableList.of(0)) + .build(); + PlanChecker.from(connectContext, plan) + .applyTopDown(new PushdownTopNThroughJoin()) + .matches( + logicalJoin( + logicalOlapScan(), + logicalOlapScan() + ) + ); + } + + @Test + void rejectWrongJoinType() { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0)) .topN(10, 0, ImmutableList.of(0))