From 7c56c17ecc2a7cf30a909f3a163e93c373bb433e Mon Sep 17 00:00:00 2001 From: feiniaofeiafei <53502832+feiniaofeiafei@users.noreply.github.com> Date: Wed, 8 May 2024 17:06:40 +0800 Subject: [PATCH] [Fix](nereids) fix NormalizeRepeat, change the outputExpression rewrite logic (#34196) In NormalizeRepeat, three parts of the outputExpression of LogicalRepeat need to be pushed down and outputted by bottom project: flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction. In the original code, use these three parts to rewrite the outputExpressions of LogicalRepeat to slots.This can cause problems in some cases, for example: ```sql SELECT ROUND( SUM(pk + 1) - 3) col_alias1, pk + 1 AS col_alias3 FROM table_20_undef_partitions2_keys3_properties4_distributed_by53 GROUP BY GROUPING SETS ((pk), ()) ; ``` The three parts expression needed to be pushed down are: pk, pk+1. The original code use pk+1 to rewrite the pk + 1 AS col_alias3 to slot. But the pk+1 is not in the list of grouping outputs, and then report error. This pr change the rewrite process, divide the expression needed to be pushed down into 2 parts: one is (flattenGroupingSetExpr) and the other one is (argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction). and use the flattenGroupingSetExpr rewrite all LogicalRepeat outputExpressions, and use the argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction to rewrite only the agg function arguments and the grouping scalar function. So, in the above sql, the pk + 1 AS col_alias3 will not be rewritten to slot, and can be computed. --- .../rules/analysis/NormalizeRepeat.java | 204 +++++++++++------- .../rules/rewrite/NormalizeToSlot.java | 11 + .../functions/BuiltinFunctionBuilder.java | 1 + .../grouping_sets/grouping_normalize_test.out | 9 + .../grouping_normalize_test.groovy | 42 ++++ 5 files changed, 195 insertions(+), 72 deletions(-) create mode 100644 regression-test/data/nereids_rules_p0/grouping_sets/grouping_normalize_test.out create mode 100644 regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index e90528a9d4..169d5a901a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -43,19 +43,18 @@ import org.apache.doris.nereids.util.PlanUtils.CollectNonWindowedAggFuncs; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableSet; +import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import org.jetbrains.annotations.NotNull; import java.util.ArrayList; -import java.util.Collection; import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; import javax.annotation.Nullable; /** NormalizeRepeat @@ -117,19 +116,34 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { } private LogicalAggregate normalizeRepeat(LogicalRepeat repeat) { - Set needToSlots = collectNeedToSlotExpressions(repeat); - NormalizeToSlotContext context = buildContext(repeat, needToSlots); + Set needToSlotsGroupingExpr = collectNeedToSlotGroupingExpr(repeat); + NormalizeToSlotContext groupingExprContext = buildContext(repeat, needToSlotsGroupingExpr); + Map groupingExprMap = groupingExprContext.getNormalizeToSlotMap(); + Set existsAlias = getExistsAlias(repeat, groupingExprMap); + Set needToSlotsArgs = collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(repeat); + NormalizeToSlotContext argsContext = NormalizeToSlotContext.buildContext(existsAlias, needToSlotsArgs); // normalize grouping sets to List> - List> normalizedGroupingSets = repeat.getGroupingSets() - .stream() - .map(groupingSet -> (List) (List) context.normalizeToUseSlotRef(groupingSet)) - .collect(ImmutableList.toImmutableList()); + ImmutableList.Builder> normalizedGroupingSetBuilder = ImmutableList.builder(); + for (List groupingSet : repeat.getGroupingSets()) { + List normalizedSet = (List) (List) groupingExprContext.normalizeToUseSlotRef(groupingSet); + normalizedGroupingSetBuilder.add(normalizedSet); + } + List> normalizedGroupingSets = normalizedGroupingSetBuilder.build(); - // replace the arguments of grouping scalar function to virtual slots - // replace some complex expression to slot, e.g. `a + 1` - List normalizedAggOutput = context.normalizeToUseSlotRef( - repeat.getOutputExpressions(), this::normalizeGroupingScalarFunction); + // use argsContext + // rewrite the arguments of grouping scalar function to slots + // rewrite grouping scalar function to virtual slots + // rewrite the arguments of agg function to slots + List normalizedAggOutput = Lists.newArrayList(); + for (Expression expr : repeat.getOutputExpressions()) { + Expression rewrittenExpr = expr.rewriteDownShortCircuit( + e -> normalizeAggFuncChildrenAndGroupingScalarFunc(argsContext, e)); + normalizedAggOutput.add((NamedExpression) rewrittenExpr); + } + + // use groupingExprContext rewrite the normalizedAggOutput + normalizedAggOutput = groupingExprContext.normalizeToUseSlotRef(normalizedAggOutput); Set virtualSlotsInFunction = ExpressionUtils.collect(normalizedAggOutput, VirtualSlotReference.class::isInstance); @@ -156,7 +170,12 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { .addAll(allVirtualSlots) .build(); - Set pushedProject = context.pushDownToNamedExpression(needToSlots); + // 3 parts need push down: + // flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction + Set needToSlots = Sets.union(needToSlotsArgs, needToSlotsGroupingExpr); + NormalizeToSlotContext fullContext = argsContext.mergeContext(groupingExprContext); + Set pushedProject = fullContext.pushDownToNamedExpression(needToSlots); + Plan normalizedChild = pushDownProject(pushedProject, repeat.child()); LogicalRepeat normalizedRepeat = repeat.withNormalizedExpr( @@ -170,42 +189,43 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { Optional.of(normalizedRepeat), normalizedRepeat); } - private Set collectNeedToSlotExpressions(LogicalRepeat repeat) { - // 3 parts need push down: - // flattenGroupingSetExpr, argumentsOfGroupingScalarFunction, argumentsOfAggregateFunction - - Set flattenGroupingSetExpr = ImmutableSet.copyOf( + private Set collectNeedToSlotGroupingExpr(LogicalRepeat repeat) { + // grouping sets should be pushed down, e.g. grouping sets((k + 1)), + // we should push down the `k + 1` to the bottom plan + return ImmutableSet.copyOf( ExpressionUtils.flatExpressions(repeat.getGroupingSets())); + } + private Set collectNeedToSlotArgsOfGroupingScalarFuncAndAggFunc(LogicalRepeat repeat) { Set groupingScalarFunctions = ExpressionUtils.collect( repeat.getOutputExpressions(), GroupingScalarFunction.class::isInstance); - - ImmutableSet argumentsOfGroupingScalarFunction = groupingScalarFunctions.stream() - .flatMap(function -> function.getArguments().stream()) - .collect(ImmutableSet.toImmutableSet()); + ImmutableSet.Builder argumentsSetBuilder = ImmutableSet.builder(); + for (GroupingScalarFunction function : groupingScalarFunctions) { + argumentsSetBuilder.addAll(function.getArguments()); + } + ImmutableSet argumentsOfGroupingScalarFunction = argumentsSetBuilder.build(); List aggregateFunctions = CollectNonWindowedAggFuncs.collect(repeat.getOutputExpressions()); + ImmutableSet.Builder argumentsOfAggregateFunctionBuilder = ImmutableSet.builder(); + for (AggregateFunction function : aggregateFunctions) { + for (Expression arg : function.getArguments()) { + if (arg instanceof OrderExpression) { + argumentsOfAggregateFunctionBuilder.add(arg.child(0)); + } else { + argumentsOfAggregateFunctionBuilder.add(arg); + } + } + } + ImmutableSet argumentsOfAggregateFunction = argumentsOfAggregateFunctionBuilder.build(); - ImmutableSet argumentsOfAggregateFunction = aggregateFunctions.stream() - .flatMap(function -> function.getArguments().stream().map(arg -> { - if (arg instanceof OrderExpression) { - return arg.child(0); - } else { - return arg; - } - })) - .collect(ImmutableSet.toImmutableSet()); - - ImmutableSet needPushDown = ImmutableSet.builder() + return ImmutableSet.builder() // grouping sets should be pushed down, e.g. grouping sets((k + 1)), // we should push down the `k + 1` to the bottom plan - .addAll(flattenGroupingSetExpr) // e.g. grouping_id(k + 1), we should push down the `k + 1` to the bottom plan .addAll(argumentsOfGroupingScalarFunction) // e.g. sum(k + 1), we should push down the `k + 1` to the bottom plan .addAll(argumentsOfAggregateFunction) .build(); - return needPushDown; } private Plan pushDownProject(Set pushedExprs, Plan originBottomPlan) { @@ -250,8 +270,13 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { return Optional.of(new NormalizeToSlotTriplet(expression, newSlot, originTriplet.pushedExpr)); } - private Expression normalizeGroupingScalarFunction(NormalizeToSlotContext context, Expression expr) { - if (expr instanceof GroupingScalarFunction) { + private Expression normalizeAggFuncChildrenAndGroupingScalarFunc(NormalizeToSlotContext context, Expression expr) { + if (expr instanceof AggregateFunction) { + AggregateFunction function = (AggregateFunction) expr; + List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); + function = function.withChildren(normalizedRealExpressions); + return function; + } else if (expr instanceof GroupingScalarFunction) { GroupingScalarFunction function = (GroupingScalarFunction) expr; List normalizedRealExpressions = context.normalizeToUseSlotRef(function.getArguments()); function = function.withChildren(normalizedRealExpressions); @@ -262,6 +287,20 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { } } + private Set getExistsAlias(LogicalRepeat repeat, + Map groupingExprMap) { + Set existsAlias = Sets.newHashSet(); + Set aliases = ExpressionUtils.collect(repeat.getOutputExpressions(), Alias.class::isInstance); + existsAlias.addAll(aliases); + for (NormalizeToSlotTriplet triplet : groupingExprMap.values()) { + if (triplet.pushedExpr instanceof Alias) { + Alias alias = (Alias) triplet.pushedExpr; + existsAlias.add(alias); + } + } + return existsAlias; + } + /* * compute slots that appear both in agg func and grouping sets, * copy the slots and output in the project below the repeat as new copied slots, @@ -278,56 +317,77 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { private LogicalAggregate dealSlotAppearBothInAggFuncAndGroupingSets( @NotNull LogicalAggregate aggregate) { LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); - - List aggregateFunctions = - CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions()); - Set aggUsedSlots = aggregateFunctions.stream() - .flatMap(e -> e.>collect(SlotReference.class::isInstance).stream()) - .collect(ImmutableSet.toImmutableSet()); - Set groupingSetsUsedSlot = repeat.getGroupingSets().stream() - .flatMap(Collection::stream) - .flatMap(e -> e.>collect(SlotReference.class::isInstance).stream()) - .collect(Collectors.toSet()); - - Set resSet = new HashSet<>(aggUsedSlots); - resSet.retainAll(groupingSetsUsedSlot); - if (resSet.isEmpty()) { + Map commonSlotToAliasMap = getCommonSlotToAliasMap(repeat, aggregate); + if (commonSlotToAliasMap.isEmpty()) { return aggregate; } - Map slotMapping = resSet.stream().collect( - Collectors.toMap(key -> key, Alias::new) - ); - Set newAliases = new HashSet<>(slotMapping.values()); - List newSlots = newAliases.stream() - .map(Alias::toSlot) - .collect(Collectors.toList()); - // modify repeat child to a new project with more projections + Set newAliases = new HashSet<>(commonSlotToAliasMap.values()); List originSlots = repeat.child().getOutput(); - ImmutableList immList = + ImmutableList newProjects = ImmutableList.builder().addAll(originSlots).addAll(newAliases).build(); - LogicalProject newProject = new LogicalProject<>(immList, repeat.child()); - repeat = repeat.withChildren(ImmutableList.of(newProject)); + LogicalProject newLogicalProject = new LogicalProject<>(newProjects, repeat.child()); + repeat = repeat.withChildren(ImmutableList.of(newLogicalProject)); // modify repeat outputs List originRepeatSlots = repeat.getOutput(); - repeat = repeat.withAggOutput(ImmutableList - .builder() - .addAll(originRepeatSlots.stream().filter(slot -> ! (slot instanceof VirtualSlotReference)) - .collect(Collectors.toList())) + List virtualSlots = Lists.newArrayList(); + List nonVirtualSlots = Lists.newArrayList(); + for (Slot slot : originRepeatSlots) { + if (slot instanceof VirtualSlotReference) { + virtualSlots.add(slot); + } else { + nonVirtualSlots.add(slot); + } + } + List newSlots = Lists.newArrayList(); + for (Alias alias : newAliases) { + newSlots.add(alias.toSlot()); + } + repeat = repeat.withAggOutput(ImmutableList.builder() + .addAll(nonVirtualSlots) .addAll(newSlots) - .addAll(originRepeatSlots.stream().filter(slot -> (slot instanceof VirtualSlotReference)) - .collect(Collectors.toList())) + .addAll(virtualSlots) .build()); aggregate = aggregate.withChildren(ImmutableList.of(repeat)); - List newOutputExpressions = aggregate.getOutputExpressions().stream() - .map(e -> (NamedExpression) e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE, - slotMapping)) - .collect(Collectors.toList()); + ImmutableList.Builder newOutputExpressionBuilder = ImmutableList.builder(); + for (NamedExpression expression : aggregate.getOutputExpressions()) { + NamedExpression newExpression = (NamedExpression) expression + .accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE, commonSlotToAliasMap); + newOutputExpressionBuilder.add(newExpression); + } + List newOutputExpressions = newOutputExpressionBuilder.build(); return aggregate.withAggOutput(newOutputExpressions); } + private Map getCommonSlotToAliasMap(LogicalRepeat repeat, LogicalAggregate aggregate) { + List aggregateFunctions = + CollectNonWindowedAggFuncs.collect(aggregate.getOutputExpressions()); + ImmutableSet.Builder aggUsedSlotBuilder = ImmutableSet.builder(); + for (AggregateFunction function : aggregateFunctions) { + aggUsedSlotBuilder.addAll(function.>collect(SlotReference.class::isInstance)); + } + ImmutableSet aggUsedSlots = aggUsedSlotBuilder.build(); + + ImmutableSet.Builder groupingSetsUsedSlotBuilder = ImmutableSet.builder(); + for (List groupingSet : repeat.getGroupingSets()) { + for (Expression expr : groupingSet) { + groupingSetsUsedSlotBuilder.addAll(expr.>collect(SlotReference.class::isInstance)); + } + } + ImmutableSet groupingSetsUsedSlot = groupingSetsUsedSlotBuilder.build(); + + Set resSet = new HashSet<>(aggUsedSlots); + resSet.retainAll(groupingSetsUsedSlot); + Map commonSlotToAliasMap = Maps.newHashMap(); + for (Slot key : resSet) { + Alias alias = new Alias(key); + commonSlotToAliasMap.put(key, alias); + } + return commonSlotToAliasMap; + } + /** * This class use the map(slotMapping) to rewrite all slots in trival-agg. * The purpose of this class is to only rewrite the slots in trival-agg and not to rewrite the slots in window-agg. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java index ea2fb8f4be..0e1fba1fc4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/NormalizeToSlot.java @@ -51,6 +51,17 @@ public interface NormalizeToSlot { this.normalizeToSlotMap = normalizeToSlotMap; } + public Map getNormalizeToSlotMap() { + return normalizeToSlotMap; + } + + public NormalizeToSlotContext mergeContext(NormalizeToSlotContext context) { + Map newMap = Maps.newHashMap(); + newMap.putAll(this.normalizeToSlotMap); + newMap.putAll(context.getNormalizeToSlotMap()); + return new NormalizeToSlotContext(newMap); + } + /** * build normalization context by follow step. * 1. collect all exists alias by input parameters existsAliases build a reverted map: expr -> alias diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java index a071328ec5..2be4887050 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java @@ -145,4 +145,5 @@ public class BuiltinFunctionBuilder extends FunctionBuilder { .map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor) constructor)) .collect(ImmutableList.toImmutableList()); } + } diff --git a/regression-test/data/nereids_rules_p0/grouping_sets/grouping_normalize_test.out b/regression-test/data/nereids_rules_p0/grouping_sets/grouping_normalize_test.out new file mode 100644 index 0000000000..41e0576eca --- /dev/null +++ b/regression-test/data/nereids_rules_p0/grouping_sets/grouping_normalize_test.out @@ -0,0 +1,9 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !test -- +-1 -2 2 +-1 -2 2 +1 0 4 +1 1 4 +29 -3 32 +41 1 \N + diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy new file mode 100644 index 0000000000..8310685c9c --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/grouping_normalize_test.groovy @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("grouping_normalize_test"){ + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + sql """ + DROP TABLE IF EXISTS grouping_normalize_test + """ + sql """ + CREATE TABLE `grouping_normalize_test` ( + `pk` INT NULL, + `col_int_undef_signed` INT NULL, + `col_int_undef_signed2` INT NULL + ) ENGINE=OLAP + DUPLICATE KEY(`pk`) + COMMENT 'OLAP' + DISTRIBUTED BY HASH(`pk`) BUCKETS 10 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + sql "insert into grouping_normalize_test values(1,3,5),(3,5,5),(31,2,5),(1,3,6),(3,6,2)" + qt_test """ + SELECT ROUND( SUM(pk + 1) - 3) col_alias1, MAX( DISTINCT col_int_undef_signed - 5) AS col_alias2, pk + 1 AS col_alias3 + FROM grouping_normalize_test GROUP BY GROUPING SETS ((col_int_undef_signed,col_int_undef_signed2,pk),()) order by 1,2,3; + """ +} \ No newline at end of file