diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 028a0428d4..ee6d19b315 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -48,7 +48,7 @@ import org.apache.doris.nereids.rules.rewrite.CollectProjectAboveConsumer; import org.apache.doris.nereids.rules.rewrite.ColumnPruning; import org.apache.doris.nereids.rules.rewrite.ConvertInnerOrCrossJoin; import org.apache.doris.nereids.rules.rewrite.CountDistinctRewrite; -import org.apache.doris.nereids.rules.rewrite.CountLiteralToCountStar; +import org.apache.doris.nereids.rules.rewrite.CountLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.CreatePartitionTopNFromWindow; import org.apache.doris.nereids.rules.rewrite.DeferMaterializeTopNResult; import org.apache.doris.nereids.rules.rewrite.EliminateAggregate; @@ -194,7 +194,7 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown( new SimplifyAggGroupBy(), new NormalizeAggregate(), - new CountLiteralToCountStar(), + new CountLiteralRewrite(), new NormalizeSort() ), topic("Window analysis", diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 093dbc28b3..172862f679 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -59,7 +59,7 @@ public enum RuleType { BINDING_INSERT_TARGET_TABLE(RuleTypeClass.REWRITE), BINDING_INSERT_FILE(RuleTypeClass.REWRITE), - COUNT_LITERAL_TO_COUNT_STAR(RuleTypeClass.REWRITE), + COUNT_LITERAL_REWRITE(RuleTypeClass.REWRITE), REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE), FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java similarity index 55% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java index fc08273b84..dfe13b388f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralToCountStar.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewrite.java @@ -23,6 +23,9 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.Lists; @@ -30,22 +33,43 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; /** * count(1) ==> count(*) + * count(null) ==> 0 */ -public class CountLiteralToCountStar extends OneRewriteRuleFactory { +public class CountLiteralRewrite extends OneRewriteRuleFactory { @Override public Rule build() { return logicalAggregate().then( agg -> { List newExprs = Lists.newArrayListWithCapacity(agg.getOutputExpressions().size()); - if (rewriteCountLiteral(agg.getOutputExpressions(), newExprs)) { - return agg.withAggOutput(newExprs); + if (!rewriteCountLiteral(agg.getOutputExpressions(), newExprs)) { + // no need to rewrite + return agg; + } + + Map> projectsAndAggFunc = newExprs.stream() + .collect(Collectors.partitioningBy(Expression::isConstant)); + + if (projectsAndAggFunc.get(false).isEmpty()) { + // if there is no group by keys and other agg func, don't rewrite + return null; + } else { + // if there is group by keys, put count(null) in projects, such as + // project(0 as count(null)) + // --Aggregate(k1, group by k1) + Plan plan = agg.withAggOutput(projectsAndAggFunc.get(false)); + if (!projectsAndAggFunc.get(true).isEmpty()) { + projectsAndAggFunc.get(false).stream().map(NamedExpression::toSlot) + .forEach(projectsAndAggFunc.get(true)::add); + plan = new LogicalProject<>(projectsAndAggFunc.get(true), plan); + } + return plan; } - return agg; } - ).toRule(RuleType.COUNT_LITERAL_TO_COUNT_STAR); + ).toRule(RuleType.COUNT_LITERAL_REWRITE); } private boolean rewriteCountLiteral(List oldExprs, List newExprs) { @@ -55,7 +79,7 @@ public class CountLiteralToCountStar extends OneRewriteRuleFactory { Set oldAggFuncSet = expr.collect(AggregateFunction.class::isInstance); oldAggFuncSet.stream() .filter(this::isCountLiteral) - .forEach(c -> replaced.put(c, new Count())); + .forEach(c -> replaced.put(c, rewrite((Count) c))); expr = expr.rewriteUp(s -> replaced.getOrDefault(s, s)); changed |= !replaced.isEmpty(); newExprs.add((NamedExpression) expr); @@ -66,6 +90,14 @@ public class CountLiteralToCountStar extends OneRewriteRuleFactory { private boolean isCountLiteral(AggregateFunction aggFunc) { return !aggFunc.isDistinct() && aggFunc instanceof Count - && aggFunc.children().stream().allMatch(e -> e.isLiteral() && !e.isNullLiteral()); + && aggFunc.children().size() == 1 + && aggFunc.child(0).isLiteral(); + } + + private Expression rewrite(Count count) { + if (count.child(0).isNullLiteral()) { + return new BigIntLiteral(0); + } + return new Count(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index 6922f81ca6..bdd776ffe9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -23,6 +23,7 @@ import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.AbstractTreeNode; import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait; import org.apache.doris.nereids.trees.expressions.functions.Nondeterministic; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.shape.LeafExpression; @@ -224,6 +225,10 @@ public abstract class Expression extends AbstractTreeNode implements * Whether the expression is a constant. */ public boolean isConstant() { + if (this instanceof AggregateFunction) { + // agg_fun(literal) is not constant, the result depends on the group by keys + return false; + } if (this instanceof LeafExpression) { return this instanceof Literal; } else { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewriteTest.java new file mode 100644 index 0000000000..05a2774196 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/CountLiteralRewriteTest.java @@ -0,0 +1,73 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +/** + * CountLiteralRewriteTest + */ +class CountLiteralRewriteTest extends TestWithFeService implements MemoPatternMatchSupported { + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("test"); + + createTable("create table test.student (\n" + "id int not null,\n" + "name varchar(128),\n" + + "age int, sex int)\n" + "distributed by hash(id) buckets 10\n" + + "properties('replication_num' = '1');"); + connectContext.setDatabase("default_cluster:test"); + } + + @Test + void testCountLiteral() { + PlanChecker.from(connectContext) + .analyze("select count(1) as c from student group by id") + .rewrite() + .matches(logicalAggregate() + .when(agg -> agg.getOutputExpressions().stream() + .allMatch(expr -> expr.anyMatch(c -> !(c instanceof Count) || ((Count) c).isCountStar())))) + .printlnTree(); + PlanChecker.from(connectContext) + .analyze("select count(1), sum(id) from student") + .rewrite() + .matches(logicalAggregate() + .when(agg -> agg.getOutputExpressions().stream() + .allMatch(expr -> expr.anyMatch(c -> !(c instanceof Count) || ((Count) c).isCountStar())))) + .printlnTree(); + } + + @Test + void testCountNull() { + PlanChecker.from(connectContext) + .analyze("select count(null) as c from student group by id") + .rewrite() + .matches(logicalAggregate().when(agg -> agg.getExpressions().stream().noneMatch(Count.class::isInstance))) + .printlnTree(); + PlanChecker.from(connectContext) + .analyze("select count(null) as c, sum(id) from student") + .rewrite() + .matches(logicalAggregate().when(agg -> agg.getExpressions().stream().noneMatch(Count.class::isInstance))) + .printlnTree(); + } +}