[refactor](Nereids): refactor PushdownLimit (#17355)
This commit is contained in:
@ -52,7 +52,6 @@ import org.apache.doris.nereids.rules.rewrite.logical.InferFilterNotNull;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.InferJoinNotNull;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.InferPredicates;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.InnerToCrossJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.LimitPushDown;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeFilters;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeProjects;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.MergeSetOperations;
|
||||
@ -60,6 +59,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.NormalizeAggregate;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanPartition;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PruneOlapScanTablet;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushFilterInsideJoin;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.PushdownLimit;
|
||||
import org.apache.doris.nereids.rules.rewrite.logical.ReorderJoin;
|
||||
|
||||
import java.util.List;
|
||||
@ -191,7 +191,7 @@ public class NereidsRewriter extends BatchRewriteJob {
|
||||
new PruneOlapScanTablet(),
|
||||
new EliminateAggregate(),
|
||||
new MergeSetOperations(),
|
||||
new LimitPushDown(),
|
||||
new PushdownLimit(),
|
||||
new BuildAggForUnion()
|
||||
)),
|
||||
|
||||
|
||||
@ -192,6 +192,8 @@ public enum RuleType {
|
||||
PUSH_LIMIT_THROUGH_JOIN(RuleTypeClass.REWRITE),
|
||||
PUSH_LIMIT_THROUGH_PROJECT_JOIN(RuleTypeClass.REWRITE),
|
||||
PUSH_LIMIT_THROUGH_UNION(RuleTypeClass.REWRITE),
|
||||
PUSH_LIMIT_THROUGH_ONE_ROW_RELATION(RuleTypeClass.REWRITE),
|
||||
PUSH_LIMIT_THROUGH_EMPTY_RELATION(RuleTypeClass.REWRITE),
|
||||
|
||||
// adjust nullable
|
||||
ADJUST_NULLABLE_ON_AGGREGATE(RuleTypeClass.REWRITE),
|
||||
|
||||
@ -20,10 +20,9 @@ package org.apache.doris.nereids.rules.rewrite.logical;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory;
|
||||
import org.apache.doris.nereids.trees.UnaryNode;
|
||||
import org.apache.doris.nereids.trees.plans.Plan;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.EmptyRelation;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.Limit;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.OneRowRelation;
|
||||
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
|
||||
@ -40,14 +39,20 @@ import java.util.List;
|
||||
* <p>
|
||||
* Limit can't be push down if it has a valid offset info.
|
||||
*/
|
||||
public class LimitPushDown implements RewriteRuleFactory {
|
||||
public class PushdownLimit implements RewriteRuleFactory {
|
||||
|
||||
@Override
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
// limit -> join
|
||||
logicalLimit(logicalJoin(any(), any())).whenNot(Limit::hasValidOffset)
|
||||
.then(limit -> limit.withChildren(pushLimitThroughJoin(limit, limit.child())))
|
||||
.then(limit -> {
|
||||
Plan newJoin = pushLimitThroughJoin(limit, limit.child());
|
||||
if (newJoin == null || limit.child().children().equals(newJoin.children())) {
|
||||
return null;
|
||||
}
|
||||
return limit.withChildren(newJoin);
|
||||
})
|
||||
.toRule(RuleType.PUSH_LIMIT_THROUGH_JOIN),
|
||||
|
||||
// limit -> project -> join
|
||||
@ -55,9 +60,11 @@ public class LimitPushDown implements RewriteRuleFactory {
|
||||
.then(limit -> {
|
||||
LogicalProject<LogicalJoin<Plan, Plan>> project = limit.child();
|
||||
LogicalJoin<Plan, Plan> join = project.child();
|
||||
return limit.withChildren(
|
||||
project.withChildren(
|
||||
pushLimitThroughJoin(limit, join)));
|
||||
Plan newJoin = pushLimitThroughJoin(limit, join);
|
||||
if (newJoin == null || join.children().equals(newJoin.children())) {
|
||||
return null;
|
||||
}
|
||||
return limit.withChildren(project.withChildren(newJoin));
|
||||
}).toRule(RuleType.PUSH_LIMIT_THROUGH_PROJECT_JOIN),
|
||||
|
||||
// limit -> union
|
||||
@ -67,11 +74,22 @@ public class LimitPushDown implements RewriteRuleFactory {
|
||||
LogicalUnion union = limit.child();
|
||||
ImmutableList<Plan> newUnionChildren = union.children()
|
||||
.stream()
|
||||
.map(child -> addLimit(limit, child))
|
||||
.map(child -> limit.withChildren(child))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
if (union.children().equals(newUnionChildren)) {
|
||||
return null;
|
||||
}
|
||||
return limit.withChildren(union.withChildren(newUnionChildren));
|
||||
})
|
||||
.toRule(RuleType.PUSH_LIMIT_THROUGH_UNION)
|
||||
.toRule(RuleType.PUSH_LIMIT_THROUGH_UNION),
|
||||
logicalLimit(logicalOneRowRelation())
|
||||
.then(limit -> limit.getLimit() > 0
|
||||
? limit.child() : new LogicalEmptyRelation(limit.child().getOutput()))
|
||||
.toRule(RuleType.PUSH_LIMIT_THROUGH_ONE_ROW_RELATION),
|
||||
logicalLimit(logicalEmptyRelation())
|
||||
.then(UnaryNode::child)
|
||||
.toRule(RuleType.PUSH_LIMIT_THROUGH_EMPTY_RELATION),
|
||||
new MergeLimits().build()
|
||||
);
|
||||
}
|
||||
|
||||
@ -79,53 +97,22 @@ public class LimitPushDown implements RewriteRuleFactory {
|
||||
switch (join.getJoinType()) {
|
||||
case LEFT_OUTER_JOIN:
|
||||
return join.withChildren(
|
||||
addLimit(limit, join.left()),
|
||||
limit.withChildren(join.left()),
|
||||
join.right()
|
||||
);
|
||||
case RIGHT_OUTER_JOIN:
|
||||
return join.withChildren(
|
||||
join.left(),
|
||||
addLimit(limit, join.right())
|
||||
limit.withChildren(join.right())
|
||||
);
|
||||
case CROSS_JOIN:
|
||||
return join.withChildren(
|
||||
addLimit(limit, join.left()),
|
||||
addLimit(limit, join.right())
|
||||
limit.withChildren(join.left()),
|
||||
limit.withChildren(join.right())
|
||||
);
|
||||
case INNER_JOIN:
|
||||
if (join.hasJoinCondition()) {
|
||||
return join;
|
||||
} else {
|
||||
return join.withChildren(
|
||||
addLimit(limit, join.left()),
|
||||
addLimit(limit, join.right())
|
||||
);
|
||||
}
|
||||
default:
|
||||
// don't push limit.
|
||||
return join;
|
||||
}
|
||||
}
|
||||
|
||||
private Plan addLimit(LogicalLimit<? extends Plan> pushdownLimit, Plan plan) {
|
||||
if (plan instanceof LogicalLimit) {
|
||||
// Avoid adding duplicate limits on top of the plan, otherwise would result in dead loop
|
||||
// when applying the rule multiple times.
|
||||
LogicalLimit<? extends Plan> limit = (LogicalLimit<? extends Plan>) plan;
|
||||
// plan is pure limit and limit value > push down limit value
|
||||
if (!limit.hasValidOffset() && limit.getLimit() > pushdownLimit.getLimit()) {
|
||||
// replace limit.
|
||||
return pushdownLimit.withChildren(limit.child());
|
||||
} else {
|
||||
// return input plan.
|
||||
return plan;
|
||||
}
|
||||
} else if (plan instanceof OneRowRelation) {
|
||||
return pushdownLimit.getLimit() > 0 ? plan : new LogicalEmptyRelation(plan.getOutput());
|
||||
} else if (plan instanceof EmptyRelation) {
|
||||
return plan;
|
||||
} else {
|
||||
return pushdownLimit.withChildren(plan);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -48,7 +48,7 @@ import java.util.Set;
|
||||
import java.util.function.Function;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSupported {
|
||||
class PushdownLimitTest extends TestWithFeService implements MemoPatternMatchSupported {
|
||||
private Plan scanScore = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.score);
|
||||
private Plan scanStudent = new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.student);
|
||||
|
||||
@ -173,7 +173,7 @@ class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSup
|
||||
logicalJoin(
|
||||
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
|
||||
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
|
||||
).when(j -> j.getJoinType() == JoinType.INNER_JOIN)
|
||||
)
|
||||
)
|
||||
)
|
||||
);
|
||||
@ -182,7 +182,7 @@ class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSup
|
||||
logicalJoin(
|
||||
logicalLimit(logicalOlapScan().when(s -> s.equals(scanScore))),
|
||||
logicalLimit(logicalOlapScan().when(s -> s.equals(scanStudent)))
|
||||
).when(j -> j.getJoinType() == JoinType.INNER_JOIN)
|
||||
)
|
||||
)
|
||||
);
|
||||
}
|
||||
@ -241,7 +241,8 @@ class LimitPushDownTest extends TestWithFeService implements MemoPatternMatchSup
|
||||
Plan plan = generatePlan(joinType, hasProject);
|
||||
PlanChecker.from(MemoTestUtils.createConnectContext())
|
||||
.analyze(plan)
|
||||
.applyTopDown(new LimitPushDown())
|
||||
.applyTopDown(new InnerToCrossJoin())
|
||||
.applyTopDown(new PushdownLimit())
|
||||
.matchesFromRoot(pattern);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user