diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java index bb5e32fe8c..2509c114c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java @@ -72,6 +72,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinAggTransposeProjec import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinCommute; import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinLogicalJoinTranspose; import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinLogicalJoinTransposeProject; +import org.apache.doris.nereids.rules.rewrite.logical.SimplifyAggGroupBy; import org.apache.doris.nereids.rules.rewrite.logical.SplitLimit; import com.google.common.collect.ImmutableList; @@ -155,6 +156,9 @@ public class NereidsRewriter extends BatchRewriteJob { // e.g. sum(sum(c1)) over(partition by avg(c1)) new NormalizeAggregate(), new CheckAndStandardizeWindowFunctionAndFrame() + ), + topDown( + new SimplifyAggGroupBy() ) ), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 0d73777333..9db6e43f5d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -95,6 +95,7 @@ public enum RuleType { EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS(RuleTypeClass.REWRITE), CHECK_AND_STANDARDIZE_WINDOW_FUNCTION_AND_FRAME(RuleTypeClass.REWRITE), AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE), + SIMPLIFY_AGG_GROUP_BY(RuleTypeClass.REWRITE), DISTINCT_AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE), LOGICAL_SUB_QUERY_ALIAS_TO_LOGICAL_PROJECT(RuleTypeClass.REWRITE), ELIMINATE_GROUP_BY_CONSTANT(RuleTypeClass.REWRITE), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupBy.java new file mode 100644 index 0000000000..b7fedba8b4 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupBy.java @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; + +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Simplify Aggregate group by Multiple to One. For example + *

+ * GROUP BY ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3 + * --> + * GROUP BY ClientIP + */ +public class SimplifyAggGroupBy extends OneRewriteRuleFactory { + @Override + public Rule build() { + return logicalAggregate() + .when(agg -> agg.getGroupByExpressions().size() > 1) + .when(agg -> agg.getGroupByExpressions().stream().allMatch(this::isBinaryArithmeticSlot)) + .then(agg -> { + Set slots = agg.getGroupByExpressions().stream() + .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet()); + if (slots.size() != 1) { + return null; + } + return agg.withGroupByAndOutput(ImmutableList.copyOf(slots), agg.getOutputExpressions()); + }) + .toRule(RuleType.SIMPLIFY_AGG_GROUP_BY); + } + + private boolean isBinaryArithmeticSlot(Expression expr) { + if (expr instanceof Slot) { + return true; + } + if (!(expr instanceof BinaryArithmetic)) { + return false; + } + return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal + || ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java index 8a4d509507..0d2772d768 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java @@ -60,6 +60,10 @@ public abstract class Expression extends AbstractTreeNode implements super(Optional.empty(), children); } + public Alias alias(String alias) { + return new Alias(this, alias); + } + /** * check input data types */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 606f13f4ca..26c07f8526 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -111,12 +111,6 @@ public class ExpressionUtils { } } - public static Set extractToSet(Expression predicate) { - Set result = Sets.newHashSet(); - extract(predicate.getClass(), predicate, result); - return result; - } - public static Optional optionalAnd(List expressions) { if (expressions.isEmpty()) { return Optional.empty(); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupByTest.java new file mode 100644 index 0000000000..68512c2923 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupByTest.java @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Multiply; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoPatternMatchSupported; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Test; + +import java.util.List; + +class SimplifyAggGroupByTest implements MemoPatternMatchSupported { + private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + + @Test + void test() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of( + id, + new Add(id, Literal.of(1)).alias("id1"), + new Add(id, Literal.of(2)).alias("id2"), + new Add(id, Literal.of(3)).alias("id3"), + new Count().alias("count") + ); + List groupBy = ImmutableList.of( + id, + new Add(id, Literal.of(1)), + new Add(id, Literal.of(2)), + new Add(id, Literal.of(3)) + ); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.getGroupByExpressions().size() == 1) + ); + } + + @Test + void testSqrt() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of( + id, + new Multiply(id, id).alias("sqrt"), + new Count().alias("count") + ); + List groupBy = ImmutableList.of( + id, + new Multiply(id, id) + ); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } + + @Test + void testAbs() { + Slot id = scan1.getOutput().get(0); + List output = ImmutableList.of( + id, + new Abs(id).alias("abs"), + new Count().alias("count") + ); + List groupBy = ImmutableList.of( + id, + new Abs(id) + ); + LogicalPlan agg = new LogicalPlanBuilder(scan1) + .agg(groupBy, output) + .build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), agg) + .applyTopDown(new SimplifyAggGroupBy()) + .matchesFromRoot( + logicalAggregate().when(a -> a.equals(agg)) + ); + } +}