From 24d236b210d956cde48a187958057a578838f03a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E8=B0=A2=E5=81=A5?= Date: Mon, 1 Jul 2024 14:57:15 +0800 Subject: [PATCH] [feat](Nereids) Optimize Sum Literal Rewriting by Excluding Single Instances (#35559) (#37047) pick from master #35559 This PR introduces a change in the method removeOneSumLiteral to enhance the performance of sum literal rewriting in SQL queries. The modification ensures that sum literals appearing only once, such as in expressions like select count(id1 + 1), count(id2 + 1) from t, are not rewritten. --- .../rules/rewrite/SumLiteralRewrite.java | 25 +++++++++++++-- .../rules/rewrite/SumLiteralRewriteTest.java | 31 +++++++++++++++++++ 2 files changed, 54 insertions(+), 2 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java index c99071a714..dcc64ce2c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewrite.java @@ -44,6 +44,7 @@ import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Map.Entry; import java.util.Objects; import java.util.Set; @@ -64,13 +65,33 @@ public class SumLiteralRewrite extends OneRewriteRuleFactory { } sumLiteralMap.put(pel.first, pel.second); } - if (sumLiteralMap.isEmpty()) { + Map> validSumLiteralMap = + removeOneSumLiteral(sumLiteralMap); + if (validSumLiteralMap.isEmpty()) { return null; } - return rewriteSumLiteral(agg, sumLiteralMap); + return rewriteSumLiteral(agg, validSumLiteralMap); }).toRule(RuleType.SUM_LITERAL_REWRITE); } + // when there only one sum literal like select count(id1 + 1), count(id2 + 1) from t, we don't rewrite them. + private Map> removeOneSumLiteral( + Map> sumLiteralMap) { + Map countSum = new HashMap<>(); + for (Entry> e : sumLiteralMap.entrySet()) { + Expression expr = e.getValue().first.expr; + countSum.merge(expr, 1, Integer::sum); + } + Map> validSumLiteralMap = new HashMap<>(); + for (Entry> e : sumLiteralMap.entrySet()) { + Expression expr = e.getValue().first.expr; + if (countSum.get(expr) > 1) { + validSumLiteralMap.put(e.getKey(), e.getValue()); + } + } + return validSumLiteralMap; + } + private Plan rewriteSumLiteral( LogicalAggregate agg, Map> sumLiteralMap) { Set newAggOutput = new HashSet<>(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java index cb2cc77627..19ea7b864f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SumLiteralRewriteTest.java @@ -112,4 +112,35 @@ class SumLiteralRewriteTest implements MemoPatternMatchSupported { .printlnTree() .matches(logicalAggregate().when(p -> p.getOutputs().size() == 4)); } + + @Test + void testSumOnce() { + Slot slot1 = scan1.getOutput().get(0); + Alias add1 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(1)))); + LogicalAggregate agg = new LogicalAggregate<>( + ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + .matches(logicalAggregate().when(p -> p.getOutputs().size() == 1)); + + Slot slot2 = new Alias(scan1.getOutput().get(0)).toSlot(); + Alias add2 = new Alias(new Sum(false, true, new Add(slot2, Literal.of(2)))); + agg = new LogicalAggregate<>( + ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1, add2), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + .matches(logicalAggregate().when(p -> p.getOutputs().size() == 2)); + + Alias add3 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(3)))); + Alias add4 = new Alias(new Sum(false, true, new Add(slot1, Literal.of(4)))); + agg = new LogicalAggregate<>( + ImmutableList.of(scan1.getOutput().get(0)), ImmutableList.of(add1, add2, add3, add4), scan1); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(ImmutableList.of(new SumLiteralRewrite().build())) + .printlnTree() + .matches(logicalAggregate().when(p -> p.getOutputs().size() == 3)); + + } }