[fix](Nereids) push down topn distinct through join by mistake (#31396)

should not push down topn distinct through join when the output
columns of the corresponding child of join is more than
aggregate distinct columns.

for example for LEFT_OUTER_JOIN:

left child of join's output is: c1, c2, c3.
distinct columns is: c1, c2
topn: limit 2

if we push down topn distinct, we could get result of join like this:

```
c1    c2    c3, ...
1     2     1
1     2     2
```

and the final result we get is:

```
c1    c2
1     2
```

this is wrong, because we need 2 lines, but only return 1.
This commit is contained in:
morrySnow
2024-02-28 13:12:12 +08:00
committed by yiguolei
parent 79dd4e24ff
commit d88caca44a
3 changed files with 82 additions and 29 deletions

View File

@ -43,15 +43,15 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
// topN -> join
logicalTopN(logicalAggregate(logicalJoin()).when(a -> a.isDistinct()))
// TODO: complex orderby
logicalTopN(logicalAggregate(logicalJoin()).when(LogicalAggregate::isDistinct))
// TODO: complex order by
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
LogicalAggregate<LogicalJoin<Plan, Plan>> distinct = topN.child();
LogicalJoin<Plan, Plan> join = distinct.child();
Plan newJoin = pushTopNThroughJoin(topN, join);
if (newJoin == null || topN.child().children().equals(newJoin.children())) {
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return topN.withChildren(distinct.withChildren(newJoin));
@ -59,8 +59,8 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {
.toRule(RuleType.PUSH_DOWN_TOP_N_DISTINCT_THROUGH_JOIN),
// topN -> project -> join
logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(p -> p.isAllSlots()))
.when(a -> a.isDistinct()))
logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots))
.when(LogicalAggregate::isDistinct))
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
.allMatch(Slot.class::isInstance))
.then(topN -> {
@ -89,38 +89,90 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {
}
private Plan pushTopNThroughJoin(LogicalTopN<? extends Plan> topN, LogicalJoin<Plan, Plan> join) {
List<Slot> groupBySlots = ((LogicalAggregate<?>) topN.child()).getGroupByExpressions().stream()
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
Set<Slot> groupBySlots = ((LogicalAggregate<?>) topN.child()).getGroupByExpressions().stream()
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet());
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
if (join.left().getOutputSet().containsAll(groupBySlots)) {
LogicalTopN<Plan> left = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
case LEFT_OUTER_JOIN: {
List<OrderKey> pushedOrderKeys = getPushedOrderKeys(groupBySlots,
join.left().getOutputSet(), topN.getOrderKeys());
if (!pushedOrderKeys.isEmpty()) {
LogicalTopN<Plan> left = topN.withLimitOrderKeyAndChild(
topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys,
PlanUtils.distinct(join.left()));
return join.withChildren(left, join.right());
}
return null;
case RIGHT_OUTER_JOIN:
if (join.right().getOutputSet().containsAll(groupBySlots)) {
LogicalTopN<Plan> right = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
}
case RIGHT_OUTER_JOIN: {
List<OrderKey> pushedOrderKeys = getPushedOrderKeys(groupBySlots,
join.right().getOutputSet(), topN.getOrderKeys());
if (!pushedOrderKeys.isEmpty()) {
LogicalTopN<Plan> right = topN.withLimitOrderKeyAndChild(
topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys,
PlanUtils.distinct(join.right()));
return join.withChildren(join.left(), right);
}
return null;
case CROSS_JOIN:
if (join.left().getOutputSet().containsAll(groupBySlots)) {
LogicalTopN<Plan> left = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
}
case CROSS_JOIN: {
Plan leftChild = join.left();
Plan rightChild = join.right();
List<OrderKey> leftPushedOrderKeys = getPushedOrderKeys(groupBySlots,
join.left().getOutputSet(), topN.getOrderKeys());
if (!leftPushedOrderKeys.isEmpty()) {
leftChild = topN.withLimitOrderKeyAndChild(
topN.getLimit() + topN.getOffset(), 0, leftPushedOrderKeys,
PlanUtils.distinct(join.left()));
return join.withChildren(left, join.right());
} else if (join.right().getOutputSet().containsAll(groupBySlots)) {
LogicalTopN<Plan> right = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
}
List<OrderKey> rightPushedOrderKeys = getPushedOrderKeys(groupBySlots,
join.right().getOutputSet(), topN.getOrderKeys());
if (!rightPushedOrderKeys.isEmpty()) {
rightChild = topN.withLimitOrderKeyAndChild(
topN.getLimit() + topN.getOffset(), 0, rightPushedOrderKeys,
PlanUtils.distinct(join.right()));
return join.withChildren(join.left(), right);
} else {
}
if (leftChild == join.left() && rightChild == join.right()) {
return null;
} else {
return join.withChildren(leftChild, rightChild);
}
}
default:
// don't push limit.
return null;
}
}
/**
* return pushed order-keys. If top-n distinct cannot be pushed, return empty list.
*/
private List<OrderKey> getPushedOrderKeys(Set<Slot> groupBySlots, Set<Slot> joinChildSlot,
List<OrderKey> orderKeys) {
// NOTICE: Currently, we have implemented strict restrictions to ensure that the distinct columns is
// a superset of the output from the corresponding child of the join operator. In the future, we can relax
// this restriction and only require that there is overlap between the output of the corresponding child of
// the join operator and the distinct columns.
// However, this would require changes to the optimized plan, converting the pushed-down aggregation distinct
// to the window function "row number". Partition by distinct columns, and a filtering condition of
// "row number = 1" would be added.
if (!groupBySlots.containsAll(joinChildSlot)) {
return ImmutableList.of();
}
// we must check the order of order keys. the slot of non-join-output should not appear before join's output
// other-wise, we will get wrong result, if we push top-n under join.
ImmutableList.Builder<OrderKey> pushedOrderKeys = ImmutableList.builder();
boolean notFound = false;
for (OrderKey orderKey : orderKeys) {
if (joinChildSlot.contains(orderKey.getExpr())) {
if (notFound) {
return ImmutableList.of();
} else {
pushedOrderKeys.add(orderKey);
}
} else {
notFound = true;
}
}
return pushedOrderKeys.build();
}
}

View File

@ -136,6 +136,13 @@ public class LogicalTopN<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYP
return new LogicalTopN<>(orderKeys, limit, offset, child);
}
public LogicalTopN<Plan> withLimitOrderKeyAndChild(long limit, long offset, List<OrderKey> orderKeys, Plan child) {
Preconditions.checkArgument(children.size() == 1,
"LogicalTopN should have 1 child, but input is %s", children.size());
return new LogicalTopN<>(orderKeys, limit, offset,
Optional.empty(), Optional.of(getLogicalProperties()), child);
}
@Override
public LogicalTopN<Plan> withChildren(List<Plan> children) {
Preconditions.checkArgument(children.size() == 1,