[feature](Nereids): push down topN through join (#24720)

Push TopN through Join.

JoinType just can be left/right outer join or cross join, because data of their one child can't be filtered.

new TopN is (original limit + original offset, 0) as limit and offset.
This commit is contained in:
jakevin
2023-10-07 14:58:53 +08:00
committed by GitHub
parent 47694c5b36
commit 3c9ff7af39
12 changed files with 312 additions and 69 deletions

View File

@ -73,6 +73,7 @@ import org.apache.doris.nereids.rules.rewrite.InferJoinNotNull;
import org.apache.doris.nereids.rules.rewrite.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
import org.apache.doris.nereids.rules.rewrite.LeadingJoin;
import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
import org.apache.doris.nereids.rules.rewrite.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
@ -90,9 +91,9 @@ import org.apache.doris.nereids.rules.rewrite.PushProjectIntoOneRowRelation;
import org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion;
import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushdownLimit;
import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughWindow;
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
import org.apache.doris.nereids.rules.rewrite.ReplaceLimitNode;
import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren;
import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute;
import org.apache.doris.nereids.rules.rewrite.SimplifyAggGroupBy;
@ -275,9 +276,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
// we should refactor like AggregateStrategies, e.g. LimitStrategies,
// generate one PhysicalLimit if current distribution is gather or two
// PhysicalLimits with gather exchange
new ReplaceLimitNode(),
new LimitSortToTopN(),
new SplitLimit(),
new PushdownLimit(),
new PushdownTopNThroughJoin(),
new PushdownTopNThroughWindow(),
new CreatePartitionTopNFromWindow()
)

View File

@ -165,8 +165,6 @@ public enum RuleType {
COLUMN_PRUNING(RuleTypeClass.REWRITE),
ELIMINATE_SORT(RuleTypeClass.REWRITE),
PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW(RuleTypeClass.REWRITE),
PUSHDOWN_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
PUSHDOWN_MIN_MAX_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSHDOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSHDOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE),
@ -248,7 +246,12 @@ public enum RuleType {
PUSH_LIMIT_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_WINDOW(RuleTypeClass.REWRITE),
PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE),
LIMIT_SORT_TO_TOP_N(RuleTypeClass.REWRITE),
// topN push down
PUSH_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
// adjust nullable
ADJUST_NULLABLE(RuleTypeClass.REWRITE),
ADJUST_CONJUNCTS_RETURN_TYPE(RuleTypeClass.REWRITE),

View File

@ -19,18 +19,35 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* Eliminate limit = 0.
*/
public class EliminateLimit extends OneRewriteRuleFactory {
public class EliminateLimit implements RewriteRuleFactory {
@Override
public Rule build() {
return logicalLimit()
.when(limit -> limit.getLimit() == 0)
.thenApply(ctx -> new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(),
ctx.root.getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT);
public List<Rule> buildRules() {
return ImmutableList.of(
logicalLimit()
.when(limit -> limit.getLimit() == 0)
.thenApply(ctx -> new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(),
ctx.root.getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT),
logicalLimit(logicalOneRowRelation())
.then(limit -> limit.getLimit() > 0 && limit.getOffset() == 0
? limit.child() : new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(),
limit.child().getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION),
logicalLimit(logicalEmptyRelation())
.then(UnaryNode::child)
.toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION)
);
}
}

View File

@ -19,10 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
@ -35,7 +32,7 @@ import java.util.List;
/**
* rule to eliminate limit node by replace to other nodes.
*/
public class ReplaceLimitNode implements RewriteRuleFactory {
public class LimitSortToTopN implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
@ -47,8 +44,8 @@ public class ReplaceLimitNode implements RewriteRuleFactory {
limit.getLimit(),
limit.getOffset(),
sort.child(0));
}).toRule(RuleType.PUSH_LIMIT_INTO_SORT),
//limit->proj->sort ==> proj->topN
}).toRule(RuleType.LIMIT_SORT_TO_TOP_N),
// limit -> proj -> sort ==> proj -> topN
logicalLimit(logicalProject(logicalSort()))
.then(limit -> {
LogicalProject<LogicalSort<Plan>> project = limit.child();
@ -58,15 +55,7 @@ public class ReplaceLimitNode implements RewriteRuleFactory {
limit.getOffset(),
sort.child(0));
return project.withChildren(Lists.newArrayList(topN));
}).toRule(RuleType.PUSH_LIMIT_INTO_SORT),
logicalLimit(logicalOneRowRelation())
.then(limit -> limit.getLimit() > 0 && limit.getOffset() == 0
? limit.child() : new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(),
limit.child().getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION),
logicalLimit(logicalEmptyRelation())
.then(UnaryNode::child)
.toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION)
}).toRule(RuleType.LIMIT_SORT_TO_TOP_N)
);
}
}

