diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java index 9dfb92a3f2..65998416fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java @@ -120,6 +120,7 @@ import org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion; import org.apache.doris.nereids.rules.rewrite.ReorderJoin; import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren; import org.apache.doris.nereids.rules.rewrite.SplitLimit; +import org.apache.doris.nereids.rules.rewrite.SumLiteralRewrite; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAgg; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject; import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin; @@ -388,7 +389,10 @@ public class Rewriter extends AbstractBatchJobExecutor { custom(RuleType.ELIMINATE_SORT, EliminateSort::new), bottomUp(new EliminateEmptyRelation()) ), - + topic("agg rewrite", + // these rules should be put after mv optimization to avoid mv matching fail + topDown(new SumLiteralRewrite()) + ), // this rule batch must keep at the end of rewrite to do some plan check topic("Final rewrite and check", custom(RuleType.CHECK_DATA_TYPES, CheckDataTypes::new), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index fc47d10487..cc66c27fbc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -53,7 +53,7 @@ public enum RuleType { BINDING_SLOT_WITH_PATHS_SCAN(RuleTypeClass.REWRITE), COUNT_LITERAL_REWRITE(RuleTypeClass.REWRITE), - + SUM_LITERAL_REWRITE(RuleTypeClass.REWRITE), REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE), FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java new file mode 100644 index 0000000000..5ded4bc9a7 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import com.google.common.collect.ImmutableList; +import org.apache.thrift.annotation.Nullable; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +/** + * sum(expr +/- literal) ==> sum(expr) +/- literal * count(expr) + */ +public class SumLiteralRewrite extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .then(agg -> { + Map> sumLiteralMap = new HashMap<>(); + for (NamedExpression namedExpression : agg.getOutputs()) { + Pair> pel = extractSumLiteral(namedExpression); + if (pel == null) { + continue; + } + sumLiteralMap.put(pel.first, pel.second); + } + if (sumLiteralMap.isEmpty()) { + return null; + } + return rewriteSumLiteral(agg, sumLiteralMap); + }).toRule(RuleType.SUM_LITERAL_REWRITE); + } + + private Plan rewriteSumLiteral( + LogicalAggregate agg, Map> sumLiteralMap) { + Set newAggOutput = new HashSet<>(); + for (NamedExpression expr : agg.getOutputExpressions()) { + if (!sumLiteralMap.containsKey(expr)) { + newAggOutput.add(expr); + } + } + + Map exprToSum = new HashMap<>(); + Map exprToCount = new HashMap<>(); + + Map existedAggFunc = new HashMap<>(); + for (NamedExpression e : agg.getOutputExpressions()) { + if (e.children().size() == 1 && e.child(0) instanceof AggregateFunction) { + existedAggFunc.put((AggregateFunction) e.child(0), e); + } + } + + Set countSumExpr = new HashSet<>(); + for (Pair pair : sumLiteralMap.values()) { + countSumExpr.add(pair.first); + } + + for (Expression e : countSumExpr) { + NamedExpression namedSum = constructSum(e, existedAggFunc); + NamedExpression namedCount = constructCount(e, existedAggFunc); + exprToSum.put(e, namedSum.toSlot()); + exprToCount.put(e, namedCount.toSlot()); + newAggOutput.add(namedSum); + newAggOutput.add(namedCount); + } + + LogicalAggregate newAgg = agg.withAggOutput(ImmutableList.copyOf(newAggOutput)); + + List newProjects = constructProjectExpression(agg, sumLiteralMap, exprToSum, exprToCount); + + return new LogicalProject<>(newProjects, newAgg); + } + + private List constructProjectExpression( + LogicalAggregate agg, Map> sumLiteralMap, + Map exprToSum, Map exprToCount) { + List newProjects = new ArrayList<>(); + for (NamedExpression namedExpr : agg.getOutputExpressions()) { + if (!sumLiteralMap.containsKey(namedExpr)) { + newProjects.add(namedExpr.toSlot()); + continue; + } + Expression originExpr = sumLiteralMap.get(namedExpr).first; + Literal literal = sumLiteralMap.get(namedExpr).second; + Expression newExpr; + if (namedExpr.child(0).child(0) instanceof Add) { + newExpr = new Add(exprToSum.get(originExpr), + new Multiply(literal, exprToCount.get(originExpr))); + } else { + newExpr = new Subtract(exprToSum.get(originExpr), + new Multiply(literal, exprToCount.get(originExpr))); + } + newProjects.add(new Alias(namedExpr.getExprId(), newExpr, namedExpr.getName())); + } + return newProjects; + } + + private NamedExpression constructSum(Expression child, Map existedAggFunc) { + Sum sum = new Sum(child); + NamedExpression namedSum; + if (existedAggFunc.containsKey(sum)) { + namedSum = existedAggFunc.get(sum); + } else { + namedSum = new Alias(sum); + } + return namedSum; + } + + private NamedExpression constructCount(Expression child, Map existedAggFunc) { + Count count = new Count(child); + NamedExpression namedCount; + if (existedAggFunc.containsKey(count)) { + namedCount = existedAggFunc.get(count); + } else { + namedCount = new Alias(count); + } + return namedCount; + } + + private @Nullable Pair> extractSumLiteral( + NamedExpression namedExpression) { + if (namedExpression.children().size() != 1) { + return null; + } + Expression func = namedExpression.child(0); + if (!(func instanceof Sum)) { + return null; + } + Expression child = func.child(0); + if (!(child instanceof Add) && !(child instanceof Subtract)) { + return null; + } + + Expression left = ((BinaryArithmetic) child).left(); + Expression right = ((BinaryArithmetic) child).right(); + if (!(right.isLiteral() && left instanceof Slot)) { + // right now, only support slot +/- literal + return null; + } + if (!(right.getDataType().isIntegerLikeType() || right.getDataType().isFloatLikeType())) { + // only support integer or float types + return null; + } + return Pair.of(namedExpression, Pair.of(left, (Literal) right)); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java new file mode 100644 index 0000000000..97dda3f8cb --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java @@ -0,0 +1,54 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +class SumLiteralRewriteTest implements MemoPatternMatchSupported { + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + + @Test + void testSimpleAddSum() { + Slot slot1 = scan1.getOutput().get(0); + Alias sum = new Alias(new Sum(slot1)); + Alias add1 = new Alias(new Sum(new Add(slot1, Literal.of(1)))); + Alias add2 = new Alias(new Sum(new Add(slot1, Literal.of(2)))); + Alias sub1 = new Alias(new Sum(new Subtract(slot1, Literal.of(1)))); + Alias sub2 = new Alias(new Sum(new Subtract(slot1, Literal.of(2)))); + LogicalAggregate agg = new LogicalAggregate<>( + ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(sum, add1, add2, sub1, sub2), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + .matches(logicalAggregate().when(p -> p.getAggregateFunctions().size() == 2)); + } +} diff --git a/regression-test/data/nereids_clickbench_shape_p0/query30.out b/regression-test/data/nereids_clickbench_shape_p0/query30.out index 8a3f753f3d..bad1a26f51 100644 --- a/regression-test/data/nereids_clickbench_shape_p0/query30.out +++ b/regression-test/data/nereids_clickbench_shape_p0/query30.out @@ -1,9 +1,10 @@ -- This file is automatically generated. You should know what you did if you want to edit this -- !ckbench_shape_30 -- PhysicalResultSink ---hashAgg[GLOBAL] -----PhysicalDistribute[DistributionSpecGather] -------hashAgg[LOCAL] ---------PhysicalProject -----------PhysicalOlapScan[hits] +--PhysicalProject +----hashAgg[GLOBAL] +------PhysicalDistribute[DistributionSpecGather] +--------hashAgg[LOCAL] +----------PhysicalProject +------------PhysicalOlapScan[hits] diff --git a/regression-test/data/nereids_rules_p0/sumRewrite.out b/regression-test/data/nereids_rules_p0/sumRewrite.out new file mode 100644 index 0000000000..ddb4b90175 --- /dev/null +++ b/regression-test/data/nereids_rules_p0/sumRewrite.out @@ -0,0 +1,142 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !sum_add_const$ -- +138 + +-- !sum_add_const_alias$ -- +138 + +-- !sum_add_const_where$ -- +138 + +-- !sum_add_const_group_by$ -- +10 43 +6 27 +7 31 +8 17 +9 20 + +-- !sum_add_const_having$ -- +10 43 +6 27 +7 31 +8 17 +9 20 + +-- !sum_sub_const$ -- +106 + +-- !sum_sub_const_alias$ -- +106 + +-- !sum_sub_const_where$ -- +106 + +-- !sum_sub_const_group_by$ -- +10 35 +6 19 +7 23 +8 13 +9 16 + +-- !sum_sub_const_having$ -- +10 35 +6 19 +7 23 +8 13 +9 16 + +-- !sum_add_const_empty_table$ -- +\N + +-- !sum_add_const_empty_table_group_by$ -- + +-- !sum_sub_const_empty_table$ -- +\N + +-- !sum_sub_const_empty_table_group_by$ -- + +-- !float_sum_add_const$ -- +79.60000002384186 + +-- !float_sum_add_const_alias$ -- +79.60000002384186 + +-- !float_sum_add_const_where$ -- +79.60000002384186 + +-- !float_sum_add_const_group_by$ -- +10 24.0 +6 7.300000071525574 +7 11.700000047683716 +8 16.09999990463257 +9 20.5 + +-- !float_sum_add_const_having$ -- +10 24.0 +6 7.300000071525574 +7 11.700000047683716 +8 16.09999990463257 +9 20.5 + +-- !float_sum_sub_const$ -- +39.60000002384186 + +-- !float_sum_sub_const_alias$ -- +39.60000002384186 + +-- !float_sum_sub_const_where$ -- +39.60000002384186 + +-- !float_sum_sub_const_group_by$ -- +10 16.0 +6 -0.6999999284744263 +7 3.700000047683716 +8 8.099999904632568 +9 12.5 + +-- !float_sum_sub_const_having$ -- +10 16.0 +7 3.700000047683716 +8 8.099999904632568 +9 12.5 + +-- !decimal_sum_add_const_precision_1$ -- +2670.55 + +-- !decimal_sum_add_const_precision_2$ -- +2672.55 + +-- !decimal_sum_add_const_precision_3$ -- +10 694.19 +6 434.03 +7 474.07 +8 514.11 +9 554.15 + +-- !decimal_sum_add_const_precision_4$ -- +10 694.636 +6 434.476 +7 474.516 +8 514.556 +9 554.596 + +-- !decimal_sum_sub_const_precision_1$ -- +2630.55 + +-- !decimal_sum_sub_const_precision_2$ -- +2628.55 + +-- !decimal_sum_sub_const_precision_3$ -- +10 686.19 +6 426.03 +7 466.07 +8 506.11 +9 546.15 + +-- !decimal_sum_sub_const_precision_4$ -- +10 685.744 +6 425.584 +7 465.624 +8 505.664 +9 545.704 + diff --git a/regression-test/suites/nereids_rules_p0/sumRewrite.groovy b/regression-test/suites/nereids_rules_p0/sumRewrite.groovy new file mode 100644 index 0000000000..6a6e8a02c8 --- /dev/null +++ b/regression-test/suites/nereids_rules_p0/sumRewrite.groovy @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("sumRewrite") { + sql "SET enable_nereids_planner=true" + sql "set runtime_filter_mode=OFF" + sql "SET enable_fallback_to_original_planner=false" + sql """ + DROP TABLE IF EXISTS sr + """ + sql """ + CREATE TABLE IF NOT EXISTS sr( + `id` int NULL, + `null_id` int not NULL, + `f_id` float NULL, + `d_id` decimal(10,2), + ) ENGINE = OLAP + DISTRIBUTED BY HASH(id) BUCKETS 4 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1" + ); + """ + + sql """ +INSERT INTO sr (id, null_id, f_id, d_id) VALUES +(11, 6, 1.1, 210.01), +(12, 6, 2.2, 220.02), +(13, 7, 3.3, 230.03), +(14, 7, 4.4, 240.04), +(15, 8, 5.5, 250.05), +(null, 8, 6.6, 260.06), +(null, 9, 7.7, 270.07), +(18, 9, 8.8, 280.08), +(19, 10, 9.9, 290.09), +(20, 10, 10.1, 400.10); +""" + + order_qt_sum_add_const$ """ select sum(id + 2) from sr """ + + order_qt_sum_add_const_alias$ """ select sum(id + 2) as result from sr """ + + order_qt_sum_add_const_where$ """ select sum(id + 2) from sr where id is not null """ + + order_qt_sum_add_const_group_by$ """ select null_id, sum(id + 2) from sr group by null_id """ + + order_qt_sum_add_const_having$ """ select null_id, sum(id + 2) from sr group by null_id having sum(id + 2) > 5 """ + + order_qt_sum_sub_const$ """ select sum(id - 2) from sr """ + + order_qt_sum_sub_const_alias$ """ select sum(id - 2) as result from sr """ + + order_qt_sum_sub_const_where$ """ select sum(id - 2) from sr where id is not null """ + + order_qt_sum_sub_const_group_by$ """ select null_id, sum(id - 2) from sr group by null_id """ + + order_qt_sum_sub_const_having$ """ select null_id, sum(id - 2) from sr group by null_id having sum(id - 2) > 0 """ + + order_qt_sum_add_const_empty_table$ """ select sum(id + 2) from sr where 1=0 """ + + order_qt_sum_add_const_empty_table_group_by$ """ select null_id, sum(id + 2) from sr where 1=0 group by null_id """ + + order_qt_sum_sub_const_empty_table$ """ select sum(id - 2) from sr where 1=0 """ + + order_qt_sum_sub_const_empty_table_group_by$ """ select null_id, sum(id - 2) from sr where 1=0 group by null_id """ + + // float类型字段测试 + order_qt_float_sum_add_const$ """ select sum(f_id + 2) from sr """ + + order_qt_float_sum_add_const_alias$ """ select sum(f_id + 2) as result from sr """ + + order_qt_float_sum_add_const_where$ """ select sum(f_id + 2) from sr where f_id is not null """ + + order_qt_float_sum_add_const_group_by$ """ select null_id, sum(f_id + 2) from sr group by null_id """ + + order_qt_float_sum_add_const_having$ """ select null_id, sum(f_id + 2) from sr group by null_id having sum(f_id + 2) > 5 """ + + order_qt_float_sum_sub_const$ """ select sum(f_id - 2) from sr """ + + order_qt_float_sum_sub_const_alias$ """ select sum(f_id - 2) as result from sr """ + + order_qt_float_sum_sub_const_where$ """ select sum(f_id - 2) from sr where f_id is not null """ + + order_qt_float_sum_sub_const_group_by$ """ select null_id, sum(f_id - 2) from sr group by null_id """ + + order_qt_float_sum_sub_const_having$ """ select null_id, sum(f_id - 2) from sr group by null_id having sum(f_id - 2) > 0 """ + + // 测试精度变化对sum加常数的影响 + // order_qt_decimal_sum_add_const_precision_1$ """ select sum(d_id + 2) from sr """ + + // order_qt_decimal_sum_add_const_precision_2$ """ select sum(d_id + 2.2) from sr """ + + // order_qt_decimal_sum_add_const_precision_3$ """ select null_id, sum(d_id + 2) from sr group by null_id """ + + // order_qt_decimal_sum_add_const_precision_4$ """ select null_id, sum(d_id + 2.223) from sr group by null_id """ + + // 测试精度变化对sum减常数的影响 + // order_qt_decimal_sum_sub_const_precision_1$ """ select sum(d_id - 2) from sr """ + + // order_qt_decimal_sum_sub_const_precision_2$ """ select sum(d_id - 2.2) from sr """ + + // order_qt_decimal_sum_sub_const_precision_3$ """ select null_id, sum(d_id - 2) from sr group by null_id """ + + // order_qt_decimal_sum_sub_const_precision_4$ """ select null_id, sum(d_id - 2.223) from sr group by null_id """ +} \ No newline at end of file