[feature](Nereids): push down topN through join (#24720)

Push TopN through Join.

JoinType just can be left/right outer join or cross join, because data of their one child can't be filtered.

new TopN is (original limit + original offset, 0) as limit and offset.
This commit is contained in:
jakevin
2023-10-07 14:58:53 +08:00
committed by GitHub
parent 47694c5b36
commit 3c9ff7af39
12 changed files with 312 additions and 69 deletions

View File

@ -73,6 +73,7 @@ import org.apache.doris.nereids.rules.rewrite.InferJoinNotNull;
import org.apache.doris.nereids.rules.rewrite.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.InferSetOperatorDistinct;
import org.apache.doris.nereids.rules.rewrite.LeadingJoin;
import org.apache.doris.nereids.rules.rewrite.LimitSortToTopN;
import org.apache.doris.nereids.rules.rewrite.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.MergeOneRowRelationIntoUnion;
import org.apache.doris.nereids.rules.rewrite.MergeProjects;
@ -90,9 +91,9 @@ import org.apache.doris.nereids.rules.rewrite.PushProjectIntoOneRowRelation;
import org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion;
import org.apache.doris.nereids.rules.rewrite.PushdownFilterThroughProject;
import org.apache.doris.nereids.rules.rewrite.PushdownLimit;
import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughJoin;
import org.apache.doris.nereids.rules.rewrite.PushdownTopNThroughWindow;
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
import org.apache.doris.nereids.rules.rewrite.ReplaceLimitNode;
import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren;
import org.apache.doris.nereids.rules.rewrite.SemiJoinCommute;
import org.apache.doris.nereids.rules.rewrite.SimplifyAggGroupBy;
@ -275,9 +276,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
// we should refactor like AggregateStrategies, e.g. LimitStrategies,
// generate one PhysicalLimit if current distribution is gather or two
// PhysicalLimits with gather exchange
new ReplaceLimitNode(),
new LimitSortToTopN(),
new SplitLimit(),
new PushdownLimit(),
new PushdownTopNThroughJoin(),
new PushdownTopNThroughWindow(),
new CreatePartitionTopNFromWindow()
)

View File

@ -165,8 +165,6 @@ public enum RuleType {
COLUMN_PRUNING(RuleTypeClass.REWRITE),
ELIMINATE_SORT(RuleTypeClass.REWRITE),
PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW(RuleTypeClass.REWRITE),
PUSHDOWN_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
PUSHDOWN_MIN_MAX_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSHDOWN_SUM_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSHDOWN_COUNT_THROUGH_JOIN(RuleTypeClass.REWRITE),
@ -248,7 +246,12 @@ public enum RuleType {
PUSH_LIMIT_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_WINDOW(RuleTypeClass.REWRITE),
PUSH_LIMIT_INTO_SORT(RuleTypeClass.REWRITE),
LIMIT_SORT_TO_TOP_N(RuleTypeClass.REWRITE),
// topN push down
PUSH_TOP_N_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_PROJECT_WINDOW(RuleTypeClass.REWRITE),
PUSH_TOP_N_THROUGH_WINDOW(RuleTypeClass.REWRITE),
// adjust nullable
ADJUST_NULLABLE(RuleTypeClass.REWRITE),
ADJUST_CONJUNCTS_RETURN_TYPE(RuleTypeClass.REWRITE),

View File

@ -19,18 +19,35 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* Eliminate limit = 0.
*/
public class EliminateLimit extends OneRewriteRuleFactory {
public class EliminateLimit implements RewriteRuleFactory {
@Override
public Rule build() {
return logicalLimit()
.when(limit -> limit.getLimit() == 0)
.thenApply(ctx -> new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(),
ctx.root.getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT);
public List<Rule> buildRules() {
return ImmutableList.of(
logicalLimit()
.when(limit -> limit.getLimit() == 0)
.thenApply(ctx -> new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(),
ctx.root.getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT),
logicalLimit(logicalOneRowRelation())
.then(limit -> limit.getLimit() > 0 && limit.getOffset() == 0
? limit.child() : new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(),
limit.child().getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION),
logicalLimit(logicalEmptyRelation())
.then(UnaryNode::child)
.toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION)
);
}
}

View File

@ -19,10 +19,7 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
@ -35,7 +32,7 @@ import java.util.List;
/**
* rule to eliminate limit node by replace to other nodes.
*/
public class ReplaceLimitNode implements RewriteRuleFactory {
public class LimitSortToTopN implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
@ -47,8 +44,8 @@ public class ReplaceLimitNode implements RewriteRuleFactory {
limit.getLimit(),
limit.getOffset(),
sort.child(0));
}).toRule(RuleType.PUSH_LIMIT_INTO_SORT),
//limit->proj->sort ==> proj->topN
}).toRule(RuleType.LIMIT_SORT_TO_TOP_N),
// limit -> proj -> sort ==> proj -> topN
logicalLimit(logicalProject(logicalSort()))
.then(limit -> {
LogicalProject<LogicalSort<Plan>> project = limit.child();
@ -58,15 +55,7 @@ public class ReplaceLimitNode implements RewriteRuleFactory {
limit.getOffset(),
sort.child(0));
return project.withChildren(Lists.newArrayList(topN));
}).toRule(RuleType.PUSH_LIMIT_INTO_SORT),
logicalLimit(logicalOneRowRelation())
.then(limit -> limit.getLimit() > 0 && limit.getOffset() == 0
? limit.child() : new LogicalEmptyRelation(StatementScopeIdGenerator.newRelationId(),
limit.child().getOutput()))
.toRule(RuleType.ELIMINATE_LIMIT_ON_ONE_ROW_RELATION),
logicalLimit(logicalEmptyRelation())
.then(UnaryNode::child)
.toRule(RuleType.ELIMINATE_LIMIT_ON_EMPTY_RELATION)
}).toRule(RuleType.LIMIT_SORT_TO_TOP_N)
);
}
}

