[feat](Nereids): rewrite sum literal to sum and count (#32244)

sum(v + 2) => sum(v) + 2*count(v)
sum(v - 2) => sum(v) - 2*count(v)
This commit is contained in:
谢健
2024-03-22 10:27:38 +08:00
committed by yiguolei
parent 8f3f9a53be
commit 4de8775e17
7 changed files with 511 additions and 7 deletions

View File

@ -120,6 +120,7 @@ import org.apache.doris.nereids.rules.rewrite.PushProjectThroughUnion;
import org.apache.doris.nereids.rules.rewrite.ReorderJoin;
import org.apache.doris.nereids.rules.rewrite.RewriteCteChildren;
import org.apache.doris.nereids.rules.rewrite.SplitLimit;
import org.apache.doris.nereids.rules.rewrite.SumLiteralRewrite;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAgg;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinAggProject;
import org.apache.doris.nereids.rules.rewrite.TransposeSemiJoinLogicalJoin;
@ -388,7 +389,10 @@ public class Rewriter extends AbstractBatchJobExecutor {
custom(RuleType.ELIMINATE_SORT, EliminateSort::new),
bottomUp(new EliminateEmptyRelation())
),
topic("agg rewrite",
// these rules should be put after mv optimization to avoid mv matching fail
topDown(new SumLiteralRewrite())
),
// this rule batch must keep at the end of rewrite to do some plan check
topic("Final rewrite and check",
custom(RuleType.CHECK_DATA_TYPES, CheckDataTypes::new),

View File

@ -53,7 +53,7 @@ public enum RuleType {
BINDING_SLOT_WITH_PATHS_SCAN(RuleTypeClass.REWRITE),
COUNT_LITERAL_REWRITE(RuleTypeClass.REWRITE),
SUM_LITERAL_REWRITE(RuleTypeClass.REWRITE),
REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT(RuleTypeClass.REWRITE),
FILL_UP_HAVING_AGGREGATE(RuleTypeClass.REWRITE),

View File

@ -0,0 +1,185 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import com.google.common.collect.ImmutableList;
import org.apache.thrift.annotation.Nullable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* sum(expr +/- literal) ==> sum(expr) +/- literal * count(expr)
*/
public class SumLiteralRewrite extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalAggregate()
.whenNot(agg -> agg.getSourceRepeat().isPresent())
.then(agg -> {
Map<NamedExpression, Pair<Expression, Literal>> sumLiteralMap = new HashMap<>();
for (NamedExpression namedExpression : agg.getOutputs()) {
Pair<NamedExpression, Pair<Expression, Literal>> pel = extractSumLiteral(namedExpression);
if (pel == null) {
continue;
}
sumLiteralMap.put(pel.first, pel.second);
}
if (sumLiteralMap.isEmpty()) {
return null;
}
return rewriteSumLiteral(agg, sumLiteralMap);
}).toRule(RuleType.SUM_LITERAL_REWRITE);
}
private Plan rewriteSumLiteral(
LogicalAggregate<?> agg, Map<NamedExpression, Pair<Expression, Literal>> sumLiteralMap) {
Set<NamedExpression> newAggOutput = new HashSet<>();
for (NamedExpression expr : agg.getOutputExpressions()) {
if (!sumLiteralMap.containsKey(expr)) {
newAggOutput.add(expr);
}
}
Map<Expression, Slot> exprToSum = new HashMap<>();
Map<Expression, Slot> exprToCount = new HashMap<>();
Map<AggregateFunction, NamedExpression> existedAggFunc = new HashMap<>();
for (NamedExpression e : agg.getOutputExpressions()) {
if (e.children().size() == 1 && e.child(0) instanceof AggregateFunction) {
existedAggFunc.put((AggregateFunction) e.child(0), e);
}
}
Set<Expression> countSumExpr = new HashSet<>();
for (Pair<Expression, Literal> pair : sumLiteralMap.values()) {
countSumExpr.add(pair.first);
}
for (Expression e : countSumExpr) {
NamedExpression namedSum = constructSum(e, existedAggFunc);
NamedExpression namedCount = constructCount(e, existedAggFunc);
exprToSum.put(e, namedSum.toSlot());
exprToCount.put(e, namedCount.toSlot());
newAggOutput.add(namedSum);
newAggOutput.add(namedCount);
}
LogicalAggregate<?> newAgg = agg.withAggOutput(ImmutableList.copyOf(newAggOutput));
List<NamedExpression> newProjects = constructProjectExpression(agg, sumLiteralMap, exprToSum, exprToCount);
return new LogicalProject<>(newProjects, newAgg);
}
private List<NamedExpression> constructProjectExpression(
LogicalAggregate<?> agg, Map<NamedExpression, Pair<Expression, Literal>> sumLiteralMap,
Map<Expression, Slot> exprToSum, Map<Expression, Slot> exprToCount) {
List<NamedExpression> newProjects = new ArrayList<>();
for (NamedExpression namedExpr : agg.getOutputExpressions()) {
if (!sumLiteralMap.containsKey(namedExpr)) {
newProjects.add(namedExpr.toSlot());
continue;
}
Expression originExpr = sumLiteralMap.get(namedExpr).first;
Literal literal = sumLiteralMap.get(namedExpr).second;
Expression newExpr;
if (namedExpr.child(0).child(0) instanceof Add) {
newExpr = new Add(exprToSum.get(originExpr),
new Multiply(literal, exprToCount.get(originExpr)));
} else {
newExpr = new Subtract(exprToSum.get(originExpr),
new Multiply(literal, exprToCount.get(originExpr)));
}
newProjects.add(new Alias(namedExpr.getExprId(), newExpr, namedExpr.getName()));
}
return newProjects;
}
private NamedExpression constructSum(Expression child, Map<AggregateFunction, NamedExpression> existedAggFunc) {
Sum sum = new Sum(child);
NamedExpression namedSum;
if (existedAggFunc.containsKey(sum)) {
namedSum = existedAggFunc.get(sum);
} else {
namedSum = new Alias(sum);
}
return namedSum;
}
private NamedExpression constructCount(Expression child, Map<AggregateFunction, NamedExpression> existedAggFunc) {
Count count = new Count(child);
NamedExpression namedCount;
if (existedAggFunc.containsKey(count)) {
namedCount = existedAggFunc.get(count);
} else {
namedCount = new Alias(count);
}
return namedCount;
}
private @Nullable Pair<NamedExpression, Pair<Expression, Literal>> extractSumLiteral(
NamedExpression namedExpression) {
if (namedExpression.children().size() != 1) {
return null;
}
Expression func = namedExpression.child(0);
if (!(func instanceof Sum)) {
return null;
}
Expression child = func.child(0);
if (!(child instanceof Add) && !(child instanceof Subtract)) {
return null;
}
Expression left = ((BinaryArithmetic) child).left();
Expression right = ((BinaryArithmetic) child).right();
if (!(right.isLiteral() && left instanceof Slot)) {
// right now, only support slot +/- literal
return null;
}
if (!(right.getDataType().isIntegerLikeType() || right.getDataType().isFloatLikeType())) {
// only support integer or float types
return null;
}
return Pair.of(namedExpression, Pair.of(left, (Literal) right));
}
}

View File

@ -0,0 +1,54 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
import org.apache.doris.nereids.util.MemoPatternMatchSupported;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
class SumLiteralRewriteTest implements MemoPatternMatchSupported {
private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
@Test
void testSimpleAddSum() {
Slot slot1 = scan1.getOutput().get(0);
Alias sum = new Alias(new Sum(slot1));
Alias add1 = new Alias(new Sum(new Add(slot1, Literal.of(1))));
Alias add2 = new Alias(new Sum(new Add(slot1, Literal.of(2))));
Alias sub1 = new Alias(new Sum(new Subtract(slot1, Literal.of(1))));
Alias sub2 = new Alias(new Sum(new Subtract(slot1, Literal.of(2))));
LogicalAggregate<?> agg = new LogicalAggregate<>(
ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(sum, add1, add2, sub1, sub2), scan1);
PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
.applyTopDown(ImmutableList.of(new SumLiteralRewrite().build()))
.printlnTree()
.matches(logicalAggregate().when(p -> p.getAggregateFunctions().size() == 2));
}
}

View File

@ -1,9 +1,10 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !ckbench_shape_30 --
PhysicalResultSink
--hashAgg[GLOBAL]
----PhysicalDistribute[DistributionSpecGather]
------hashAgg[LOCAL]
--------PhysicalProject
----------PhysicalOlapScan[hits]
--PhysicalProject
----hashAgg[GLOBAL]
------PhysicalDistribute[DistributionSpecGather]
--------hashAgg[LOCAL]
----------PhysicalProject
------------PhysicalOlapScan[hits]

View File

@ -0,0 +1,142 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !sum_add_const$ --
138
-- !sum_add_const_alias$ --
138
-- !sum_add_const_where$ --
138
-- !sum_add_const_group_by$ --
10 43
6 27
7 31
8 17
9 20
-- !sum_add_const_having$ --
10 43
6 27
7 31
8 17
9 20
-- !sum_sub_const$ --
106
-- !sum_sub_const_alias$ --
106
-- !sum_sub_const_where$ --
106
-- !sum_sub_const_group_by$ --
10 35
6 19
7 23
8 13
9 16
-- !sum_sub_const_having$ --
10 35
6 19
7 23
8 13
9 16
-- !sum_add_const_empty_table$ --
\N
-- !sum_add_const_empty_table_group_by$ --
-- !sum_sub_const_empty_table$ --
\N
-- !sum_sub_const_empty_table_group_by$ --
-- !float_sum_add_const$ --
79.60000002384186
-- !float_sum_add_const_alias$ --
79.60000002384186
-- !float_sum_add_const_where$ --
79.60000002384186
-- !float_sum_add_const_group_by$ --
10 24.0
6 7.300000071525574
7 11.700000047683716
8 16.09999990463257
9 20.5
-- !float_sum_add_const_having$ --
10 24.0
6 7.300000071525574
7 11.700000047683716
8 16.09999990463257
9 20.5
-- !float_sum_sub_const$ --
39.60000002384186
-- !float_sum_sub_const_alias$ --
39.60000002384186
-- !float_sum_sub_const_where$ --
39.60000002384186
-- !float_sum_sub_const_group_by$ --
10 16.0
6 -0.6999999284744263
7 3.700000047683716
8 8.099999904632568
9 12.5
-- !float_sum_sub_const_having$ --
10 16.0
7 3.700000047683716
8 8.099999904632568
9 12.5
-- !decimal_sum_add_const_precision_1$ --
2670.55
-- !decimal_sum_add_const_precision_2$ --
2672.55
-- !decimal_sum_add_const_precision_3$ --
10 694.19
6 434.03
7 474.07
8 514.11
9 554.15
-- !decimal_sum_add_const_precision_4$ --
10 694.636
6 434.476
7 474.516
8 514.556
9 554.596
-- !decimal_sum_sub_const_precision_1$ --
2630.55
-- !decimal_sum_sub_const_precision_2$ --
2628.55
-- !decimal_sum_sub_const_precision_3$ --
10 686.19
6 426.03
7 466.07
8 506.11
9 546.15
-- !decimal_sum_sub_const_precision_4$ --
10 685.744
6 425.584
7 465.624
8 505.664
9 545.704

View File

@ -0,0 +1,118 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("sumRewrite") {
sql "SET enable_nereids_planner=true"
sql "set runtime_filter_mode=OFF"
sql "SET enable_fallback_to_original_planner=false"
sql """
DROP TABLE IF EXISTS sr
"""
sql """
CREATE TABLE IF NOT EXISTS sr(
`id` int NULL,
`null_id` int not NULL,
`f_id` float NULL,
`d_id` decimal(10,2),
) ENGINE = OLAP
DISTRIBUTED BY HASH(id) BUCKETS 4
PROPERTIES (
"replication_allocation" = "tag.location.default: 1"
);
"""
sql """
INSERT INTO sr (id, null_id, f_id, d_id) VALUES
(11, 6, 1.1, 210.01),
(12, 6, 2.2, 220.02),
(13, 7, 3.3, 230.03),
(14, 7, 4.4, 240.04),
(15, 8, 5.5, 250.05),
(null, 8, 6.6, 260.06),
(null, 9, 7.7, 270.07),
(18, 9, 8.8, 280.08),
(19, 10, 9.9, 290.09),
(20, 10, 10.1, 400.10);
"""
order_qt_sum_add_const$ """ select sum(id + 2) from sr """
order_qt_sum_add_const_alias$ """ select sum(id + 2) as result from sr """
order_qt_sum_add_const_where$ """ select sum(id + 2) from sr where id is not null """
order_qt_sum_add_const_group_by$ """ select null_id, sum(id + 2) from sr group by null_id """
order_qt_sum_add_const_having$ """ select null_id, sum(id + 2) from sr group by null_id having sum(id + 2) > 5 """
order_qt_sum_sub_const$ """ select sum(id - 2) from sr """
order_qt_sum_sub_const_alias$ """ select sum(id - 2) as result from sr """
order_qt_sum_sub_const_where$ """ select sum(id - 2) from sr where id is not null """
order_qt_sum_sub_const_group_by$ """ select null_id, sum(id - 2) from sr group by null_id """
order_qt_sum_sub_const_having$ """ select null_id, sum(id - 2) from sr group by null_id having sum(id - 2) > 0 """
order_qt_sum_add_const_empty_table$ """ select sum(id + 2) from sr where 1=0 """
order_qt_sum_add_const_empty_table_group_by$ """ select null_id, sum(id + 2) from sr where 1=0 group by null_id """
order_qt_sum_sub_const_empty_table$ """ select sum(id - 2) from sr where 1=0 """
order_qt_sum_sub_const_empty_table_group_by$ """ select null_id, sum(id - 2) from sr where 1=0 group by null_id """
// float类型字段测试
order_qt_float_sum_add_const$ """ select sum(f_id + 2) from sr """
order_qt_float_sum_add_const_alias$ """ select sum(f_id + 2) as result from sr """
order_qt_float_sum_add_const_where$ """ select sum(f_id + 2) from sr where f_id is not null """
order_qt_float_sum_add_const_group_by$ """ select null_id, sum(f_id + 2) from sr group by null_id """
order_qt_float_sum_add_const_having$ """ select null_id, sum(f_id + 2) from sr group by null_id having sum(f_id + 2) > 5 """
order_qt_float_sum_sub_const$ """ select sum(f_id - 2) from sr """
order_qt_float_sum_sub_const_alias$ """ select sum(f_id - 2) as result from sr """
order_qt_float_sum_sub_const_where$ """ select sum(f_id - 2) from sr where f_id is not null """
order_qt_float_sum_sub_const_group_by$ """ select null_id, sum(f_id - 2) from sr group by null_id """
order_qt_float_sum_sub_const_having$ """ select null_id, sum(f_id - 2) from sr group by null_id having sum(f_id - 2) > 0 """
// 测试精度变化对sum加常数的影响
// order_qt_decimal_sum_add_const_precision_1$ """ select sum(d_id + 2) from sr """
// order_qt_decimal_sum_add_const_precision_2$ """ select sum(d_id + 2.2) from sr """
// order_qt_decimal_sum_add_const_precision_3$ """ select null_id, sum(d_id + 2) from sr group by null_id """
// order_qt_decimal_sum_add_const_precision_4$ """ select null_id, sum(d_id + 2.223) from sr group by null_id """
// 测试精度变化对sum减常数的影响
// order_qt_decimal_sum_sub_const_precision_1$ """ select sum(d_id - 2) from sr """
// order_qt_decimal_sum_sub_const_precision_2$ """ select sum(d_id - 2.2) from sr """
// order_qt_decimal_sum_sub_const_precision_3$ """ select null_id, sum(d_id - 2) from sr group by null_id """
// order_qt_decimal_sum_sub_const_precision_4$ """ select null_id, sum(d_id - 2.223) from sr group by null_id """
}