[fix](Nereids) push down topn distinct through join by mistake (#31396)
should not push down topn distinct through join when the output columns of the corresponding child of join is more than aggregate distinct columns. for example for LEFT_OUTER_JOIN: left child of join's output is: c1, c2, c3. distinct columns is: c1, c2 topn: limit 2 if we push down topn distinct, we could get result of join like this: ``` c1 c2 c3, ... 1 2 1 1 2 2 ``` and the final result we get is: ``` c1 c2 1 2 ``` this is wrong, because we need 2 lines, but only return 1.
This commit is contained in:
@ -43,15 +43,15 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
// topN -> join
|
||||
logicalTopN(logicalAggregate(logicalJoin()).when(a -> a.isDistinct()))
|
||||
// TODO: complex orderby
|
||||
logicalTopN(logicalAggregate(logicalJoin()).when(LogicalAggregate::isDistinct))
|
||||
// TODO: complex order by
|
||||
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
|
||||
.allMatch(Slot.class::isInstance))
|
||||
.then(topN -> {
|
||||
LogicalAggregate<LogicalJoin<Plan, Plan>> distinct = topN.child();
|
||||
LogicalJoin<Plan, Plan> join = distinct.child();
|
||||
Plan newJoin = pushTopNThroughJoin(topN, join);
|
||||
if (newJoin == null || topN.child().children().equals(newJoin.children())) {
|
||||
if (newJoin == null || join.children().equals(newJoin.children())) {
|
||||
return null;
|
||||
}
|
||||
return topN.withChildren(distinct.withChildren(newJoin));
|
||||
@ -59,8 +59,8 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {
|
||||
.toRule(RuleType.PUSH_DOWN_TOP_N_DISTINCT_THROUGH_JOIN),
|
||||
|
||||
// topN -> project -> join
|
||||
logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(p -> p.isAllSlots()))
|
||||
.when(a -> a.isDistinct()))
|
||||
logicalTopN(logicalAggregate(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots))
|
||||
.when(LogicalAggregate::isDistinct))
|
||||
.when(topN -> topN.getOrderKeys().stream().map(OrderKey::getExpr)
|
||||
.allMatch(Slot.class::isInstance))
|
||||
.then(topN -> {
|
||||
@ -89,38 +89,90 @@ public class PushDownTopNDistinctThroughJoin implements RewriteRuleFactory {
|
||||
}
|
||||
|
||||
private Plan pushTopNThroughJoin(LogicalTopN<? extends Plan> topN, LogicalJoin<Plan, Plan> join) {
|
||||
List<Slot> groupBySlots = ((LogicalAggregate<?>) topN.child()).getGroupByExpressions().stream()
|
||||
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
|
||||
Set<Slot> groupBySlots = ((LogicalAggregate<?>) topN.child()).getGroupByExpressions().stream()
|
||||
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet());
|
||||
switch (join.getJoinType()) {
|
||||
case LEFT_OUTER_JOIN:
|
||||
if (join.left().getOutputSet().containsAll(groupBySlots)) {
|
||||
LogicalTopN<Plan> left = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
|
||||
case LEFT_OUTER_JOIN: {
|
||||
List<OrderKey> pushedOrderKeys = getPushedOrderKeys(groupBySlots,
|
||||
join.left().getOutputSet(), topN.getOrderKeys());
|
||||
if (!pushedOrderKeys.isEmpty()) {
|
||||
LogicalTopN<Plan> left = topN.withLimitOrderKeyAndChild(
|
||||
topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys,
|
||||
PlanUtils.distinct(join.left()));
|
||||
return join.withChildren(left, join.right());
|
||||
}
|
||||
return null;
|
||||
case RIGHT_OUTER_JOIN:
|
||||
if (join.right().getOutputSet().containsAll(groupBySlots)) {
|
||||
LogicalTopN<Plan> right = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
|
||||
}
|
||||
case RIGHT_OUTER_JOIN: {
|
||||
List<OrderKey> pushedOrderKeys = getPushedOrderKeys(groupBySlots,
|
||||
join.right().getOutputSet(), topN.getOrderKeys());
|
||||
if (!pushedOrderKeys.isEmpty()) {
|
||||
LogicalTopN<Plan> right = topN.withLimitOrderKeyAndChild(
|
||||
topN.getLimit() + topN.getOffset(), 0, pushedOrderKeys,
|
||||
PlanUtils.distinct(join.right()));
|
||||
return join.withChildren(join.left(), right);
|
||||
}
|
||||
return null;
|
||||
case CROSS_JOIN:
|
||||
if (join.left().getOutputSet().containsAll(groupBySlots)) {
|
||||
LogicalTopN<Plan> left = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
|
||||
}
|
||||
case CROSS_JOIN: {
|
||||
Plan leftChild = join.left();
|
||||
Plan rightChild = join.right();
|
||||
List<OrderKey> leftPushedOrderKeys = getPushedOrderKeys(groupBySlots,
|
||||
join.left().getOutputSet(), topN.getOrderKeys());
|
||||
if (!leftPushedOrderKeys.isEmpty()) {
|
||||
leftChild = topN.withLimitOrderKeyAndChild(
|
||||
topN.getLimit() + topN.getOffset(), 0, leftPushedOrderKeys,
|
||||
PlanUtils.distinct(join.left()));
|
||||
return join.withChildren(left, join.right());
|
||||
} else if (join.right().getOutputSet().containsAll(groupBySlots)) {
|
||||
LogicalTopN<Plan> right = topN.withLimitChild(topN.getLimit() + topN.getOffset(), 0,
|
||||
}
|
||||
List<OrderKey> rightPushedOrderKeys = getPushedOrderKeys(groupBySlots,
|
||||
join.right().getOutputSet(), topN.getOrderKeys());
|
||||
if (!rightPushedOrderKeys.isEmpty()) {
|
||||
rightChild = topN.withLimitOrderKeyAndChild(
|
||||
topN.getLimit() + topN.getOffset(), 0, rightPushedOrderKeys,
|
||||
PlanUtils.distinct(join.right()));
|
||||
return join.withChildren(join.left(), right);
|
||||
} else {
|
||||
}
|
||||
if (leftChild == join.left() && rightChild == join.right()) {
|
||||
return null;
|
||||
} else {
|
||||
return join.withChildren(leftChild, rightChild);
|
||||
}
|
||||
}
|
||||
default:
|
||||
// don't push limit.
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* return pushed order-keys. If top-n distinct cannot be pushed, return empty list.
|
||||
*/
|
||||
private List<OrderKey> getPushedOrderKeys(Set<Slot> groupBySlots, Set<Slot> joinChildSlot,
|
||||
List<OrderKey> orderKeys) {
|
||||
// NOTICE: Currently, we have implemented strict restrictions to ensure that the distinct columns is
|
||||
// a superset of the output from the corresponding child of the join operator. In the future, we can relax
|
||||
// this restriction and only require that there is overlap between the output of the corresponding child of
|
||||
// the join operator and the distinct columns.
|
||||
// However, this would require changes to the optimized plan, converting the pushed-down aggregation distinct
|
||||
// to the window function "row number". Partition by distinct columns, and a filtering condition of
|
||||
// "row number = 1" would be added.
|
||||
if (!groupBySlots.containsAll(joinChildSlot)) {
|
||||
return ImmutableList.of();
|
||||
}
|
||||
// we must check the order of order keys. the slot of non-join-output should not appear before join's output
|
||||
// other-wise, we will get wrong result, if we push top-n under join.
|
||||
ImmutableList.Builder<OrderKey> pushedOrderKeys = ImmutableList.builder();
|
||||
boolean notFound = false;
|
||||
for (OrderKey orderKey : orderKeys) {
|
||||
if (joinChildSlot.contains(orderKey.getExpr())) {
|
||||
if (notFound) {
|
||||
return ImmutableList.of();
|
||||
} else {
|
||||
pushedOrderKeys.add(orderKey);
|
||||
}
|
||||
} else {
|
||||
notFound = true;
|
||||
}
|
||||
}
|
||||
return pushedOrderKeys.build();
|
||||
}
|
||||
}
|
||||
|
||||
@ -136,6 +136,13 @@ public class LogicalTopN<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_TYP
|
||||
return new LogicalTopN<>(orderKeys, limit, offset, child);
|
||||
}
|
||||
|
||||
public LogicalTopN<Plan> withLimitOrderKeyAndChild(long limit, long offset, List<OrderKey> orderKeys, Plan child) {
|
||||
Preconditions.checkArgument(children.size() == 1,
|
||||
"LogicalTopN should have 1 child, but input is %s", children.size());
|
||||
return new LogicalTopN<>(orderKeys, limit, offset,
|
||||
Optional.empty(), Optional.of(getLogicalProperties()), child);
|
||||
}
|
||||
|
||||
@Override
|
||||
public LogicalTopN<Plan> withChildren(List<Plan> children) {
|
||||
Preconditions.checkArgument(children.size() == 1,
|
||||
|
||||
Reference in New Issue
Block a user