diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java index dcb330cd28..284ac52e14 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/ColumnPruning.java @@ -18,7 +18,9 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.jobs.JobContext; +import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.ColumnPruning.PruneContext; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; @@ -38,6 +40,7 @@ import org.apache.doris.nereids.trees.plans.logical.OutputPrunable; import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter; import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanRewriter; import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; @@ -173,16 +176,18 @@ public class ColumnPruning extends DefaultPlanRewriter implements private Plan pruneAggregate(Aggregate agg, PruneContext context) { // first try to prune group by and aggregate functions Aggregate prunedOutputAgg = pruneOutput(agg, agg.getOutputs(), agg::pruneOutputs, context); + Set enableNereidsRules = ConnectContext.get().getSessionVariable().getEnableNereidsRules(); + Aggregate fillUpAggr; - List groupByExpressions = prunedOutputAgg.getGroupByExpressions(); - List outputExpressions = prunedOutputAgg.getOutputExpressions(); + if (!enableNereidsRules.contains(RuleType.ELIMINATE_GROUP_BY_KEY.type())) { + fillUpAggr = fillUpGroupByToOutput(prunedOutputAgg) + .map(fullOutput -> prunedOutputAgg.withAggOutput(fullOutput)) + .orElse(prunedOutputAgg); + } else { + fillUpAggr = fillUpGroupByAndOutput(prunedOutputAgg); + } - // then fill up group by - Aggregate fillUpOutputRepeat = fillUpGroupByToOutput(groupByExpressions, outputExpressions) - .map(fullOutput -> prunedOutputAgg.withAggOutput(fullOutput)) - .orElse(prunedOutputAgg); - - return pruneChildren(fillUpOutputRepeat); + return pruneChildren(fillUpAggr); } private Plan skipPruneThisAndFirstLevelChildren(Plan plan) { @@ -193,8 +198,9 @@ public class ColumnPruning extends DefaultPlanRewriter implements return pruneChildren(plan, requireAllOutputOfChildren); } - private static Optional> fillUpGroupByToOutput( - List groupBy, List output) { + private static Optional> fillUpGroupByToOutput(Aggregate prunedOutputAgg) { + List groupBy = prunedOutputAgg.getGroupByExpressions(); + List output = prunedOutputAgg.getOutputExpressions(); if (output.containsAll(groupBy)) { return Optional.empty(); @@ -209,6 +215,34 @@ public class ColumnPruning extends DefaultPlanRewriter implements .build()); } + private static Aggregate fillUpGroupByAndOutput(Aggregate prunedOutputAgg) { + List groupBy = prunedOutputAgg.getGroupByExpressions(); + List output = prunedOutputAgg.getOutputExpressions(); + + if (!(prunedOutputAgg instanceof LogicalAggregate)) { + return prunedOutputAgg; + } + // add back group by keys which eliminated by rule ELIMINATE_GROUP_BY_KEY + // if related output expressions are not in pruned output list. + List remainedOutputExprs = Lists.newArrayList(output); + remainedOutputExprs.removeAll(groupBy); + + List newOutputList = Lists.newArrayList(); + newOutputList.addAll((List) groupBy); + newOutputList.addAll(remainedOutputExprs); + + if (!(prunedOutputAgg instanceof LogicalAggregate)) { + return prunedOutputAgg.withAggOutput(newOutputList); + } else { + List newGroupByExprList = newOutputList.stream().filter(e -> + !(prunedOutputAgg.getAggregateFunctions().contains(e) + || e instanceof Alias && prunedOutputAgg.getAggregateFunctions() + .contains(((Alias) e).child())) + ).collect(Collectors.toList()); + return ((LogicalAggregate) prunedOutputAgg).withGroupByAndOutput(newGroupByExprList, newOutputList); + } + } + /** prune output */ public static

P pruneOutput(P plan, List originOutput, Function, P> withPrunedOutput, PruneContext context) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java index d922252beb..69a34a680e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java @@ -50,10 +50,7 @@ public class EliminateGroupByKey extends OneRewriteRuleFactory { List uniqueFdItems = new ArrayList<>(); List nonUniqueFdItems = new ArrayList<>(); if (agg.getGroupByExpressions().isEmpty() - || agg.getGroupByExpressions().equals(agg.getOutputExpressions()) - || !agg.getGroupByExpressions().stream().allMatch(e -> e instanceof SlotReference) - || agg.getGroupByExpressions().stream().allMatch(e -> - ((SlotReference) e).getColumn().isPresent() && ((SlotReference) e).getTable().isPresent())) { + || !agg.getGroupByExpressions().stream().allMatch(e -> e instanceof SlotReference)) { return null; } ImmutableSet fdItems = childPlan.getLogicalProperties().getFunctionalDependencies().getFdItems();