From e17809a684da983a51dac7012592ee77540d0957 Mon Sep 17 00:00:00 2001 From: starocean999 <40539150+starocean999@users.noreply.github.com> Date: Wed, 10 Jan 2024 09:57:29 +0800 Subject: [PATCH] [fix](nereids)logicalhaving is in wrong place after logicalagg and logicalwindow (#29463) --- .../doris/nereids/jobs/executor/Analyzer.java | 2 + .../apache/doris/nereids/rules/RuleType.java | 1 + .../rules/analysis/FillUpMissingSlots.java | 5 - .../rules/analysis/HavingToFilter.java | 34 +++ .../rules/analysis/NormalizeAggregate.java | 276 ++++++++++-------- .../aggregate/agg_window_project.out | 4 + .../aggregate/agg_window_project.groovy | 22 ++ 7 files changed, 214 insertions(+), 130 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/HavingToFilter.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java index c88cf1c432..95fb019ad5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias; import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant; import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint; import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots; +import org.apache.doris.nereids.rules.analysis.HavingToFilter; import org.apache.doris.nereids.rules.analysis.LeadingJoin; import org.apache.doris.nereids.rules.analysis.NormalizeAggregate; import org.apache.doris.nereids.rules.analysis.NormalizeRepeat; @@ -161,6 +162,7 @@ public class Analyzer extends AbstractBatchJobExecutor { bottomUp(new CheckAnalysis()), topDown(new EliminateGroupByConstant()), topDown(new NormalizeAggregate()), + topDown(new HavingToFilter()), bottomUp(new SemiJoinCommute()), bottomUp( new CollectSubQueryAlias(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index cede463c92..61efe9e32f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -63,6 +63,7 @@ public enum RuleType { RESOLVE_PROJECT_ALIAS(RuleTypeClass.REWRITE), RESOLVE_AGGREGATE_ALIAS(RuleTypeClass.REWRITE), PROJECT_TO_GLOBAL_AGGREGATE(RuleTypeClass.REWRITE), + HAVING_TO_FILTER(RuleTypeClass.REWRITE), ONE_ROW_RELATION_EXTRACT_AGGREGATE(RuleTypeClass.REWRITE), PROJECT_WITH_DISTINCT_TO_AGGREGATE(RuleTypeClass.REWRITE), AVG_DISTINCT_TO_SUM_DIV_COUNT(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java index a500d9bb66..c8efc1a891 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/FillUpMissingSlots.java @@ -30,7 +30,6 @@ import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.algebra.Aggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalSort; @@ -174,10 +173,6 @@ public class FillUpMissingSlots implements AnalysisRuleFactory { return new LogicalProject<>(ImmutableList.copyOf(project.getOutput()), having.withChildren(new LogicalProject<>(projects, project.child()))); }) - ), - // Convert having to filter - RuleType.FILL_UP_HAVING_PROJECT.build( - logicalHaving().then(having -> new LogicalFilter<>(having.getConjuncts(), having.child())) ) ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/HavingToFilter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/HavingToFilter.java new file mode 100644 index 0000000000..9751c61ebe --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/HavingToFilter.java @@ -0,0 +1,34 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; + +/** + * HavingToFilter + */ +public class HavingToFilter extends OneAnalysisRuleFactory { + @Override + public Rule build() { + return logicalHaving() + .then(having -> new LogicalFilter<>(having.getConjuncts(), having.child())) + .toRule(RuleType.HAVING_TO_FILTER); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java index a7eb7c7e5c..0e7dc51bab 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeAggregate.java @@ -21,7 +21,7 @@ import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot; import org.apache.doris.nereids.rules.rewrite.NormalizeToSlot.NormalizeToSlotContext; -import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.rules.rewrite.RewriteRuleFactory; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; @@ -34,6 +34,8 @@ import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionVisitor; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; @@ -46,6 +48,7 @@ import com.google.common.collect.Sets; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -97,141 +100,164 @@ import java.util.stream.Collectors; * * More example could get from UT {NormalizeAggregateTest} */ -public class NormalizeAggregate extends OneRewriteRuleFactory implements NormalizeToSlot { +public class NormalizeAggregate implements RewriteRuleFactory, NormalizeToSlot { @Override - public Rule build() { - return logicalAggregate().whenNot(LogicalAggregate::isNormalized).then(aggregate -> { - // The LogicalAggregate node may contain window agg functions and usual agg functions - // we call window agg functions as window-agg and usual agg functions as trival-agg for short - // This rule simplify LogicalAggregate node by: - // 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node, - // 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs - // 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node - // Push down exprs: - // 1. all group by exprs - // 2. child contains subquery expr in trival-agg - // 3. child contains window expr in trival-agg - // 4. all input slots of trival-agg - // 5. expr(including subquery) in distinct trival-agg - // Normalize LogicalAggregate's output. - // 1. normalize group by exprs by outputs of bottom LogicalProject - // 2. normalize trival-aggs by outputs of bottom LogicalProject - // 3. build normalized agg outputs - // Pull up exprs: - // normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes: - // 1. simple slots - // 2. aliases - // a. alias with no aggs child - // b. alias with trival-agg child - // c. alias with window-agg + public List buildRules() { + return ImmutableList.of( + logicalHaving(logicalAggregate().whenNot(LogicalAggregate::isNormalized)) + .then(having -> normalizeAgg(having.child(), Optional.of(having))) + .toRule(RuleType.NORMALIZE_AGGREGATE), + logicalAggregate().whenNot(LogicalAggregate::isNormalized) + .then(aggregate -> normalizeAgg(aggregate, Optional.empty())) + .toRule(RuleType.NORMALIZE_AGGREGATE)); + } - // Push down exprs: - // collect group by exprs - Set groupingByExprs = - ImmutableSet.copyOf(aggregate.getGroupByExpressions()); + private LogicalPlan normalizeAgg(LogicalAggregate aggregate, Optional> having) { + // The LogicalAggregate node may contain window agg functions and usual agg functions + // we call window agg functions as window-agg and usual agg functions as trival-agg for short + // This rule simplify LogicalAggregate node by: + // 1. Push down some exprs from old LogicalAggregate node to a new child LogicalProject Node, + // 2. create a new LogicalAggregate with normalized group by exprs and trival-aggs + // 3. Pull up normalized old LogicalAggregate's output exprs to a new parent LogicalProject Node + // Push down exprs: + // 1. all group by exprs + // 2. child contains subquery expr in trival-agg + // 3. child contains window expr in trival-agg + // 4. all input slots of trival-agg + // 5. expr(including subquery) in distinct trival-agg + // Normalize LogicalAggregate's output. + // 1. normalize group by exprs by outputs of bottom LogicalProject + // 2. normalize trival-aggs by outputs of bottom LogicalProject + // 3. build normalized agg outputs + // Pull up exprs: + // normalize all output exprs in old LogicalAggregate to build a parent project node, typically includes: + // 1. simple slots + // 2. aliases + // a. alias with no aggs child + // b. alias with trival-agg child + // c. alias with window-agg - // collect all trival-agg - List aggregateOutput = aggregate.getOutputExpressions(); - List aggFuncs = Lists.newArrayList(); - aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); + // Push down exprs: + // collect group by exprs + Set groupingByExprs = + ImmutableSet.copyOf(aggregate.getGroupByExpressions()); - // split non-distinct agg child as two part - // TRUE part 1: need push down itself, if it contains subqury or window expression - // FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression - Map> categorizedNoDistinctAggsChildren = aggFuncs.stream() - .filter(aggFunc -> !aggFunc.isDistinct()) - .flatMap(agg -> agg.children().stream()) - .collect(Collectors.groupingBy( - child -> child.containsType(SubqueryExpr.class, WindowExpression.class), - Collectors.toSet())); + // collect all trival-agg + List aggregateOutput = aggregate.getOutputExpressions(); + List aggFuncs = Lists.newArrayList(); + aggregateOutput.forEach(o -> o.accept(CollectNonWindowedAggFuncs.INSTANCE, aggFuncs)); - // split distinct agg child as two parts - // TRUE part 1: need push down itself, if it is NOT SlotReference or Literal - // FALSE part 2: need push down its input slots, if it is SlotReference or Literal - Map> categorizedDistinctAggsChildren = aggFuncs.stream() - .filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream()) - .collect(Collectors.groupingBy( - child -> !(child instanceof SlotReference || child instanceof Literal), - Collectors.toSet())); + // split non-distinct agg child as two part + // TRUE part 1: need push down itself, if it contains subqury or window expression + // FALSE part 2: need push down its input slots, if it DOES NOT contain subqury or window expression + Map> categorizedNoDistinctAggsChildren = aggFuncs.stream() + .filter(aggFunc -> !aggFunc.isDistinct()) + .flatMap(agg -> agg.children().stream()) + .collect(Collectors.groupingBy( + child -> child.containsType(SubqueryExpr.class, WindowExpression.class), + Collectors.toSet())); - Set needPushSelf = Sets.union( - categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()), - categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>())); - Set needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union( - categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()), - categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>()))); + // split distinct agg child as two parts + // TRUE part 1: need push down itself, if it is NOT SlotReference or Literal + // FALSE part 2: need push down its input slots, if it is SlotReference or Literal + Map> categorizedDistinctAggsChildren = aggFuncs.stream() + .filter(aggFunc -> aggFunc.isDistinct()).flatMap(agg -> agg.children().stream()) + .collect(Collectors.groupingBy( + child -> !(child instanceof SlotReference || child instanceof Literal), + Collectors.toSet())); - Set existsAlias = - ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); + Set needPushSelf = Sets.union( + categorizedNoDistinctAggsChildren.getOrDefault(true, new HashSet<>()), + categorizedDistinctAggsChildren.getOrDefault(true, new HashSet<>())); + Set needPushInputSlots = ExpressionUtils.getInputSlotSet(Sets.union( + categorizedNoDistinctAggsChildren.getOrDefault(false, new HashSet<>()), + categorizedDistinctAggsChildren.getOrDefault(false, new HashSet<>()))); - // push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later - // 1. group by exprs - // 2. trivalAgg children - // 3. trivalAgg input slots - Set allPushDownExprs = - Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots)); - NormalizeToSlotContext bottomSlotContext = - NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs); - Set pushedGroupByExprs = - bottomSlotContext.pushDownToNamedExpression(groupingByExprs); - Set pushedTrivalAggChildren = - bottomSlotContext.pushDownToNamedExpression(needPushSelf); - Set pushedTrivalAggInputSlots = - bottomSlotContext.pushDownToNamedExpression(needPushInputSlots); - Set bottomProjects = Sets.union(pushedGroupByExprs, - Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots)); + Set existsAlias = + ExpressionUtils.mutableCollect(aggregateOutput, Alias.class::isInstance); - // create bottom project - Plan bottomPlan; - if (!bottomProjects.isEmpty()) { - bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), - aggregate.child()); + // push down 3 kinds of exprs, these pushed exprs will be used to normalize agg output later + // 1. group by exprs + // 2. trivalAgg children + // 3. trivalAgg input slots + Set allPushDownExprs = + Sets.union(groupingByExprs, Sets.union(needPushSelf, needPushInputSlots)); + NormalizeToSlotContext bottomSlotContext = + NormalizeToSlotContext.buildContext(existsAlias, allPushDownExprs); + Set pushedGroupByExprs = + bottomSlotContext.pushDownToNamedExpression(groupingByExprs); + Set pushedTrivalAggChildren = + bottomSlotContext.pushDownToNamedExpression(needPushSelf); + Set pushedTrivalAggInputSlots = + bottomSlotContext.pushDownToNamedExpression(needPushInputSlots); + Set bottomProjects = Sets.union(pushedGroupByExprs, + Sets.union(pushedTrivalAggChildren, pushedTrivalAggInputSlots)); + + // create bottom project + Plan bottomPlan; + if (!bottomProjects.isEmpty()) { + bottomPlan = new LogicalProject<>(ImmutableList.copyOf(bottomProjects), + aggregate.child()); + } else { + bottomPlan = aggregate.child(); + } + + // use group by context to normalize agg functions to process + // sql like: select sum(a + 1) from t group by a + 1 + // + // before normalize: + // agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1]) + // +-- project(a[#0], (a[#0] + 1)[#1]) + // + // after normalize: + // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1]) + // +-- project((a[#0] + 1)[#1]) + + // normalize group by exprs by bottomProjects + List normalizedGroupExprs = + bottomSlotContext.normalizeToUseSlotRef(groupingByExprs); + + // normalize trival-aggs by bottomProjects + List normalizedAggFuncs = + bottomSlotContext.normalizeToUseSlotRef(aggFuncs); + + // build normalized agg output + NormalizeToSlotContext normalizedAggFuncsToSlotContext = + NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs); + + // agg output include 2 parts + // pushedGroupByExprs and normalized agg functions + List normalizedAggOutput = ImmutableList.builder() + .addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator()) + .addAll(normalizedAggFuncsToSlotContext + .pushDownToNamedExpression(normalizedAggFuncs)) + .build(); + + // create new agg node + LogicalAggregate newAggregate = + aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan); + + // create upper projects by normalize all output exprs in old LogicalAggregate + List upperProjects = normalizeOutput(aggregateOutput, + bottomSlotContext, normalizedAggFuncsToSlotContext); + + // create a parent project node + LogicalProject project = new LogicalProject<>(upperProjects, newAggregate); + if (having.isPresent()) { + if (upperProjects.stream().anyMatch(expr -> expr.anyMatch(WindowExpression.class::isInstance))) { + // when project contains window functions, in order to get the correct result + // push having through project to make it the parent node of logicalAgg + return project.withChildren(ImmutableList.of( + new LogicalHaving<>( + ExpressionUtils.replace(having.get().getConjuncts(), project.getAliasToProducer()), + project.child() + ))); } else { - bottomPlan = aggregate.child(); + return (LogicalPlan) having.get().withChildren(project); } - - // use group by context to normalize agg functions to process - // sql like: select sum(a + 1) from t group by a + 1 - // - // before normalize: - // agg(output: sum(a[#0] + 1)[#2], group_by: alias(a + 1)[#1]) - // +-- project(a[#0], (a[#0] + 1)[#1]) - // - // after normalize: - // agg(output: sum(alias(a + 1)[#1])[#2], group_by: alias(a + 1)[#1]) - // +-- project((a[#0] + 1)[#1]) - - // normalize group by exprs by bottomProjects - List normalizedGroupExprs = - bottomSlotContext.normalizeToUseSlotRef(groupingByExprs); - - // normalize trival-aggs by bottomProjects - List normalizedAggFuncs = - bottomSlotContext.normalizeToUseSlotRef(aggFuncs); - - // build normalized agg output - NormalizeToSlotContext normalizedAggFuncsToSlotContext = - NormalizeToSlotContext.buildContext(existsAlias, normalizedAggFuncs); - - // agg output include 2 parts - // pushedGroupByExprs and normalized agg functions - List normalizedAggOutput = ImmutableList.builder() - .addAll(pushedGroupByExprs.stream().map(NamedExpression::toSlot).iterator()) - .addAll(normalizedAggFuncsToSlotContext - .pushDownToNamedExpression(normalizedAggFuncs)) - .build(); - - // create new agg node - LogicalAggregate newAggregate = - aggregate.withNormalized(normalizedGroupExprs, normalizedAggOutput, bottomPlan); - - // create upper projects by normalize all output exprs in old LogicalAggregate - List upperProjects = normalizeOutput(aggregateOutput, - bottomSlotContext, normalizedAggFuncsToSlotContext); - - // create a parent project node - return new LogicalProject<>(upperProjects, newAggregate); - }).toRule(RuleType.NORMALIZE_AGGREGATE); + } else { + return project; + } } private List normalizeOutput(List aggregateOutput, diff --git a/regression-test/data/nereids_p0/aggregate/agg_window_project.out b/regression-test/data/nereids_p0/aggregate/agg_window_project.out index 60f8b9a6f9..dcdb0f25ec 100644 --- a/regression-test/data/nereids_p0/aggregate/agg_window_project.out +++ b/regression-test/data/nereids_p0/aggregate/agg_window_project.out @@ -8,3 +8,7 @@ -- !select3 -- 1 +-- !select4 -- +2 1 23.0000000000 +2 2 23.0000000000 + diff --git a/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy b/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy index 794ccf3002..605f4da5ff 100644 --- a/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy +++ b/regression-test/suites/nereids_p0/aggregate/agg_window_project.groovy @@ -75,4 +75,26 @@ suite("agg_window_project") { test_window_table;""" sql "DROP TABLE IF EXISTS test_window_table;" + + sql "DROP TABLE IF EXISTS test_window_table2;" + sql """ + create table test_window_table2 + ( + a varchar(100) null, + b decimalv3(18,10) null, + c int, + ) ENGINE=OLAP + DUPLICATE KEY(`a`) + DISTRIBUTED BY HASH(`a`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + + """ + + sql """insert into test_window_table2 values("1", 1, 1),("1", 1, 2),("1", 2, 1),("1", 2, 2),("2", 11, 1),("2", 11, 2),("2", 12, 1),("2", 12, 2);""" + + order_qt_select4 """select a, c, sum(sum(b)) over(partition by c order by c rows between unbounded preceding and current row) from test_window_table2 group by a, c having a > 1;""" + + sql "DROP TABLE IF EXISTS test_window_table2;" }