View File

@ -41,27 +41,26 @@ public class PushdownFilterThroughProject implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT.build(logicalFilter(logicalProject())
.whenNot(filter -> filter.child().getProjects().stream().anyMatch(
expr -> expr.anyMatch(WindowExpression.class::isInstance)))
.then(PushdownFilterThroughProject::pushdownFilterThroughProject)),
// filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit,
// then we should change filter(limit(project)) to project(filter(limit))
RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT
.build(logicalFilter(logicalLimit(logicalProject()))
.whenNot(filter -> filter.child().child().getProjects().stream()
.anyMatch(expr -> expr
.anyMatch(WindowExpression.class::isInstance)))
.then(filter -> {
LogicalLimit<LogicalProject<Plan>> limit = filter.child();
LogicalProject<Plan> project = limit.child();
logicalFilter(logicalProject())
.whenNot(filter -> filter.child().getProjects().stream().anyMatch(
expr -> expr.anyMatch(WindowExpression.class::isInstance)))
.then(PushdownFilterThroughProject::pushdownFilterThroughProject)
.toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT),
// filter(project(limit)) will change to filter(limit(project)) by PushdownProjectThroughLimit,
// then we should change filter(limit(project)) to project(filter(limit))
logicalFilter(logicalLimit(logicalProject()))
.whenNot(filter -> filter.child().child().getProjects().stream()
.anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance)))
.then(filter -> {
LogicalLimit<LogicalProject<Plan>> limit = filter.child();
LogicalProject<Plan> project = limit.child();
return project.withProjectsAndChild(project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(filter.getConjuncts(),
project.getAliasToProducer()),
limit.withChildren(project.child())));
}))
return project.withProjectsAndChild(project.getProjects(),
new LogicalFilter<>(
ExpressionUtils.replace(filter.getConjuncts(),
project.getAliasToProducer()),
limit.withChildren(project.child())));
}).toRule(RuleType.PUSHDOWN_FILTER_THROUGH_PROJECT_UNDER_LIMIT)
);
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
/**
* <pre>
* Before:
* project
* │
@ -42,6 +43,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
* │
* ▼
* plan node
* </pre>
*/
public class PushdownProjectThroughLimit extends OneRewriteRuleFactory {
@ -50,9 +52,7 @@ public class PushdownProjectThroughLimit extends OneRewriteRuleFactory {
return logicalProject(logicalLimit()).thenApply(ctx -> {
LogicalProject<LogicalLimit<Plan>> logicalProject = ctx.root;
LogicalLimit<Plan> logicalLimit = logicalProject.child();
return new LogicalLimit<>(logicalLimit.getLimit(), logicalLimit.getOffset(),
logicalLimit.getPhase(), logicalProject.withProjectsAndChild(logicalProject.getProjects(),
logicalLimit.child()));
return logicalLimit.withChildren(logicalProject.withChildren(logicalLimit.child()));
}).toRule(RuleType.PUSHDOWN_PROJECT_THROUGH_LIMIT);
}
}

