From 2af831de335bd0ffc7d18d4b8d9bed43addd2b5b Mon Sep 17 00:00:00 2001 From: zhengshiJ <32082872+zhengshiJ@users.noreply.github.com> Date: Wed, 28 Dec 2022 10:42:21 +0800 Subject: [PATCH] [Fix](Nereids)fix group by binding error, resulting in incorrect results (#15328) Original: group by is bound to the outputExpression of the current node. Problem: When the name of the new reference of outputExpression is the same as the child's output column, the child's output column should be used for group by, but at this time, the new reference of the node's outputExpression will be used for group by, resulting in an error Now: Give priority to the child's output for group by binding. If the child does not have a corresponding column, use the outputExpression of this node for binding --- .../rules/analysis/BindSlotReference.java | 23 +++++++-- .../rules/analysis/NormalizeRepeat.java | 2 +- .../rewrite/logical/NormalizeAggregate.java | 3 +- .../nereids/trees/plans/GroupingSetsTest.java | 49 +++++++++++++++++++ .../data/nereids_syntax_p0/grouping_sets.out | 34 +++++++++++++ .../nereids_syntax_p0/grouping_sets.groovy | 28 +++++++++++ 6 files changed, 133 insertions(+), 6 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java index 3590d07483..d438db12ea 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java @@ -167,18 +167,26 @@ public class BindSlotReference implements AnalysisRuleFactory { LogicalAggregate agg = ctx.root; List output = bind(agg.getOutputExpressions(), agg.children(), agg, ctx.cascadesContext); + + // The columns referenced in group by are first obtained from the child's output, + // and then from the node's output + Map childOutputsToExpr = agg.child().getOutput().stream() + .collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr)); Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); + aliasNameToExpr.entrySet().stream() + .forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue())); + List replacedGroupBy = agg.getGroupByExpressions().stream() .map(groupBy -> { if (groupBy instanceof UnboundSlot) { UnboundSlot unboundSlot = (UnboundSlot) groupBy; if (unboundSlot.getNameParts().size() == 1) { String name = unboundSlot.getNameParts().get(0); - if (aliasNameToExpr.containsKey(name)) { - return aliasNameToExpr.get(name); + if (childOutputsToExpr.containsKey(name)) { + return childOutputsToExpr.get(name); } } } @@ -197,10 +205,17 @@ public class BindSlotReference implements AnalysisRuleFactory { List output = bind(repeat.getOutputExpressions(), repeat.children(), repeat, ctx.cascadesContext); + // The columns referenced in group by are first obtained from the child's output, + // and then from the node's output + Map childOutputsToExpr = repeat.child().getOutput().stream() + .collect(Collectors.toMap(Slot::getName, Slot::toSlot, (oldExpr, newExpr) -> oldExpr)); Map aliasNameToExpr = output.stream() .filter(ne -> ne instanceof Alias) .map(Alias.class::cast) .collect(Collectors.toMap(Alias::getName, UnaryNode::child, (oldExpr, newExpr) -> oldExpr)); + aliasNameToExpr.entrySet().stream() + .forEach(e -> childOutputsToExpr.putIfAbsent(e.getKey(), e.getValue())); + List> replacedGroupingSets = repeat.getGroupingSets().stream() .map(groupBy -> groupBy.stream().map(expr -> { @@ -208,8 +223,8 @@ public class BindSlotReference implements AnalysisRuleFactory { UnboundSlot unboundSlot = (UnboundSlot) expr; if (unboundSlot.getNameParts().size() == 1) { String name = unboundSlot.getNameParts().get(0); - if (aliasNameToExpr.containsKey(name)) { - return aliasNameToExpr.get(name); + if (childOutputsToExpr.containsKey(name)) { + return childOutputsToExpr.get(name); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index 8c62aab3d4..7818fe2bb7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -200,7 +200,7 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { } private Plan pushDownProject(Set pushedExprs, Plan originBottomPlan) { - if (!pushedExprs.equals(originBottomPlan.getOutputSet())) { + if (!pushedExprs.equals(originBottomPlan.getOutputSet()) && !pushedExprs.isEmpty()) { return new LogicalProject<>(ImmutableList.copyOf(pushedExprs), originBottomPlan); } return originBottomPlan; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java index d69cff4d58..b5a1ddc32b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/NormalizeAggregate.java @@ -93,7 +93,8 @@ public class NormalizeAggregate extends OneRewriteRuleFactory implements Normali aggregateFunctionToSlotContext.pushDownToNamedExpression(normalizedAggregateFunctions); List normalizedGroupBy = - (List) groupByAndArgumentToSlotContext.normalizeToUseSlotRef(aggregate.getGroupByExpressions()); + (List) groupByAndArgumentToSlotContext + .normalizeToUseSlotRef(aggregate.getGroupByExpressions()); // we can safely add all groupBy and aggregate functions to output, because we will // add a project on it, and the upper project can protect the scope of visible of slot diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java index cec372311e..a30364b116 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/GroupingSetsTest.java @@ -183,4 +183,53 @@ public class GroupingSetsTest extends TestWithFeService { PlanChecker.from(connectContext) .checkPlannerResult("select if(k1 = 1, 2, k1) k_if from t1"); } + + @Test + public void test1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test1_1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test1_2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select coalesce(col1, 'all') as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());"); + } + + @Test + public void test2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(1 = null, 'all', 2) as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_1() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col1, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_2() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),());"); + } + + @Test + public void test2_3() { + PlanChecker.from(connectContext) + .checkPlannerResult("select if(col1 = null, 'all', 2) as col2, count(*) as cnt from" + + " (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),());"); + } } diff --git a/regression-test/data/nereids_syntax_p0/grouping_sets.out b/regression-test/data/nereids_syntax_p0/grouping_sets.out index 38b19938d0..6c18f54e2b 100644 --- a/regression-test/data/nereids_syntax_p0/grouping_sets.out +++ b/regression-test/data/nereids_syntax_p0/grouping_sets.out @@ -224,3 +224,37 @@ 3 4 +-- !select1 -- +a 1 +all 1 +all 2 + +-- !select2 -- +a 1 +all 1 +all 2 + +-- !select3 -- +\N 2 +a 1 +all 1 + +-- !select4 -- +2 1 +2 1 +2 2 + +-- !select5 -- +2 1 +2 1 +2 2 + +-- !select6 -- +2 1 +2 1 +2 2 + +-- !select7 -- +2 1 +2 1 +2 2 diff --git a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy index 5218a4215c..8e6cc6e5c7 100644 --- a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy +++ b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy @@ -232,4 +232,32 @@ suite("test_nereids_grouping_sets") { ) T ) T2; """ + + order_qt_select1 """ + select coalesce(col1, 'all') as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select2 """ + select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select3 """ + select coalesce(col1, 'all') as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col2),()); + """ + + order_qt_select4 """ + select if(1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select5 """ + select if(col1 = null, 'all', 2) as col1, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select6 """ + select if(1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ + + order_qt_select7 """ + select if(col1 = null, 'all', 2) as col2, count(*) as cnt from (select null as col1 union all select 'a' as col1 ) t group by grouping sets ((col1),()); + """ }