[refactor](Nereids): refactor PushdownLimit (#17355)

This commit is contained in:
jakevin
2023-03-07 12:04:20 +08:00
committed by GitHub
parent b0e3156f51
commit b8c9875adb
4 changed files with 41 additions and 51 deletions

View File

@ -52,7 +52,6 @@ import org.apache.doris.nereids.rules.rewrite.logical.InferFilterNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.InferJoinNotNull;
import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates;
import org.apache.doris.nereids.rules.rewrite.logical.InnerToCrossJoin;
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
import org.apache.doris.nereids.rules.rewrite.logical.MergeSetOperations;
@ -60,6 +59,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanTablet;
import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin;
import org.apache.doris.nereids.rules.rewrite.logical.PushdownLimit;
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
import java.util.List;
@ -191,7 +191,7 @@ public class NereidsRewriter extends BatchRewriteJob {
new PruneOlapScanTablet(),
new EliminateAggregate(),
new MergeSetOperations(),
new LimitPushDown(),
new PushdownLimit(),
new BuildAggForUnion()
)),

View File

@ -192,6 +192,8 @@ public enum RuleType {
PUSH_LIMIT_THROUGH_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_ONE_ROW_RELATION(RuleTypeClass.REWRITE),
PUSH_LIMIT_THROUGH_EMPTY_RELATION(RuleTypeClass.REWRITE),
// adjust nullable
ADJUST_NULLABLE_ON_AGGREGATE(RuleTypeClass.REWRITE),

View File

@ -20,10 +20,9 @@ package org.apache.doris.nereids.rules.rewrite.logical;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
import org.apache.doris.nereids.trees.UnaryNode;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
import org.apache.doris.nereids.trees.plans.algebra.Limit;
import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
@ -40,14 +39,20 @@ import java.util.List;
* <p>
* Limit can't be push down if it has a valid offset info.
*/
public class LimitPushDown implements RewriteRuleFactory {
public class PushdownLimit implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
// limit -> join
logicalLimit(logicalJoin(any(), any())).whenNot(Limit::hasValidOffset)
.then(limit -> limit.withChildren(pushLimitThroughJoin(limit, limit.child())))
.then(limit -> {
Plan newJoin = pushLimitThroughJoin(limit, limit.child());
if (newJoin == null || limit.child().children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(newJoin);
})
.toRule(RuleType.PUSH_LIMIT_THROUGH_JOIN),
// limit -> project -> join
@ -55,9 +60,11 @@ public class LimitPushDown implements RewriteRuleFactory {
.then(limit -> {
LogicalProject<LogicalJoin<Plan, Plan>> project = limit.child();
LogicalJoin<Plan, Plan> join = project.child();
return limit.withChildren(
project.withChildren(
pushLimitThroughJoin(limit, join)));
Plan newJoin = pushLimitThroughJoin(limit, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(project.withChildren(newJoin));
}).toRule(RuleType.PUSH_LIMIT_THROUGH_PROJECT_JOIN),
// limit -> union
@ -67,11 +74,22 @@ public class LimitPushDown implements RewriteRuleFactory {
LogicalUnion union = limit.child();
ImmutableList<Plan> newUnionChildren = union.children()
.stream()
.map(child -> addLimit(limit, child))
.map(child -> limit.withChildren(child))
.collect(ImmutableList.toImmutableList());
if (union.children().equals(newUnionChildren)) {
return null;
}
return limit.withChildren(union.withChildren(newUnionChildren));
})
.toRule(RuleType.PUSH_LIMIT_THROUGH_UNION)
.toRule(RuleType.PUSH_LIMIT_THROUGH_UNION),
logicalLimit(logicalOneRowRelation())
.then(limit -> limit.getLimit() > 0
? limit.child() : new LogicalEmptyRelation(limit.child().getOutput()))
.toRule(RuleType.PUSH_LIMIT_THROUGH_ONE_ROW_RELATION),
logicalLimit(logicalEmptyRelation())
.then(UnaryNode::child)
.toRule(RuleType.PUSH_LIMIT_THROUGH_EMPTY_RELATION),
new MergeLimits().build()
);
}
@ -79,53 +97,22 @@ public class LimitPushDown implements RewriteRuleFactory {
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
return join.withChildren(
addLimit(limit, join.left()),
limit.withChildren(join.left()),
join.right()
);
case RIGHT_OUTER_JOIN:
return join.withChildren(
join.left(),
addLimit(limit, join.right())
limit.withChildren(join.right())
);
case CROSS_JOIN:
return join.withChildren(
addLimit(limit, join.left()),
addLimit(limit, join.right())
limit.withChildren(join.left()),
limit.withChildren(join.right())
);
case INNER_JOIN:
if (join.hasJoinCondition()) {
return join;
} else {
return join.withChildren(
addLimit(limit, join.left()),
addLimit(limit, join.right())
);
}
default:
// don't push limit.
return join;
}
}
private Plan addLimit(LogicalLimit<? extends Plan> pushdownLimit, Plan plan) {
if (plan instanceof LogicalLimit) {
// Avoid adding duplicate limits on top of the plan, otherwise would result in dead loop
// when applying the rule multiple times.
LogicalLimit<? extends Plan> limit = (LogicalLimit<? extends Plan>) plan;
// plan is pure limit and limit value > push down limit value
if (!limit.hasValidOffset() && limit.getLimit() > pushdownLimit.getLimit()) {
// replace limit.
return pushdownLimit.withChildren(limit.child());
} else {
// return input plan.
return plan;
}
} else if (plan instanceof OneRowRelation) {
return pushdownLimit.getLimit() > 0 ? plan : new LogicalEmptyRelation(plan.getOutput());
} else if (plan instanceof EmptyRelation) {
return plan;
} else {
return pushdownLimit.withChildren(plan);
return null;
}
}
}

View File

@ -48,7 +48,7 @@ import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSupported {
class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSupported {
private Plan scanScore = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.score);
private Plan scanStudent = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
@ -173,7 +173,7 @@ class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSup
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.INNER_JOIN)
)
)
)
);
@ -182,7 +182,7 @@ class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSup
logicalJoin(
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
).when(j -> j.getJoinType() == JoinType.INNER_JOIN)
)
)
);
}
@ -241,7 +241,8 @@ class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSup
Plan plan = generatePlan(joinType, hasProject);
PlanChecker.from(MemoTestUtils.createConnectContext())
.analyze(plan)
.applyTopDown(new LimitPushDown())
.applyTopDown(new InnerToCrossJoin())
.applyTopDown(new PushdownLimit())
.matchesFromRoot(pattern);
}