View File

@ -0,0 +1,108 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.util.Utils;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
/**
* Push down TopN through Outer Join into left child .....
*/
public class PushdownTopNThroughJoin implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
// topN -> join
logicalTopN(logicalJoin())
// TODO: complex orderby
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
LogicalJoin<Plan, Plan> join = topN.child();
Plan newJoin = pushLimitThroughJoin(topN, join);
if (newJoin == null || topN.child().children().equals(newJoin.children())) {
return null;
}
return topN.withChildren(newJoin);
})
.toRule(RuleType.PUSH_TOP_N_THROUGH_JOIN),
// topN -> project -> join
logicalTopN(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots))
// TODO: complex project
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
LogicalProject<LogicalJoin<Plan, Plan>> project = topN.child();
LogicalJoin<Plan, Plan> join = project.child();
Plan newJoin = pushLimitThroughJoin(topN, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return topN.withChildren(project.withChildren(newJoin));
}).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_JOIN)
);
}
private Plan pushLimitThroughJoin(LogicalTopN<? extends Plan> topN, LogicalJoin<Plan, Plan> join) {
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
Set<Slot> rightOutputSet = join.right().getOutputSet();
if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
.anyMatch(e -> Utils.isIntersecting(rightOutputSet, e.getInputSlots()))) {
return null;
}
return join.withChildren(topN.withChildren(join.left()), join.right());
case RIGHT_OUTER_JOIN:
Set<Slot> leftOutputSet = join.left().getOutputSet();
if (topN.getOrderKeys().stream().map(OrderKey::getExpr)
.anyMatch(e -> Utils.isIntersecting(leftOutputSet, e.getInputSlots()))) {
return null;
}
return join.withChildren(join.left(), topN.withChildren(join.right()));
case CROSS_JOIN:
List<Slot> orderbySlots = topN.getOrderKeys().stream().map(OrderKey::getExpr)
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
if (join.left().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(topN.withChildren(join.left()), join.right());
} else if (join.right().getOutputSet().containsAll(orderbySlots)) {
return join.withChildren(join.left(), topN.withChildren(join.right()));
} else {
return null;
}
default:
// don't push limit.
return null;
}
}
}

View File

@ -59,7 +59,7 @@ public class PushdownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
return topn.withChildren(newWindow.get());
}).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_WINDOW),
}).toRule(RuleType.PUSH_TOP_N_THROUGH_WINDOW),
// topn -> projection -> window
logicalTopN(logicalProject(logicalWindow())).then(topn -> {
@ -79,7 +79,7 @@ public class PushdownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
return topn.withChildren(project.withChildren(newWindow.get()));
}).toRule(RuleType.PUSHDOWN_TOP_N_THROUGH_PROJECTION_WINDOW)
}).toRule(RuleType.PUSH_TOP_N_THROUGH_PROJECT_WINDOW)
);
}

View File

@ -351,10 +351,6 @@ public class ExpressionUtils {
return builder.build();
}
public static boolean isAllLiteral(Expression... children) {
return Arrays.stream(children).allMatch(c -> c instanceof Literal);
}
public static boolean isAllLiteral(List<Expression> children) {
return children.stream().allMatch(c -> c instanceof Literal);
}