[opt](Nereids) refine group by elimination column prune (#30953)
This commit is contained in:
@ -18,7 +18,9 @@
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.jobs.JobContext;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.rules.rewrite.ColumnPruning.PruneContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Alias;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
@ -38,6 +40,7 @@ import org.apache.doris.nereids.trees.plans.logical.OutputPrunable;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
@ -173,16 +176,18 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
|
||||
private Plan pruneAggregate(Aggregate agg, PruneContext context) {
|
||||
// first try to prune group by and aggregate functions
|
||||
Aggregate prunedOutputAgg = pruneOutput(agg, agg.getOutputs(), agg::pruneOutputs, context);
|
||||
Set<Integer> enableNereidsRules = ConnectContext.get().getSessionVariable().getEnableNereidsRules();
|
||||
Aggregate fillUpAggr;
|
||||
|
||||
List<Expression> groupByExpressions = prunedOutputAgg.getGroupByExpressions();
|
||||
List<NamedExpression> outputExpressions = prunedOutputAgg.getOutputExpressions();
|
||||
if (!enableNereidsRules.contains(RuleType.ELIMINATE_GROUP_BY_KEY.type())) {
|
||||
fillUpAggr = fillUpGroupByToOutput(prunedOutputAgg)
|
||||
.map(fullOutput -> prunedOutputAgg.withAggOutput(fullOutput))
|
||||
.orElse(prunedOutputAgg);
|
||||
} else {
|
||||
fillUpAggr = fillUpGroupByAndOutput(prunedOutputAgg);
|
||||
}
|
||||
|
||||
// then fill up group by
|
||||
Aggregate fillUpOutputRepeat = fillUpGroupByToOutput(groupByExpressions, outputExpressions)
|
||||
.map(fullOutput -> prunedOutputAgg.withAggOutput(fullOutput))
|
||||
.orElse(prunedOutputAgg);
|
||||
|
||||
return pruneChildren(fillUpOutputRepeat);
|
||||
return pruneChildren(fillUpAggr);
|
||||
}
|
||||
|
||||
private Plan skipPruneThisAndFirstLevelChildren(Plan plan) {
|
||||
@ -193,8 +198,9 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
|
||||
return pruneChildren(plan, requireAllOutputOfChildren);
|
||||
}
|
||||
|
||||
private static Optional<List<NamedExpression>> fillUpGroupByToOutput(
|
||||
List<Expression> groupBy, List<NamedExpression> output) {
|
||||
private static Optional<List<NamedExpression>> fillUpGroupByToOutput(Aggregate prunedOutputAgg) {
|
||||
List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
|
||||
List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();
|
||||
|
||||
if (output.containsAll(groupBy)) {
|
||||
return Optional.empty();
|
||||
@ -209,6 +215,34 @@ public class ColumnPruning extends DefaultPlanRewriter<PruneContext> implements
|
||||
.build());
|
||||
}
|
||||
|
||||
private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) {
|
||||
List<Expression> groupBy = prunedOutputAgg.getGroupByExpressions();
|
||||
List<NamedExpression> output = prunedOutputAgg.getOutputExpressions();
|
||||
|
||||
if (!(prunedOutputAgg instanceof LogicalAggregate)) {
|
||||
return prunedOutputAgg;
|
||||
}
|
||||
// add back group by keys which eliminated by rule ELIMINATE_GROUP_BY_KEY
|
||||
// if related output expressions are not in pruned output list.
|
||||
List<NamedExpression> remainedOutputExprs = Lists.newArrayList(output);
|
||||
remainedOutputExprs.removeAll(groupBy);
|
||||
|
||||
List<NamedExpression> newOutputList = Lists.newArrayList();
|
||||
newOutputList.addAll((List) groupBy);
|
||||
newOutputList.addAll(remainedOutputExprs);
|
||||
|
||||
if (!(prunedOutputAgg instanceof LogicalAggregate)) {
|
||||
return prunedOutputAgg.withAggOutput(newOutputList);
|
||||
} else {
|
||||
List<Expression> newGroupByExprList = newOutputList.stream().filter(e ->
|
||||
!(prunedOutputAgg.getAggregateFunctions().contains(e)
|
||||
|| e instanceof Alias && prunedOutputAgg.getAggregateFunctions()
|
||||
.contains(((Alias) e).child()))
|
||||
).collect(Collectors.toList());
|
||||
return ((LogicalAggregate) prunedOutputAgg).withGroupByAndOutput(newGroupByExprList, newOutputList);
|
||||
}
|
||||
}
|
||||
|
||||
/** prune output */
|
||||
public static <P extends Plan> P pruneOutput(P plan, List<NamedExpression> originOutput,
|
||||
Function<List<NamedExpression>, P> withPrunedOutput, PruneContext context) {
|
||||
|
||||
@ -50,10 +50,7 @@ public class EliminateGroupByKey extends OneRewriteRuleFactory {
|
||||
List<FdItem> uniqueFdItems = new ArrayList<>();
|
||||
List<FdItem> nonUniqueFdItems = new ArrayList<>();
|
||||
if (agg.getGroupByExpressions().isEmpty()
|
||||
|| agg.getGroupByExpressions().equals(agg.getOutputExpressions())
|
||||
|| !agg.getGroupByExpressions().stream().allMatch(e -> e instanceof SlotReference)
|
||||
|| agg.getGroupByExpressions().stream().allMatch(e ->
|
||||
((SlotReference) e).getColumn().isPresent() && ((SlotReference) e).getTable().isPresent())) {
|
||||
|| !agg.getGroupByExpressions().stream().allMatch(e -> e instanceof SlotReference)) {
|
||||
return null;
|
||||
}
|
||||
ImmutableSet<FdItem> fdItems = childPlan.getLogicalProperties().getFunctionalDependencies().getFdItems();
|
||||
|
||||
Reference in New Issue
Block a user