diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index fc3ef04fba..4fe9df664c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -55,6 +55,7 @@ import org.apache.doris.nereids.rules.rewrite.EliminateAssertNumRows; import org.apache.doris.nereids.rules.rewrite.EliminateDedupJoinCondition; import org.apache.doris.nereids.rules.rewrite.EliminateEmptyRelation; import org.apache.doris.nereids.rules.rewrite.EliminateFilter; +import org.apache.doris.nereids.rules.rewrite.EliminateGroupBy; import org.apache.doris.nereids.rules.rewrite.EliminateJoinByFK; import org.apache.doris.nereids.rules.rewrite.EliminateJoinCondition; import org.apache.doris.nereids.rules.rewrite.EliminateLimit; @@ -276,6 +277,10 @@ public class Rewriter extends AbstractBatchJobExecutor { topDown(new BuildAggForUnion()) ), + topic("Eliminate GroupBy", + topDown(new EliminateGroupBy()) + ), + topic("Eager aggregation", topDown( new PushDownSumThroughJoin(), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java index c2ba22b5dc..f013740c17 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/FunctionalDependencies.java @@ -220,7 +220,8 @@ public class FunctionalDependencies { .map(s -> replaceMap.getOrDefault(s, s)) .collect(Collectors.toSet()); slotSets = slotSets.stream() - .map(set -> set.stream().map(replaceMap::get).collect(ImmutableSet.toImmutableSet())) + .map(set -> set.stream().map(s -> replaceMap.getOrDefault(s, s)) + .collect(ImmutableSet.toImmutableSet())) .collect(Collectors.toSet()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 8edf2d079a..32794946dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -207,6 +207,7 @@ public enum RuleType { ELIMINATE_NOT_NULL(RuleTypeClass.REWRITE), ELIMINATE_UNNECESSARY_PROJECT(RuleTypeClass.REWRITE), ELIMINATE_OUTER_JOIN(RuleTypeClass.REWRITE), + ELIMINATE_GROUP_BY(RuleTypeClass.REWRITE), ELIMINATE_DEDUP_JOIN_CONDITION(RuleTypeClass.REWRITE), ELIMINATE_NULL_AWARE_LEFT_ANTI_JOIN(RuleTypeClass.REWRITE), ELIMINATE_ASSERT_NUM_ROWS(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java new file mode 100644 index 0000000000..3b95e9b44e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java @@ -0,0 +1,83 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.IsNull; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Max; +import org.apache.doris.nereids.trees.expressions.functions.agg.Min; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.functions.scalar.If; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.util.PlanUtils; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Eliminate GroupBy. + */ +public class EliminateGroupBy extends OneRewriteRuleFactory { + + @Override + public Rule build() { + return logicalAggregate() + .when(agg -> agg.getGroupByExpressions().stream().allMatch(expr -> expr instanceof Slot)) + .then(agg -> { + Set groupby = agg.getGroupByExpressions().stream().map(e -> (Slot) e) + .collect(Collectors.toSet()); + Plan child = agg.child(); + boolean unique = child.getLogicalProperties().getFunctionalDependencies() + .isUniqueAndNotNull(groupby); + if (!unique) { + return null; + } + Set aggregateFunctions = agg.getAggregateFunctions(); + if (!aggregateFunctions.stream().allMatch( + f -> (f instanceof Sum || f instanceof Count || f instanceof Min || f instanceof Max) + && (f.arity() == 1 && f.child(0) instanceof Slot))) { + return null; + } + + List newOutput = agg.getOutputExpressions().stream().map(ne -> { + if (ne instanceof Alias && ne.child(0) instanceof AggregateFunction) { + AggregateFunction f = (AggregateFunction) ne.child(0); + if (f instanceof Sum || f instanceof Min || f instanceof Max) { + return new Alias(ne.getExprId(), f.child(0), ne.getName()); + } else if (f instanceof Count) { + return (NamedExpression) ne.withChildren( + new If(new IsNull(f.child(0)), Literal.of(0), Literal.of(1))); + } else { + throw new IllegalStateException("Unexpected aggregate function: " + f); + } + } else { + return ne; + } + }).collect(Collectors.toList()); + return PlanUtils.projectOrSelf(newOutput, child); + }).toRule(RuleType.ELIMINATE_GROUP_BY); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java new file mode 100644 index 0000000000..14f62d25f7 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java @@ -0,0 +1,100 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Test; + +class EliminateGroupByTest extends TestWithFeService implements MemoPatternMatchSupported { + + @Override + protected void runBeforeAll() throws Exception { + createDatabase("eliminate_group_by"); + createTable( + "create table eliminate_group_by.t (\n" + + "id int not null,\n" + + "name varchar(128),\n" + + "age int, sex int" + + ")\n" + + "UNIQUE KEY(id)\n" + + "distributed by hash(id) buckets 10\n" + + "properties('replication_num' = '1');" + ); + connectContext.setDatabase("default_cluster:eliminate_group_by"); + } + + @Test + void eliminateMax() { + // -> select id, age as max(age) from t; + String sql = "select id, max(age) from t group by id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalProject().when(p -> p.getProjects().get(0).toSql().equals("id") + && p.getProjects().get(1).toSql().equals("age AS `max(age)`")) + ); + } + + @Test + void eliminateMin() { + // -> select id, age as min(age) from t; + String sql = "select id, min(age) from t group by id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalProject().when(p -> p.getProjects().get(0).toSql().equals("id") + && p.getProjects().get(1).toSql().equals("age AS `min(age)`")) + ); + } + + @Test + void eliminateSum() { + // -> select id, age as sum(age) from t; + String sql = "select id, sum(age) from t group by id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalProject().when(p -> p.getProjects().get(0).toSql().equals("id") + && p.getProjects().get(1).toSql().equals("age AS `sum(age)`")) + ); + } + + @Test + void eliminateCount() { + // -> select id, case when age is not null then 1 else 0 end from t; + String sql = "select id, count(age) from t group by id"; + + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalProject().when(p -> p.getProjects().get(0).toSql().equals("id") + && p.getProjects().get(1).toSql().equals("if(age IS NULL, 0, 1) AS `if(age IS NULL, 0, 1)`") + ) + ); + } +}