View File

@ -41,27 +41,26 @@ public class PushdownFilterThroughProject implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT.build(logicalFilter(logicalProject())
.whenNot(filter -> filter.child().getProjects().stream().anyMatch(
expr -> expr.anyMatch(WindowExpression.class::isInstance)))
.then(PushdownFilterThroughProject::pushdownFilterThroughProject)),
// filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit,
// then we should change filter(limit(project)) to project(filter(limit))
RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT
.build(logicalFilter(logicalLimit(logicalProject()))
.whenNot(filter -> filter.child().child().getProjects().stream()
.anyMatch(expr -> expr
.anyMatch(WindowExpression.class::isInstance)))
.then(filter -> {
LogicalLimit<LogicalProject<Plan>> limit = filter.child();
LogicalProject<Plan> project = limit.child();
logicalFilter(logicalProject())
.whenNot(filter -> filter.child().getProjects().stream().anyMatch(
expr -> expr.anyMatch(WindowExpression.class::isInstance)))
.then(PushdownFilterThroughProject::pushdownFilterThroughProject)
.toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT),
// filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit,
// then we should change filter(limit(project)) to project(filter(limit))
logicalFilter(logicalLimit(logicalProject()))
.whenNot(filter -> filter.child().child().getProjects().stream()
.anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance)))
.then(filter -> {
LogicalLimit<LogicalProject<Plan>> limit = filter.child();
LogicalProject<Plan> project = limit.child();
return project.withProjectsAndChild(project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(filter.getConjuncts(),
project.getAliasToProducer()),
limit.withChildren(project.child())));
}))
return project.withProjectsAndChild(project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(filter.getConjuncts(),
project.getAliasToProducer()),
limit.withChildren(project.child())));
}).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT)
);
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
/**
* <pre>
* Before:
* project
* │
@ -42,6 +43,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
* │
* ▼
* plan node
* </pre>
*/
public class PushdownProjectThroughLimit extends OneRewriteRuleFactory {
@ -50,9 +52,7 @@ public class PushdownProjectThroughLimit extends OneRewriteRuleFactory {
return logicalProject(logicalLimit()).thenApply(ctx -> {
LogicalProject<LogicalLimit<Plan>> logicalProject = ctx.root;
LogicalLimit<Plan> logicalLimit = logicalProject.child();
return new LogicalLimit<>(logicalLimit.getLimit(), logicalLimit.getOffset(),
logicalLimit.getPhase(), logicalProject.withProjectsAndChild(logicalProject.getProjects(),
logicalLimit.child()));
return logicalLimit.withChildren(logicalProject.withChildren(logicalLimit.child()));
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_LIMIT);
}
}

View File

@ -0,0 +1,108 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Push down TopN through Outer Join into left child .....
*/
public class PushdownTopNThroughJoin implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
// topN -> join
logicalTopN(logicalJoin())
// TODO: complex orderby
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
LogicalJoin<Plan, Plan> join = topN.child();
Plan newJoin = pushLimitThroughJoin(topN, join);
if (newJoin == null || topN.child().children().equals(newJoin.children())) {
return null;
}
return topN.withChildren(newJoin);
})
.toRule(RuleType.PUSH_TOP_N_THROUGH_JOIN),
// topN -> project -> join
logicalTopN(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots))
// TODO: complex project
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
LogicalProject<LogicalJoin<Plan, Plan>> project = topN.child();
LogicalJoin<Plan, Plan> join = project.child();
Plan newJoin = pushLimitThroughJoin(topN, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return topN.withChildren(project.withChildren(newJoin));
}).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_JOIN)
);
}
private Plan pushLimitThroughJoin(LogicalTopN<? extends Plan> topN, LogicalJoin<Plan, Plan> join) {
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
Set<Slot> rightOutputSet = join.right().getOutputSet();
if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
.anyMatch(e -> Utils.isIntersecting(rightOutputSet, e.getInputSlots()))) {
return null;
}
return join.withChildren(topN.withChildren(join.left()), join.right());
case RIGHT_OUTER_JOIN:
Set<Slot> leftOutputSet = join.left().getOutputSet();
if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
.anyMatch(e -> Utils.isIntersecting(leftOutputSet, e.getInputSlots()))) {
return null;
}
return join.withChildren(join.left(), topN.withChildren(join.right()));
case CROSS_JOIN:
List<Slot> orderbySlots = topN.getOrderKeys().stream().map(OrderKey::getExpr)
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
if (join.left().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(topN.withChildren(join.left()), join.right());
} else if (join.right().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(join.left(), topN.withChildren(join.right()));
} else {
return null;
}
default:
// don't push limit.
return null;
}
}
}

