diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughJoin.java index 98d1cdd347..f2dde7ba2a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PushDownTopNDistinctThroughJoin.java @@ -43,15 +43,15 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory { public List buildRules() { return ImmutableList.of( // topN -> join - logicalTopN(logicalAggregate(logicalJoin()).when(a -> a.isDistinct())) - // TODO: complex orderby + logicalTopN(logicalAggregate(logicalJoin()).when(LogicalAggregate::isDistinct)) + // TODO: complex order by .when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr) .allMatch(Slot.class::isInstance)) .then(topN -> { LogicalAggregate> distinct = topN.child(); LogicalJoin join = distinct.child(); Plan newJoin = pushTopNThroughJoin(topN, join); - if (newJoin == null || topN.child().children().equals(newJoin.children())) { + if (newJoin == null || join.children().equals(newJoin.children())) { return null; } return topN.withChildren(distinct.withChildren(newJoin)); @@ -59,8 +59,8 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory { .toRule(RuleType.PUSH_DOWN_TOP_N_DISTINCT_THROUGH_JOIN), // topN -> project -> join - logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(p -> p.isAllSlots())) - .when(a -> a.isDistinct())) + logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots)) + .when(LogicalAggregate::isDistinct)) .when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr) .allMatch(Slot.class::isInstance)) .then(topN -> { @@ -89,38 +89,90 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory { } private Plan pushTopNThroughJoin(LogicalTopN topN, LogicalJoin join) { - List groupBySlots = ((LogicalAggregate) topN.child()).getGroupByExpressions().stream() - .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList()); + Set groupBySlots = ((LogicalAggregate) topN.child()).getGroupByExpressions().stream() + .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet()); switch (join.getJoinType()) { - case LEFT_OUTER_JOIN: - if (join.left().getOutputSet().containsAll(groupBySlots)) { - LogicalTopN left = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, + case LEFT_OUTER_JOIN: { + List pushedOrderKeys = getPushedOrderKeys(groupBySlots, + join.left().getOutputSet(), topN.getOrderKeys()); + if (!pushedOrderKeys.isEmpty()) { + LogicalTopN left = topN.withLimitOrderKeyAndChild( + topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys, PlanUtils.distinct(join.left())); return join.withChildren(left, join.right()); } return null; - case RIGHT_OUTER_JOIN: - if (join.right().getOutputSet().containsAll(groupBySlots)) { - LogicalTopN right = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, + } + case RIGHT_OUTER_JOIN: { + List pushedOrderKeys = getPushedOrderKeys(groupBySlots, + join.right().getOutputSet(), topN.getOrderKeys()); + if (!pushedOrderKeys.isEmpty()) { + LogicalTopN right = topN.withLimitOrderKeyAndChild( + topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys, PlanUtils.distinct(join.right())); return join.withChildren(join.left(), right); } return null; - case CROSS_JOIN: - if (join.left().getOutputSet().containsAll(groupBySlots)) { - LogicalTopN left = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, + } + case CROSS_JOIN: { + Plan leftChild = join.left(); + Plan rightChild = join.right(); + List leftPushedOrderKeys = getPushedOrderKeys(groupBySlots, + join.left().getOutputSet(), topN.getOrderKeys()); + if (!leftPushedOrderKeys.isEmpty()) { + leftChild = topN.withLimitOrderKeyAndChild( + topN.getLimit() + topN.getOffset(), 0, leftPushedOrderKeys, PlanUtils.distinct(join.left())); - return join.withChildren(left, join.right()); - } else if (join.right().getOutputSet().containsAll(groupBySlots)) { - LogicalTopN right = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0, + } + List rightPushedOrderKeys = getPushedOrderKeys(groupBySlots, + join.right().getOutputSet(), topN.getOrderKeys()); + if (!rightPushedOrderKeys.isEmpty()) { + rightChild = topN.withLimitOrderKeyAndChild( + topN.getLimit() + topN.getOffset(), 0, rightPushedOrderKeys, PlanUtils.distinct(join.right())); - return join.withChildren(join.left(), right); - } else { + } + if (leftChild == join.left() && rightChild == join.right()) { return null; + } else { + return join.withChildren(leftChild, rightChild); } + } default: // don't push limit. return null; } } + + /** + * return pushed order-keys. If top-n distinct cannot be pushed, return empty list. + */ + private List getPushedOrderKeys(Set groupBySlots, Set joinChildSlot, + List orderKeys) { + // NOTICE: Currently, we have implemented strict restrictions to ensure that the distinct columns is + // a superset of the output from the corresponding child of the join operator. In the future, we can relax + // this restriction and only require that there is overlap between the output of the corresponding child of + // the join operator and the distinct columns. + // However, this would require changes to the optimized plan, converting the pushed-down aggregation distinct + // to the window function "row number". Partition by distinct columns, and a filtering condition of + // "row number = 1" would be added. + if (!groupBySlots.containsAll(joinChildSlot)) { + return ImmutableList.of(); + } + // we must check the order of order keys. the slot of non-join-output should not appear before join's output + // other-wise, we will get wrong result, if we push top-n under join. + ImmutableList.Builder pushedOrderKeys = ImmutableList.builder(); + boolean notFound = false; + for (OrderKey orderKey : orderKeys) { + if (joinChildSlot.contains(orderKey.getExpr())) { + if (notFound) { + return ImmutableList.of(); + } else { + pushedOrderKeys.add(orderKey); + } + } else { + notFound = true; + } + } + return pushedOrderKeys.build(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java index 7f7f9b7a40..63e3dc9c0b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalTopN.java @@ -136,6 +136,13 @@ public class LogicalTopN extends LogicalUnary(orderKeys, limit, offset, child); } + public LogicalTopN withLimitOrderKeyAndChild(long limit, long offset, List orderKeys, Plan child) { + Preconditions.checkArgument(children.size() == 1, + "LogicalTopN should have 1 child, but input is %s", children.size()); + return new LogicalTopN<>(orderKeys, limit, offset, + Optional.empty(), Optional.of(getLogicalProperties()), child); + } + @Override public LogicalTopN withChildren(List children) { Preconditions.checkArgument(children.size() == 1, diff --git a/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out b/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out index b74605e19e..d1ab585baa 100644 --- a/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out +++ b/regression-test/data/nereids_rules_p0/limit_push_down/order_push_down.out @@ -147,10 +147,7 @@ PhysicalResultSink ------hashAgg[GLOBAL] --------hashAgg[LOCAL] ----------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -------------PhysicalTopN[MERGE_SORT] ---------------PhysicalTopN[LOCAL_SORT] -----------------hashAgg[LOCAL] -------------------PhysicalOlapScan[t1] +------------PhysicalOlapScan[t1] ------------PhysicalOlapScan[t2] -- !limit_distinct -- @@ -160,10 +157,7 @@ PhysicalResultSink ------hashAgg[GLOBAL] --------hashAgg[LOCAL] ----------hashJoin[LEFT_OUTER_JOIN] hashCondition=((t1.id = t2.id)) otherCondition=() -------------PhysicalTopN[MERGE_SORT] ---------------PhysicalTopN[LOCAL_SORT] -----------------hashAgg[LOCAL] -------------------PhysicalOlapScan[t1] +------------PhysicalOlapScan[t1] ------------PhysicalOlapScan[t2] -- !limit_window --