diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index e90528a9d4..169d5a901a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -43,19 +43,18 @@ import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import org.jetbrains.annotations.NotNull; import java.util.ArrayList; -import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; import javax.annotation.Nullable; /** NormalizeRepeat @@ -117,19 +116,34 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { } private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { - Set needToSlots = collectNeedToSlotExpressions(repeat); - NormalizeToSlotContext context = buildContext(repeat, needToSlots); + Set needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat); + NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr); + Map groupingExprMap = groupingExprContext.getNormalizeToSlotMap(); + Set existsAlias = getExistsAlias(repeat, groupingExprMap); + Set needToSlotsArgs = collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(repeat); + NormalizeToSlotContext argsContext = NormalizeToSlotContext.buildContext(existsAlias, needToSlotsArgs); // normalize grouping sets to List> - List> normalizedGroupingSets = repeat.getGroupingSets() - .stream() - .map(groupingSet -> (List) (List) context.normalizeToUseSlotRef(groupingSet)) - .collect(ImmutableList.toImmutableList()); + ImmutableList.Builder> normalizedGroupingSetBuilder = ImmutableList.builder(); + for (List groupingSet : repeat.getGroupingSets()) { + List normalizedSet = (List) (List) groupingExprContext.normalizeToUseSlotRef(groupingSet); + normalizedGroupingSetBuilder.add(normalizedSet); + } + List> normalizedGroupingSets = normalizedGroupingSetBuilder.build(); - // replace the arguments of grouping scalar function to virtual slots - // replace some complex expression to slot, e.g. `a + 1` - List normalizedAggOutput = context.normalizeToUseSlotRef( - repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction); + // use argsContext + // rewrite the arguments of grouping scalar function to slots + // rewrite grouping scalar function to virtual slots + // rewrite the arguments of agg function to slots + List normalizedAggOutput = Lists.newArrayList(); + for (Expression expr : repeat.getOutputExpressions()) { + Expression rewrittenExpr = expr.rewriteDownShortCircuit( + e -> normalizeAggFuncChildrenAndGroupingScalarFunc(argsContext, e)); + normalizedAggOutput.add((NamedExpression) rewrittenExpr); + } + + // use groupingExprContext rewrite the normalizedAggOutput + normalizedAggOutput = groupingExprContext.normalizeToUseSlotRef(normalizedAggOutput); Set virtualSlotsInFunction = ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance); @@ -156,7 +170,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { .addAll(allVirtualSlots) .build(); - Set pushedProject = context.pushDownToNamedExpression(needToSlots); + // 3 parts need push down: + // flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction + Set needToSlots = Sets.union(needToSlotsArgs, needToSlotsGroupingExpr); + NormalizeToSlotContext fullContext = argsContext.mergeContext(groupingExprContext); + Set pushedProject = fullContext.pushDownToNamedExpression(needToSlots); + Plan normalizedChild = pushDownProject(pushedProject, repeat.child()); LogicalRepeat normalizedRepeat = repeat.withNormalizedExpr( @@ -170,42 +189,43 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { Optional.of(normalizedRepeat), normalizedRepeat); } - private Set collectNeedToSlotExpressions(LogicalRepeat repeat) { - // 3 parts need push down: - // flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction - - Set flattenGroupingSetExpr = ImmutableSet.copyOf( + private Set collectNeedToSlotGroupingExpr(LogicalRepeat repeat) { + // grouping sets should be pushed down, e.g. grouping sets((k + 1)), + // we should push down the `k + 1` to the bottom plan + return ImmutableSet.copyOf( ExpressionUtils.flatExpressions(repeat.getGroupingSets())); + } + private Set collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(LogicalRepeat repeat) { Set groupingScalarFunctions = ExpressionUtils.collect( repeat.getOutputExpressions(), GroupingScalarFunction.class::isInstance); - - ImmutableSet argumentsOfGroupingScalarFunction = groupingScalarFunctions.stream() - .flatMap(function -> function.getArguments().stream()) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder argumentsSetBuilder = ImmutableSet.builder(); + for (GroupingScalarFunction function : groupingScalarFunctions) { + argumentsSetBuilder.addAll(function.getArguments()); + } + ImmutableSet argumentsOfGroupingScalarFunction = argumentsSetBuilder.build(); List aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions()); + ImmutableSet.Builder argumentsOfAggregateFunctionBuilder = ImmutableSet.builder(); + for (AggregateFunction function : aggregateFunctions) { + for (Expression arg : function.getArguments()) { + if (arg instanceof OrderExpression) { + argumentsOfAggregateFunctionBuilder.add(arg.child(0)); + } else { + argumentsOfAggregateFunctionBuilder.add(arg); + } + } + } + ImmutableSet argumentsOfAggregateFunction = argumentsOfAggregateFunctionBuilder.build(); - ImmutableSet argumentsOfAggregateFunction = aggregateFunctions.stream() - .flatMap(function -> function.getArguments().stream().map(arg -> { - if (arg instanceof OrderExpression) { - return arg.child(0); - } else { - return arg; - } - })) - .collect(ImmutableSet.toImmutableSet()); - - ImmutableSet needPushDown = ImmutableSet.builder() + return ImmutableSet.builder() // grouping sets should be pushed down, e.g. grouping sets((k + 1)), // we should push down the `k + 1` to the bottom plan - .addAll(flattenGroupingSetExpr) // e.g. grouping_id(k + 1), we should push down the `k + 1` to the bottom plan .addAll(argumentsOfGroupingScalarFunction) // e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan .addAll(argumentsOfAggregateFunction) .build(); - return needPushDown; } private Plan pushDownProject(Set pushedExprs, Plan originBottomPlan) { @@ -250,8 +270,13 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { return Optional.of(new NormalizeToSlotTriplet(expression, newSlot, originTriplet.pushedExpr)); } - private Expression normalizeGroupingScalarFunction(NormalizeToSlotContext context, Expression expr) { - if (expr instanceof GroupingScalarFunction) { + private Expression normalizeAggFuncChildrenAndGroupingScalarFunc(NormalizeToSlotContext context, Expression expr) { + if (expr instanceof AggregateFunction) { + AggregateFunction function = (AggregateFunction) expr; + List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); + function = function.withChildren(normalizedRealExpressions); + return function; + } else if (expr instanceof GroupingScalarFunction) { GroupingScalarFunction function = (GroupingScalarFunction) expr; List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); function = function.withChildren(normalizedRealExpressions); @@ -262,6 +287,20 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { } } + private Set getExistsAlias(LogicalRepeat repeat, + Map groupingExprMap) { + Set existsAlias = Sets.newHashSet(); + Set aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance); + existsAlias.addAll(aliases); + for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) { + if (triplet.pushedExpr instanceof Alias) { + Alias alias = (Alias) triplet.pushedExpr; + existsAlias.add(alias); + } + } + return existsAlias; + } + /* * compute slots that appear both in agg func and grouping sets, * copy the slots and output in the project below the repeat as new copied slots, @@ -278,56 +317,77 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { private LogicalAggregate dealSlotAppearBothInAggFuncAndGroupingSets( @NotNull LogicalAggregate aggregate) { LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); - - List aggregateFunctions = - CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions()); - Set aggUsedSlots = aggregateFunctions.stream() - .flatMap(e -> e.>collect(SlotReference.class::isInstance).stream()) - .collect(ImmutableSet.toImmutableSet()); - Set groupingSetsUsedSlot = repeat.getGroupingSets().stream() - .flatMap(Collection::stream) - .flatMap(e -> e.>collect(SlotReference.class::isInstance).stream()) - .collect(Collectors.toSet()); - - Set resSet = new HashSet<>(aggUsedSlots); - resSet.retainAll(groupingSetsUsedSlot); - if (resSet.isEmpty()) { + Map commonSlotToAliasMap = getCommonSlotToAliasMap(repeat, aggregate); + if (commonSlotToAliasMap.isEmpty()) { return aggregate; } - Map slotMapping = resSet.stream().collect( - Collectors.toMap(key -> key, Alias::new) - ); - Set newAliases = new HashSet<>(slotMapping.values()); - List newSlots = newAliases.stream() - .map(Alias::toSlot) - .collect(Collectors.toList()); - // modify repeat child to a new project with more projections + Set newAliases = new HashSet<>(commonSlotToAliasMap.values()); List originSlots = repeat.child().getOutput(); - ImmutableList immList = + ImmutableList newProjects = ImmutableList.builder().addAll(originSlots).addAll(newAliases).build(); - LogicalProject newProject = new LogicalProject<>(immList, repeat.child()); - repeat = repeat.withChildren(ImmutableList.of(newProject)); + LogicalProject newLogicalProject = new LogicalProject<>(newProjects, repeat.child()); + repeat = repeat.withChildren(ImmutableList.of(newLogicalProject)); // modify repeat outputs List originRepeatSlots = repeat.getOutput(); - repeat = repeat.withAggOutput(ImmutableList - .builder() - .addAll(originRepeatSlots.stream().filter(slot -> ! (slot instanceof VirtualSlotReference)) - .collect(Collectors.toList())) + List virtualSlots = Lists.newArrayList(); + List nonVirtualSlots = Lists.newArrayList(); + for (Slot slot : originRepeatSlots) { + if (slot instanceof VirtualSlotReference) { + virtualSlots.add(slot); + } else { + nonVirtualSlots.add(slot); + } + } + List newSlots = Lists.newArrayList(); + for (Alias alias : newAliases) { + newSlots.add(alias.toSlot()); + } + repeat = repeat.withAggOutput(ImmutableList.builder() + .addAll(nonVirtualSlots) .addAll(newSlots) - .addAll(originRepeatSlots.stream().filter(slot -> (slot instanceof VirtualSlotReference)) - .collect(Collectors.toList())) + .addAll(virtualSlots) .build()); aggregate = aggregate.withChildren(ImmutableList.of(repeat)); - List newOutputExpressions = aggregate.getOutputExpressions().stream() - .map(e -> (NamedExpression) e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE, - slotMapping)) - .collect(Collectors.toList()); + ImmutableList.Builder newOutputExpressionBuilder = ImmutableList.builder(); + for (NamedExpression expression : aggregate.getOutputExpressions()) { + NamedExpression newExpression = (NamedExpression) expression + .accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE, commonSlotToAliasMap); + newOutputExpressionBuilder.add(newExpression); + } + List newOutputExpressions = newOutputExpressionBuilder.build(); return aggregate.withAggOutput(newOutputExpressions); } + private Map getCommonSlotToAliasMap(LogicalRepeat repeat, LogicalAggregate aggregate) { + List aggregateFunctions = + CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions()); + ImmutableSet.Builder aggUsedSlotBuilder = ImmutableSet.builder(); + for (AggregateFunction function : aggregateFunctions) { + aggUsedSlotBuilder.addAll(function.>collect(SlotReference.class::isInstance)); + } + ImmutableSet aggUsedSlots = aggUsedSlotBuilder.build(); + + ImmutableSet.Builder groupingSetsUsedSlotBuilder = ImmutableSet.builder(); + for (List groupingSet : repeat.getGroupingSets()) { + for (Expression expr : groupingSet) { + groupingSetsUsedSlotBuilder.addAll(expr.>collect(SlotReference.class::isInstance)); + } + } + ImmutableSet groupingSetsUsedSlot = groupingSetsUsedSlotBuilder.build(); + + Set resSet = new HashSet<>(aggUsedSlots); + resSet.retainAll(groupingSetsUsedSlot); + Map commonSlotToAliasMap = Maps.newHashMap(); + for (Slot key : resSet) { + Alias alias = new Alias(key); + commonSlotToAliasMap.put(key, alias); + } + return commonSlotToAliasMap; + } + /** * This class use the map(slotMapping) to rewrite all slots in trival-agg. * The purpose of this class is to only rewrite the slots in trival-agg and not to rewrite the slots in window-agg. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java index ea2fb8f4be..0e1fba1fc4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java @@ -51,6 +51,17 @@ public interface NormalizeToSlot { this.normalizeToSlotMap = normalizeToSlotMap; } + public Map getNormalizeToSlotMap() { + return normalizeToSlotMap; + } + + public NormalizeToSlotContext mergeContext(NormalizeToSlotContext context) { + Map newMap = Maps.newHashMap(); + newMap.putAll(this.normalizeToSlotMap); + newMap.putAll(context.getNormalizeToSlotMap()); + return new NormalizeToSlotContext(newMap); + } + /** * build normalization context by follow step. * 1. collect all exists alias by input parameters existsAliases build a reverted map: expr -> alias diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java index a071328ec5..2be4887050 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java @@ -145,4 +145,5 @@ public class BuiltinFunctionBuilder extends FunctionBuilder { .map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor) constructor)) .collect(ImmutableList.toImmutableList()); } + } diff --git a/regression-test/data/nereids_rules_p0/grouping_sets/grouping_normalize_test.out b/regression-test/data/nereids_rules_p0/grouping_sets/grouping_normalize_test.out new file mode 100644 index 0000000000..41e0576eca --- /dev/null +++ b/regression-test/data/nereids_rules_p0/grouping_sets/grouping_normalize_test.out @@ -0,0 +1,9 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !test -- +-1 -2 2 +-1 -2 2 +1 0 4 +1 1 4 +29 -3 32 +41 1 \N + diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy new file mode 100644 index 0000000000..8310685c9c --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("grouping_normalize_test"){ + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """ + DROP TABLE IF EXISTS grouping_normalize_test + """ + sql """ + CREATE TABLE `grouping_normalize_test` ( + `pk` INT NULL, + `col_int_undef_signed` INT NULL, + `col_int_undef_signed2` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`pk`) + COMMENT 'OLAP' + DISTRIBUTED BY HASH(`pk`) BUCKETS 10 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql "insert into grouping_normalize_test values(1,3,5),(3,5,5),(31,2,5),(1,3,6),(3,6,2)" + qt_test """ + SELECT ROUND( SUM(pk + 1) - 3) col_alias1, MAX( DISTINCT col_int_undef_signed - 5) AS col_alias2, pk + 1 AS col_alias3 + FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed,col_int_undef_signed2,pk),()) order by 1,2,3; + """ +} \ No newline at end of file