diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java index 2c9db0fdad..3a93298067 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/SlotRef.java @@ -84,7 +84,7 @@ public class SlotRef extends Expr { analysisDone(); } - // nerieds use this constructor to build aggFnParam + // nereids use this constructor to build aggFnParam public SlotRef(Type type, boolean nullable) { super(); // tuple id and slot id is meaningless here, nereids just use type and nullable diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java index d88e9b4591..8c62aab3d4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeat.java @@ -218,9 +218,16 @@ public class NormalizeRepeat extends OneAnalysisRuleFactory { List groupingSetExpressions = ExpressionUtils.flatExpressions(repeat.getGroupingSets()); Set commonGroupingSetExpressions = repeat.getCommonGroupingSetExpressions(); + // nullable will be different from grouping set and output expressions, + // so we can not use the slot in grouping set,but use the equivalent slot in output expressions. + List outputs = repeat.getOutputExpressions(); + Map normalizeToSlotMap = Maps.newLinkedHashMap(); for (Expression expression : sourceExpressions) { Optional pushDownTriplet; + if (expression instanceof NamedExpression && outputs.contains(expression)) { + expression = outputs.get(outputs.indexOf(expression)); + } if (groupingSetExpressions.contains(expression)) { boolean isCommonGroupingSetExpression = commonGroupingSetExpressions.contains(expression); pushDownTriplet = toGroupingSetExpressionPushDownTriplet( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java new file mode 100644 index 0000000000..f32f5959d9 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/NormalizeRepeatTest.java @@ -0,0 +1,56 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat; +import org.apache.doris.nereids.trees.plans.logical.RelationUtil; +import org.apache.doris.nereids.types.IntegerType; +import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +public class NormalizeRepeatTest implements PatternMatchSupported { + + @Test + public void testKeepNullableAfterNormalizeRepeat() { + SlotReference slot1 = new SlotReference("id", IntegerType.INSTANCE, false); + SlotReference slot2 = slot1.withNullable(true); + SlotReference slot3 = new SlotReference("name", StringType.INSTANCE, false); + Alias alias = new Alias(new Sum(slot3), "sum(name)"); + Plan plan = new LogicalRepeat<>( + ImmutableList.of(ImmutableList.of(slot1)), + ImmutableList.of(slot2, alias), + new LogicalOlapScan(RelationUtil.newRelationId(), PlanConstructor.newOlapTable(0, "t", 0)) + ); + PlanChecker.from(MemoTestUtils.createCascadesContext(plan)) + .applyTopDown(new NormalizeRepeat()) + .matches( + logicalRepeat().when(repeat -> repeat.getOutputExpressions().get(0).nullable()) + ); + } +} diff --git a/regression-test/data/nereids_syntax_p0/grouping_sets.out b/regression-test/data/nereids_syntax_p0/grouping_sets.out index 78b5a0284a..38b19938d0 100644 --- a/regression-test/data/nereids_syntax_p0/grouping_sets.out +++ b/regression-test/data/nereids_syntax_p0/grouping_sets.out @@ -214,3 +214,13 @@ 2 \N 3 \N +-- !select7 -- +1 +2 +3 +4 +1 +2 +3 +4 + diff --git a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy index 8343d7923c..5218a4215c 100644 --- a/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy +++ b/regression-test/suites/nereids_syntax_p0/grouping_sets.groovy @@ -157,4 +157,79 @@ suite("test_nereids_grouping_sets") { order_qt_select """ select k1, sum(k2) from (select k1, k2, grouping(k1), grouping(k2) from groupingSetsTableNotNullable group by grouping sets((k1), (k2)))a group by k1 """ + + sql """ + drop table if exists grouping_subquery_table; + """ + + sql """ + create table grouping_subquery_table ( a int not null, b int not null ) + ENGINE=OLAP + DISTRIBUTED BY HASH(a) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "in_memory" = "false", + "storage_format" = "V2" + ); + """ + + sql """ + insert into grouping_subquery_table values + (1, 1), (1, 2), (1, 3), (1, 4), + (2, 1), (2, 2), (2, 3), (2, 4), + (3, 1), (3, 2), (3, 3), (3, 4), + (4, 1), (4, 2), (4, 3), (4, 4); + """ + + qt_select7 """ + SELECT + a + FROM + ( + with base_table as ( + SELECT + `a`, + sum(`b`) as `sum(b)` + FROM + ( + SELECT + inv.a, + sum(inv.b) as b + FROM + grouping_subquery_table inv + group by + inv.a + ) T + GROUP BY + `a` + ), + grouping_sum_table as ( + select + `a`, + sum(`sum(b)`) as `sum(b)` + from + base_table + group by + grouping sets ( + (`base_table`.`a`) + ) + ) + select + * + from + ( + select + `a`, + `sum(b)` + from + base_table + union all + select + `a`, + `sum(b)` + from + grouping_sum_table + ) T + ) T2; + """ }