View File

@ -59,7 +59,7 @@ public class PushdownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
return topn.withChildren(newWindow.get());
}).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_WINDOW),
}).toRule(RuleType.PUSH_TOP_N_THROUGH_WINDOW),
// topn -> projection -> window
logicalTopN(logicalProject(logicalWindow())).then(topn -> {
@ -79,7 +79,7 @@ public class PushdownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
return topn.withChildren(project.withChildren(newWindow.get()));
}).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW)
}).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_WINDOW)
);
}

View File

@ -351,10 +351,6 @@ public class ExpressionUtils {
return builder.build();
}
public static boolean isAllLiteral(Expression... children) {
return Arrays.stream(children).allMatch(c -> c instanceof Literal);
}
public static boolean isAllLiteral(List<Expression> children) {
return children.stream().allMatch(c -> c instanceof Literal);
}

View File

@ -65,7 +65,7 @@ import java.util.stream.Collectors;
class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSupported {
private final LogicalOlapScan scanScore = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.score);
private Plan scanStudent = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student);
private final LogicalOlapScan scanStudent = new LogicalOlapScan(StatementScopeIdGenerator.newRelationId(), PlanConstructor.student);
@Override
protected void runBeforeAll() throws Exception {
@ -114,7 +114,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testPushLimitThroughLeftJoin() {
void testPushLimitThroughLeftJoin() {
test(JoinType.LEFT_OUTER_JOIN, true,
logicalLimit(
logicalProject(
@ -136,7 +136,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testPushLimitThroughRightJoin() {
void testPushLimitThroughRightJoin() {
// after use RelationUtil to allocate relation id, the id will increase when getNextId() called.
test(JoinType.RIGHT_OUTER_JOIN, true,
logicalLimit(
@ -159,7 +159,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testPushLimitThroughCrossJoin() {
void testPushLimitThroughCrossJoin() {
test(JoinType.CROSS_JOIN, true,
logicalLimit(
logicalProject(
@ -181,7 +181,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testPushLimitThroughInnerJoin() {
void testPushLimitThroughInnerJoin() {
test(JoinType.INNER_JOIN, true,
logicalLimit(
logicalProject(
@ -203,7 +203,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testTranslate() {
void testTranslate() {
PlanChecker.from(connectContext).checkPlannerResult("select * from t1 left join t2 on t1.k1=t2.k1 limit 5",
planner -> {
List<PlanFragment> fragments = planner.getFragments();
@ -227,7 +227,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testLimitPushSort() {
void testLimitPushSort() {
PlanChecker.from(connectContext)
.analyze("select k1 from t1 order by k1 limit 1")
.rewrite()
@ -235,7 +235,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testLimitPushUnion() {
void testLimitPushUnion() {
PlanChecker.from(connectContext)
.analyze("select k1 from t1 "
+ "union all select k2 from t2 "
@ -262,7 +262,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testLimitPushWindow() {
void testLimitPushWindow() {
ConnectContext context = MemoTestUtils.createConnectContext();
context.getSessionVariable().setEnablePartitionTopN(true);
NamedExpression grade = scanScore.getOutput().get(2).toSlot();
@ -304,7 +304,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
}
@Test
public void testTopNPushWindow() {
void testTopNPushWindow() {
ConnectContext context = MemoTestUtils.createConnectContext();
context.getSessionVariable().setEnablePartitionTopN(true);
NamedExpression grade = scanScore.getOutput().get(2).toSlot();
@ -322,7 +322,7 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
List<OrderKey> orderKey = ImmutableList.of(
new OrderKey(windowAlias1.toSlot(), true, true)
);
LogicalSort<LogicalWindow> sort = new LogicalSort<>(orderKey, window);
LogicalSort<Plan> sort = new LogicalSort<>(orderKey, window);
LogicalPlan plan = new LogicalPlanBuilder(sort)
.limit(100)
@ -364,8 +364,8 @@ class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSup
LogicalJoin<? extends Plan, ? extends Plan> join = new LogicalJoin<>(
joinType,
joinConditions,
new LogicalOlapScan(((LogicalOlapScan) scanScore).getRelationId(), PlanConstructor.score),
new LogicalOlapScan(((LogicalOlapScan) scanStudent).getRelationId(), PlanConstructor.student)
new LogicalOlapScan(scanScore.getRelationId(), PlanConstructor.score),
new LogicalOlapScan(scanStudent.getRelationId(), PlanConstructor.student)
);
if (hasProject) {

View File

@ -0,0 +1,120 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.util.LogicalPlanBuilder;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.utframe.TestWithFeService;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
class PushdownTopNThroughJoinTest extends TestWithFeService implements MemoPatternMatchSupported {
private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
private static final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0);
@Override
protected void runBeforeAll() throws Exception {
createDatabase("test");
connectContext.setDatabase("default_cluster:test");
createTable("CREATE TABLE `t1` (\n"
+ " `k1` int(11) NOT NULL,\n"
+ " `k2` int(11) NOT NULL\n"
+ ") ENGINE=OLAP\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n"
+ "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
+ "\"in_memory\" = \"false\",\n"
+ "\"storage_format\" = \"V2\",\n"
+ "\"disable_auto_compaction\" = \"false\"\n"
+ ");");
createTable("CREATE TABLE `t2` (\n"
+ " `k1` int(11) NULL,\n"
+ " `k2` int(11) NULL\n"
+ ") ENGINE=OLAP\n"
+ "COMMENT 'OLAP'\n"
+ "DISTRIBUTED BY HASH(`k1`) BUCKETS 3\n"
+ "PROPERTIES (\n"
+ "\"replication_allocation\" = \"tag.location.default: 1\",\n"
+ "\"in_memory\" = \"false\",\n"
+ "\"storage_format\" = \"V2\",\n"
+ "\"disable_auto_compaction\" = \"false\"\n"
+ ");");
}
@Test
void testJoin() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0))
.topN(10, 0, ImmutableList.of(0))
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownTopNThroughJoin())
.matches(
logicalTopN(
logicalJoin(
logicalTopN().when(l -> l.getLimit() == 10 && l.getOffset() == 0),
logicalOlapScan()
)
)
);
}
@Test
void testJoinSql() {
PlanChecker.from(connectContext)
.analyze("select * from t1 left join t2 on t1.k1 = t2.k1 order by t1.k1 limit 10")
.rewrite()
.matches(
logicalTopN(
logicalProject(
logicalJoin(
logicalTopN().when(l -> l.getLimit() == 10 && l.getOffset() == 0),
logicalOlapScan()
)
)
)
);
}
@Test
void badCase() {
LogicalPlan plan = new LogicalPlanBuilder(scan1)
.join(scan2, JoinType.RIGHT_OUTER_JOIN, Pair.of(0, 0))
.topN(10, 0, ImmutableList.of(0))
.build();
PlanChecker.from(connectContext, plan)
.applyTopDown(new PushdownTopNThroughJoin())
.matches(
logicalJoin(
logicalOlapScan(),
logicalOlapScan()
)
);
}
}

View File

@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -142,6 +143,14 @@ public class LogicalPlanBuilder {
return limit(limit, 0);
}
public LogicalPlanBuilder topN(long limit, long offset, List<Integer> orderKeySlotsIndex) {
List<OrderKey> orderKeys = orderKeySlotsIndex.stream()
.map(i -> new OrderKey(this.plan.getOutput().get(i), false, false))
.collect(Collectors.toList());
LogicalTopN<Plan> topNPlan = new LogicalTopN<>(orderKeys, limit, offset, this.plan);
return from(topNPlan);
}
public LogicalPlanBuilder filter(Expression conjunct) {
return filter(ImmutableSet.copyOf(ExpressionUtils.extractConjunction(conjunct)));
}