[opt](nereids) support partitionTopn for multi window exprs (#39687)

## Proposed changes

pick from https://github.com/apache/doris/pull/38393

Co-authored-by: xiongzhongjian <xiongzhongjian@selectdb.com>
This commit is contained in:
xzj7019
2024-08-22 10:34:36 +08:00
committed by GitHub
parent 021982fc71
commit 8f580b523f
5 changed files with 371 additions and 153 deletions

View File

@ -17,32 +17,16 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
/**
* Push down the 'partitionTopN' into the 'window'.
* It will convert the filter condition to the 'limit value' and push down below the 'window'.
@ -89,82 +73,17 @@ public class CreatePartitionTopNFromWindow extends OneRewriteRuleFactory {
return filter;
}
List<NamedExpression> windowExprs = window.getWindowExpressions();
if (windowExprs.size() != 1) {
Pair<WindowExpression, Long> windowFuncPair = window.getPushDownWindowFuncAndLimit(filter, Long.MAX_VALUE);
if (windowFuncPair == null) {
return filter;
} else if (windowFuncPair.second == -1) {
// limit -1 indicating a empty relation case
return new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), filter.getOutput());
} else {
Plan newWindow = window.pushPartitionLimitThroughWindow(windowFuncPair.first,
windowFuncPair.second, false);
return filter.withChildren(newWindow);
}
NamedExpression windowExpr = windowExprs.get(0);
if (windowExpr.children().size() != 1 || !(windowExpr.child(0) instanceof WindowExpression)) {
return filter;
}
// Check the filter conditions. Now, we currently only support simple conditions of the form
// 'column </ <=/ = constant'. We will extract some related conjuncts and do some check.
Set<Expression> conjuncts = filter.getConjuncts();
Set<Expression> relatedConjuncts = extractRelatedConjuncts(conjuncts, windowExpr.getExprId());
boolean hasPartitionLimit = false;
long partitionLimit = Long.MAX_VALUE;
for (Expression conjunct : relatedConjuncts) {
Preconditions.checkArgument(conjunct instanceof BinaryOperator);
BinaryOperator op = (BinaryOperator) conjunct;
Expression leftChild = op.children().get(0);
Expression rightChild = op.children().get(1);
Preconditions.checkArgument(leftChild instanceof SlotReference
&& rightChild instanceof IntegerLikeLiteral);
long limitVal = ((IntegerLikeLiteral) rightChild).getLongValue();
// Adjust the value for 'limitVal' based on the comparison operators.
if (conjunct instanceof LessThan) {
limitVal--;
}
if (limitVal < 0) {
return new LogicalEmptyRelation(ctx.statementContext.getNextRelationId(), filter.getOutput());
}
if (hasPartitionLimit) {
partitionLimit = Math.min(partitionLimit, limitVal);
} else {
partitionLimit = limitVal;
hasPartitionLimit = true;
}
}
if (!hasPartitionLimit) {
return filter;
}
Optional<Plan> newWindow = window.pushPartitionLimitThroughWindow(partitionLimit, false);
if (!newWindow.isPresent()) {
return filter;
}
return filter.withChildren(newWindow.get());
}).toRule(RuleType.CREATE_PARTITION_TOPN_FOR_WINDOW);
}
private Set<Expression> extractRelatedConjuncts(Set<Expression> conjuncts, ExprId slotRefID) {
Predicate<Expression> condition = conjunct -> {
if (!(conjunct instanceof BinaryOperator)) {
return false;
}
BinaryOperator op = (BinaryOperator) conjunct;
Expression leftChild = op.children().get(0);
Expression rightChild = op.children().get(1);
if (!(conjunct instanceof LessThan || conjunct instanceof LessThanEqual || conjunct instanceof EqualTo)) {
return false;
}
// TODO: Now, we only support the column on the left side.
if (!(leftChild instanceof SlotReference) || !(rightChild instanceof IntegerLikeLiteral)) {
return false;
}
return ((SlotReference) leftChild).getExprId() == slotRefID;
};
return conjuncts.stream()
.filter(condition)
.collect(ImmutableSet.toImmutableSet());
}
}

View File

@ -17,8 +17,10 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Limit;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
@ -31,7 +33,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
/**
* Rules to push {@link org.apache.doris.nereids.trees.plans.logical.LogicalLimit} down.
@ -72,11 +73,17 @@ public class PushDownLimit implements RewriteRuleFactory {
.then(limit -> {
LogicalWindow<Plan> window = limit.child();
long partitionLimit = limit.getLimit() + limit.getOffset();
Optional<Plan> newWindow = window.pushPartitionLimitThroughWindow(partitionLimit, true);
if (!newWindow.isPresent()) {
if (partitionLimit <= 0) {
return limit;
}
return limit.withChildren(newWindow.get());
Pair<WindowExpression, Long> windowFuncLongPair = window
.getPushDownWindowFuncAndLimit(null, partitionLimit);
if (windowFuncLongPair == null) {
return limit;
}
Plan newWindow = window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
windowFuncLongPair.second, true);
return limit.withChildren(newWindow);
}).toRule(RuleType.PUSH_LIMIT_THROUGH_WINDOW),
// limit -> project -> window
@ -85,11 +92,17 @@ public class PushDownLimit implements RewriteRuleFactory {
LogicalProject<LogicalWindow<Plan>> project = limit.child();
LogicalWindow<Plan> window = project.child();
long partitionLimit = limit.getLimit() + limit.getOffset();
Optional<Plan> newWindow = window.pushPartitionLimitThroughWindow(partitionLimit, true);
if (!newWindow.isPresent()) {
if (partitionLimit <= 0) {
return limit;
}
return limit.withChildren(project.withChildren(newWindow.get()));
Pair<WindowExpression, Long> windowFuncLongPair = window
.getPushDownWindowFuncAndLimit(null, partitionLimit);
if (windowFuncLongPair == null) {
return limit;
}
Plan newWindow = window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
windowFuncLongPair.second, true);
return limit.withChildren(project.withChildren(newWindow));
}).toRule(RuleType.PUSH_LIMIT_THROUGH_PROJECT_WINDOW),
// limit -> union

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.OrderKey;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
@ -33,7 +34,6 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
/**
* PushdownTopNThroughWindow push down the TopN through the Window and generate the PartitionTopN.
@ -54,11 +54,14 @@ public class PushDownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
long partitionLimit = topn.getLimit() + topn.getOffset();
Optional<Plan> newWindow = window.pushPartitionLimitThroughWindow(partitionLimit, true);
if (!newWindow.isPresent()) {
Pair<WindowExpression, Long> windowFuncLongPair = window
.getPushDownWindowFuncAndLimit(null, partitionLimit);
if (windowFuncLongPair == null) {
return topn;
}
return topn.withChildren(newWindow.get());
Plan newWindow = window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
windowFuncLongPair.second, true);
return topn.withChildren(newWindow);
}).toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_WINDOW),
// topn -> projection -> window
@ -74,11 +77,14 @@ public class PushDownTopNThroughWindow implements RewriteRuleFactory {
return topn;
}
long partitionLimit = topn.getLimit() + topn.getOffset();
Optional<Plan> newWindow = window.pushPartitionLimitThroughWindow(partitionLimit, true);
if (!newWindow.isPresent()) {
Pair<WindowExpression, Long> windowFuncLongPair = window
.getPushDownWindowFuncAndLimit(null, partitionLimit);
if (windowFuncLongPair == null) {
return topn;
}
return topn.withChildren(project.withChildren(newWindow.get()));
Plan newWindow = window.pushPartitionLimitThroughWindow(windowFuncLongPair.first,
windowFuncLongPair.second, true);
return topn.withChildren(project.withChildren(newWindow));
}).toRule(RuleType.PUSH_DOWN_TOP_N_THROUGH_PROJECT_WINDOW)
);
}

View File

@ -17,19 +17,27 @@
package org.apache.doris.nereids.trees.plans.logical;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.FdItem;
import org.apache.doris.nereids.properties.FunctionalDependencies;
import org.apache.doris.nereids.properties.FunctionalDependencies.Builder;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.trees.expressions.BinaryOperator;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.WindowFrame;
import org.apache.doris.nereids.trees.expressions.functions.window.DenseRank;
import org.apache.doris.nereids.trees.expressions.functions.window.Rank;
import org.apache.doris.nereids.trees.expressions.functions.window.RowNumber;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.algebra.Window;
@ -44,6 +52,8 @@ import com.google.common.collect.ImmutableSet;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.function.Predicate;
/**
* logical node to deal with window functions;
@ -170,62 +180,172 @@ public class LogicalWindow<CHILD_TYPE extends Plan> extends LogicalUnary<CHILD_T
return Objects.hash(windowExpressions, isChecked);
}
/**
* Get push down window function candidate and corresponding partition limit.
*
* @param filter
* For partition topN filter cases, it means the topN filter;
* For partition limit cases, it will be null.
* @param partitionLimit
* For partition topN filter cases, it means the filter boundary,
* e.g, 100 for the case rn <= 100;
* For partition limit cases, it means the limit.
* @return
* Return null means invalid cases or the opt option is disabled,
* else return the chosen window function and the chosen partition limit.
* A special limit -1 means the case can be optimized as empty relation.
*/
public Pair<WindowExpression, Long> getPushDownWindowFuncAndLimit(LogicalFilter<?> filter, long partitionLimit) {
if (!ConnectContext.get().getSessionVariable().isEnablePartitionTopN()) {
return null;
}
// We have already done such optimization rule, so just ignore it.
if (child(0) instanceof LogicalPartitionTopN
|| (child(0) instanceof LogicalFilter
&& child(0).child(0) != null
&& child(0).child(0) instanceof LogicalPartitionTopN)) {
return null;
}
// Check the window function. There are some restrictions for window function:
// 1. The window function should be one of the 'row_number()', 'rank()', 'dense_rank()'.
// 2. The window frame should be 'UNBOUNDED' to 'CURRENT'.
// 3. The 'PARTITION' key and 'ORDER' key can not be empty at the same time.
WindowExpression chosenWindowFunc = null;
long chosenPartitionLimit = Long.MAX_VALUE;
long chosenRowNumberPartitionLimit = Long.MAX_VALUE;
boolean hasRowNumber = false;
for (NamedExpression windowExpr : windowExpressions) {
if (windowExpr == null || windowExpr.children().size() != 1
|| !(windowExpr.child(0) instanceof WindowExpression)) {
continue;
}
WindowExpression windowFunc = (WindowExpression) windowExpr.child(0);
// Check the window function name.
if (!(windowFunc.getFunction() instanceof RowNumber
|| windowFunc.getFunction() instanceof Rank
|| windowFunc.getFunction() instanceof DenseRank)) {
continue;
}
// Check the partition key and order key.
if (windowFunc.getPartitionKeys().isEmpty() && windowFunc.getOrderKeys().isEmpty()) {
continue;
}
// Check the window type and window frame.
Optional<WindowFrame> windowFrame = windowFunc.getWindowFrame();
if (windowFrame.isPresent()) {
WindowFrame frame = windowFrame.get();
if (!(frame.getLeftBoundary().getFrameBoundType() == WindowFrame.FrameBoundType.UNBOUNDED_PRECEDING
&& frame.getRightBoundary().getFrameBoundType() == WindowFrame.FrameBoundType.CURRENT_ROW)) {
continue;
}
} else {
continue;
}
// Check filter conditions.
if (filter != null) {
// We currently only support simple conditions of the form 'column </ <=/ = constant'.
// We will extract some related conjuncts and do some check.
boolean hasPartitionLimit = false;
long curPartitionLimit = Long.MAX_VALUE;
Set<Expression> conjuncts = filter.getConjuncts();
Set<Expression> relatedConjuncts = extractRelatedConjuncts(conjuncts, windowExpr.getExprId());
for (Expression conjunct : relatedConjuncts) {
// Pre-checking has been done in former extraction
BinaryOperator op = (BinaryOperator) conjunct;
Expression rightChild = op.children().get(1);
long limitVal = ((IntegerLikeLiteral) rightChild).getLongValue();
// Adjust the value for 'limitVal' based on the comparison operators.
if (conjunct instanceof LessThan) {
limitVal--;
}
if (limitVal < 0) {
// Set return limit value as -1 for indicating a empty relation opt case
chosenPartitionLimit = -1;
chosenRowNumberPartitionLimit = -1;
break;
}
if (hasPartitionLimit) {
curPartitionLimit = Math.min(curPartitionLimit, limitVal);
} else {
curPartitionLimit = limitVal;
hasPartitionLimit = true;
}
}
if (chosenPartitionLimit == -1) {
chosenWindowFunc = windowFunc;
break;
} else if (windowFunc.getFunction() instanceof RowNumber) {
// choose row_number first any way
// if multiple exists, choose the one with minimal limit value
if (curPartitionLimit < chosenRowNumberPartitionLimit) {
chosenRowNumberPartitionLimit = curPartitionLimit;
chosenWindowFunc = windowFunc;
hasRowNumber = true;
}
} else if (!hasRowNumber) {
// if no row_number, choose the one with minimal limit value
if (curPartitionLimit < chosenPartitionLimit) {
chosenPartitionLimit = curPartitionLimit;
chosenWindowFunc = windowFunc;
}
}
} else {
// limit
chosenWindowFunc = windowFunc;
chosenPartitionLimit = partitionLimit;
if (windowFunc.getFunction() instanceof RowNumber) {
break;
}
}
}
if (chosenWindowFunc == null || (chosenPartitionLimit == Long.MAX_VALUE
&& chosenRowNumberPartitionLimit == Long.MAX_VALUE)) {
return null;
} else {
return Pair.of(chosenWindowFunc, hasRowNumber ? chosenRowNumberPartitionLimit : chosenPartitionLimit);
}
}
/**
* pushPartitionLimitThroughWindow is used to push the partitionLimit through the window
* and generate the partitionTopN. If the window can not meet the requirement,
* it will return null. So when we use this function, we need check the null in the outside.
*/
public Optional<Plan> pushPartitionLimitThroughWindow(long partitionLimit, boolean hasGlobalLimit) {
if (!ConnectContext.get().getSessionVariable().isEnablePartitionTopN()) {
return Optional.empty();
}
// We have already done such optimization rule, so just ignore it.
if (child(0) instanceof LogicalPartitionTopN) {
return Optional.empty();
}
// Check the window function. There are some restrictions for window function:
// 1. The number of window function should be 1.
// 2. The window function should be one of the 'row_number()', 'rank()', 'dense_rank()'.
// 3. The window frame should be 'UNBOUNDED' to 'CURRENT'.
// 4. The 'PARTITION' key and 'ORDER' key can not be empty at the same time.
if (windowExpressions.size() != 1) {
return Optional.empty();
}
NamedExpression windowExpr = windowExpressions.get(0);
if (windowExpr.children().size() != 1 || !(windowExpr.child(0) instanceof WindowExpression)) {
return Optional.empty();
}
WindowExpression windowFunc = (WindowExpression) windowExpr.child(0);
// Check the window function name.
if (!(windowFunc.getFunction() instanceof RowNumber
|| windowFunc.getFunction() instanceof Rank
|| windowFunc.getFunction() instanceof DenseRank)) {
return Optional.empty();
}
// Check the partition key and order key.
if (windowFunc.getPartitionKeys().isEmpty() && windowFunc.getOrderKeys().isEmpty()) {
return Optional.empty();
}
// Check the window type and window frame.
Optional<WindowFrame> windowFrame = windowFunc.getWindowFrame();
if (windowFrame.isPresent()) {
WindowFrame frame = windowFrame.get();
if (!(frame.getLeftBoundary().getFrameBoundType() == WindowFrame.FrameBoundType.UNBOUNDED_PRECEDING
&& frame.getRightBoundary().getFrameBoundType() == WindowFrame.FrameBoundType.CURRENT_ROW)) {
return Optional.empty();
}
} else {
return Optional.empty();
}
public Plan pushPartitionLimitThroughWindow(WindowExpression windowFunc,
long partitionLimit, boolean hasGlobalLimit) {
LogicalWindow<?> window = (LogicalWindow<?>) withChildren(new LogicalPartitionTopN<>(windowFunc, hasGlobalLimit,
partitionLimit, child(0)));
return window;
}
return Optional.ofNullable(window);
private Set<Expression> extractRelatedConjuncts(Set<Expression> conjuncts, ExprId slotRefID) {
Predicate<Expression> condition = conjunct -> {
if (!(conjunct instanceof BinaryOperator)) {
return false;
}
BinaryOperator op = (BinaryOperator) conjunct;
Expression leftChild = op.children().get(0);
Expression rightChild = op.children().get(1);
if (!(conjunct instanceof LessThan || conjunct instanceof LessThanEqual || conjunct instanceof EqualTo)) {
return false;
}
// TODO: Now, we only support the column on the left side.
if (!(leftChild instanceof SlotReference) || !(rightChild instanceof IntegerLikeLiteral)) {
return false;
}
return ((SlotReference) leftChild).getExprId() == slotRefID;
};
return conjuncts.stream()
.filter(condition)
.collect(ImmutableSet.toImmutableSet());
}
private boolean isUnique(NamedExpression namedExpression) {