[opt](Nereids) refine group by elimination column prune (#30953)

This commit is contained in:
xzj7019
2024-02-19 14:13:32 +08:00
committed by yiguolei
parent bb4575a392
commit 5ac4b6a137
2 changed files with 45 additions and 14 deletions

View File

@ -18,7 +18,9 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.rewrite.ColumnPruning.PruneContext;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
@ -38,6 +40,7 @@ import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@ -173,16 +176,18 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
private Plan pruneAggregate(Aggregate agg, PruneContext context) {
// first try to prune group by and aggregate functions
Aggregate prunedOutputAgg = pruneOutput(agg, agg.getOutputs(), agg::pruneOutputs, context);
Set<Integer> enableNereidsRules = ConnectContext.get().getSessionVariable().getEnableNereidsRules();
Aggregate fillUpAggr;
List<Expression> groupByExpressions = prunedOutputAgg.getGroupByExpressions();
List<NamedExpression> outputExpressions = prunedOutputAgg.getOutputExpressions();
if (!enableNereidsRules.contains(RuleType.ELIMINATE_GROUP_BY_KEY.type())) {
fillUpAggr = fillUpGroupByToOutput(prunedOutputAgg)
.map(fullOutput -> prunedOutputAgg.withAggOutput(fullOutput))
.orElse(prunedOutputAgg);
} else {
fillUpAggr = fillUpGroupByAndOutput(prunedOutputAgg);
}
// then fill up group by
Aggregate fillUpOutputRepeat = fillUpGroupByToOutput(groupByExpressions, outputExpressions)
.map(fullOutput -> prunedOutputAgg.withAggOutput(fullOutput))
.orElse(prunedOutputAgg);
return pruneChildren(fillUpOutputRepeat);
return pruneChildren(fillUpAggr);
}
private Plan skipPruneThisAndFirstLevelChildren(Plan plan) {
@ -193,8 +198,9 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
return pruneChildren(plan, requireAllOutputOfChildren);
}
private static Optional<List<NamedExpression>> fillUpGroupByToOutput(
List<Expression> groupBy, List<NamedExpression> output) {
private static Optional<List<NamedExpression>> fillUpGroupByToOutput(Aggregate prunedOutputAgg) {
List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();
if (output.containsAll(groupBy)) {
return Optional.empty();
@ -209,6 +215,34 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
.build());
}
private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) {
List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();
if (!(prunedOutputAgg instanceof LogicalAggregate)) {
return prunedOutputAgg;
}
// add back group by keys which eliminated by rule ELIMINATE_GROUP_BY_KEY
// if related output expressions are not in pruned output list.
List<NamedExpression> remainedOutputExprs = Lists.newArrayList(output);
remainedOutputExprs.removeAll(groupBy);
List<NamedExpression> newOutputList = Lists.newArrayList();
newOutputList.addAll((List) groupBy);
newOutputList.addAll(remainedOutputExprs);
if (!(prunedOutputAgg instanceof LogicalAggregate)) {
return prunedOutputAgg.withAggOutput(newOutputList);
} else {
List<Expression> newGroupByExprList = newOutputList.stream().filter(e ->
!(prunedOutputAgg.getAggregateFunctions().contains(e)
|| e instanceof Alias && prunedOutputAgg.getAggregateFunctions()
.contains(((Alias) e).child()))
).collect(Collectors.toList());
return ((LogicalAggregate) prunedOutputAgg).withGroupByAndOutput(newGroupByExprList, newOutputList);
}
}
/** prune output */
public static <P extends Plan> P pruneOutput(P plan, List<NamedExpression> originOutput,
Function<List<NamedExpression>, P> withPrunedOutput, PruneContext context) {

View File

@ -50,10 +50,7 @@ public class EliminateGroupByKey extends OneRewriteRuleFactory {
List<FdItem> uniqueFdItems = new ArrayList<>();
List<FdItem> nonUniqueFdItems = new ArrayList<>();
if (agg.getGroupByExpressions().isEmpty()
|| agg.getGroupByExpressions().equals(agg.getOutputExpressions())
|| !agg.getGroupByExpressions().stream().allMatch(e -> e instanceof SlotReference)
|| agg.getGroupByExpressions().stream().allMatch(e ->
((SlotReference) e).getColumn().isPresent() && ((SlotReference) e).getTable().isPresent())) {
|| !agg.getGroupByExpressions().stream().allMatch(e -> e instanceof SlotReference)) {
return null;
}
ImmutableSet<FdItem> fdItems = childPlan.getLogicalProperties().getFunctionalDependencies().getFdItems();