From 2047b9416f5dd78db7c33024afdcf106cdd8d3d2 Mon Sep 17 00:00:00 2001 From: feiniaofeiafei <53502832+feiniaofeiafei@users.noreply.github.com> Date: Thu, 29 Feb 2024 22:17:43 +0800 Subject: [PATCH] [Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets (#31600) * [Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets * [Fix](nereids) Only rewrite the slots that appear both in trival-agg func and grouping sets --------- Co-authored-by: feiniaofeiafei --- .../rules/analysis/NormalizeRepeat.java | 84 +++++++++++++++---- ...th_appear_in_agg_fun_and_grouping_sets.out | 61 ++++++++++++++ ...appear_in_agg_fun_and_grouping_sets.groovy | 30 +++++++ 3 files changed, 160 insertions(+), 15 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 3c893ce4be..8437dc40b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -29,8 +29,10 @@ import org.apache.doris.nereids.trees.expressions.OrderExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.VirtualSlotReference; +import org.apache.doris.nereids.trees.expressions.WindowExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.scalar.GroupingScalarFunction; +import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Repeat; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; @@ -47,6 +49,7 @@ import com.google.common.collect.Sets; import com.google.common.collect.Sets.SetView; import org.jetbrains.annotations.NotNull; +import java.util.ArrayList; import java.util.Collection; import java.util.HashSet; import java.util.List; @@ -267,8 +270,11 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { private LogicalAggregate dealSlotAppearBothInAggFuncAndGroupingSets( @NotNull LogicalAggregate aggregate) { LogicalRepeat repeat = (LogicalRepeat) aggregate.child(); - Set aggUsedSlots = aggregate.getOutputExpressions().stream() - .flatMap(e -> e.>collect(AggregateFunction.class::isInstance).stream()) + + List aggregateFunctions = Lists.newArrayList(); + aggregate.getOutputExpressions().forEach( + o -> o.accept(PlanUtils.CollectNonWindowedAggFuncs.INSTANCE, aggregateFunctions)); + Set aggUsedSlots = aggregateFunctions.stream() .flatMap(e -> e.>collect(SlotReference.class::isInstance).stream()) .collect(ImmutableSet.toImmutableSet()); Set groupingSetsUsedSlot = repeat.getGroupingSets().stream() @@ -308,20 +314,68 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { .build()); aggregate = aggregate.withChildren(ImmutableList.of(repeat)); - // modify aggregate functions' parameter slot reference to new copied slots List newOutputExpressions = aggregate.getOutputExpressions().stream() - .map(output -> (NamedExpression) output.rewriteDownShortCircuit(expr -> { - if (expr instanceof AggregateFunction) { - return expr.rewriteDownShortCircuit(e -> { - if (e instanceof Slot && slotMapping.containsKey(e)) { - return slotMapping.get(e).toSlot(); - } - return e; - }); - } - return expr; - }) - ).collect(Collectors.toList()); + .map(e -> (NamedExpression) e.accept(RewriteAggFuncWithoutWindowAggFunc.INSTANCE, + slotMapping)) + .collect(Collectors.toList()); return aggregate.withAggOutput(newOutputExpressions); } + + /** + * This class use the map(slotMapping) to rewrite all slots in trival-agg. + * The purpose of this class is to only rewrite the slots in trival-agg and not to rewrite the slots in window-agg. + */ + private static class RewriteAggFuncWithoutWindowAggFunc + extends DefaultExpressionRewriter> { + + private static final RewriteAggFuncWithoutWindowAggFunc + INSTANCE = new RewriteAggFuncWithoutWindowAggFunc(); + + private RewriteAggFuncWithoutWindowAggFunc() {} + + @Override + public Expression visitAggregateFunction(AggregateFunction aggregateFunction, Map slotMapping) { + return aggregateFunction.rewriteDownShortCircuit(e -> { + if (e instanceof Slot && slotMapping.containsKey(e)) { + return slotMapping.get(e).toSlot(); + } + return e; + }); + } + + @Override + public Expression visitWindow(WindowExpression windowExpression, Map slotMapping) { + List newChildren = new ArrayList<>(); + Expression function = windowExpression.getFunction(); + Expression oldFuncChild = function.child(0); + boolean hasNewChildren = false; + if (oldFuncChild != null) { + Expression newFuncChild; + newFuncChild = function.child(0).accept(this, slotMapping); + hasNewChildren = (newFuncChild != oldFuncChild); + newChildren.add(hasNewChildren + ? function.withChildren(ImmutableList.of(newFuncChild)) : function); + } else { + newChildren.add(function); + } + for (Expression partitionKey : windowExpression.getPartitionKeys()) { + Expression newChild = partitionKey.accept(this, slotMapping); + if (newChild != partitionKey) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + for (Expression orderKey : windowExpression.getOrderKeys()) { + Expression newChild = orderKey.accept(this, slotMapping); + if (newChild != orderKey) { + hasNewChildren = true; + } + newChildren.add(newChild); + } + if (windowExpression.getWindowFrame().isPresent()) { + newChildren.add(windowExpression.getWindowFrame().get()); + } + return hasNewChildren ? windowExpression.withChildren(newChildren) : windowExpression; + } + } } diff --git a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out index 901226f854..2c96648dac 100644 --- a/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out +++ b/regression-test/data/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.out @@ -64,3 +64,64 @@ yeah 4 5 +-- !select6 -- +\N +-86 +-48 +-12 +82 +89 +16054 +19196 + +-- !select7 -- +\N +-48 +-43 +82 +89 +35195 + +-- !select8 -- +\N +\N +-86 +-86 +-48 +-12 +82 +89 +16054 +19196 + +-- !select9 -- +\N +\N +\N +-129 +-129 +-129 +-96 +-96 +-12 +164 +164 +178 +178 +16054 +19196 +35195 +35275 + +-- !select10 -- +\N +\N +-172 +-172 +-96 +-24 +164 +178 +32108 +38392 + diff --git a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy index 865ce3b5f5..bee63217a9 100644 --- a/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy +++ b/regression-test/suites/nereids_rules_p0/grouping_sets/slot_both_appear_in_agg_fun_and_grouping_sets.groovy @@ -60,4 +60,34 @@ suite("slot_both_appear_in_agg_fun_and_grouping_sets") { select sum(rank() over (partition by col_text_undef_signed order by col_int_undef_signed)) as col1 from table_10_undef_undef4 group by grouping sets((col_int_undef_signed)) order by 1; """ + + qt_select6 """ + select sum(sum(col_int_undef_signed)) over (partition by sum(col_int_undef_signed) + order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by + grouping sets ((col_int_undef_signed)) order by 1; + """ + + qt_select7 """ + select sum(sum(col_int_undef_signed)) over (partition by sum(col_int_undef_signed) + order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by + grouping sets ((col_text_undef_signed)) order by 1; + """ + + qt_select8 """ + select sum(sum(col_int_undef_signed)) over (partition by sum(col_int_undef_signed) + order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by + grouping sets ((col_text_undef_signed,col_int_undef_signed)) order by 1; + """ + + qt_select9 """ + select sum(sum(col_int_undef_signed)) over (partition by sum(col_int_undef_signed) + order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by + grouping sets ((col_text_undef_signed,col_int_undef_signed), (col_text_undef_signed), ()) order by 1; + """ + + qt_select10 """ + select sum(col_int_undef_signed + sum(col_int_undef_signed)) over (partition by sum(col_int_undef_signed) + order by sum(col_int_undef_signed)) from table_10_undef_undef4 group by + grouping sets ((col_text_undef_signed,col_int_undef_signed)) order by 1; + """ }