diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
index bb5e32fe8c..2509c114c4 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/batch/NereidsRewriter.java
@@ -72,6 +72,7 @@ import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinAggTransposeProjec
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinCommute;
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinLogicalJoinTranspose;
import org.apache.doris.nereids.rules.rewrite.logical.SemiJoinLogicalJoinTransposeProject;
+import org.apache.doris.nereids.rules.rewrite.logical.SimplifyAggGroupBy;
import org.apache.doris.nereids.rules.rewrite.logical.SplitLimit;
import com.google.common.collect.ImmutableList;
@@ -155,6 +156,9 @@ public class NereidsRewriter extends BatchRewriteJob {
// e.g. sum(sum(c1)) over(partition by avg(c1))
new NormalizeAggregate(),
new CheckAndStandardizeWindowFunctionAndFrame()
+ ),
+ topDown(
+ new SimplifyAggGroupBy()
)
),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
index 0d73777333..9db6e43f5d 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java
@@ -95,6 +95,7 @@ public enum RuleType {
EXTRACT_AND_NORMALIZE_WINDOW_EXPRESSIONS(RuleTypeClass.REWRITE),
CHECK_AND_STANDARDIZE_WINDOW_FUNCTION_AND_FRAME(RuleTypeClass.REWRITE),
AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
+ SIMPLIFY_AGG_GROUP_BY(RuleTypeClass.REWRITE),
DISTINCT_AGGREGATE_DISASSEMBLE(RuleTypeClass.REWRITE),
LOGICAL_SUB_QUERY_ALIAS_TO_LOGICAL_PROJECT(RuleTypeClass.REWRITE),
ELIMINATE_GROUP_BY_CONSTANT(RuleTypeClass.REWRITE),
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupBy.java
new file mode 100644
index 0000000000..b7fedba8b4
--- /dev/null
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupBy.java
@@ -0,0 +1,68 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.rules.Rule;
+import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory;
+import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Simplify Aggregate group by Multiple to One. For example
+ *
+ * GROUP BY ClientIP, ClientIP - 1, ClientIP - 2, ClientIP - 3
+ * -->
+ * GROUP BY ClientIP
+ */
+public class SimplifyAggGroupBy extends OneRewriteRuleFactory {
+ @Override
+ public Rule build() {
+ return logicalAggregate()
+ .when(agg -> agg.getGroupByExpressions().size() > 1)
+ .when(agg -> agg.getGroupByExpressions().stream().allMatch(this::isBinaryArithmeticSlot))
+ .then(agg -> {
+ Set slots = agg.getGroupByExpressions().stream()
+ .flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toSet());
+ if (slots.size() != 1) {
+ return null;
+ }
+ return agg.withGroupByAndOutput(ImmutableList.copyOf(slots), agg.getOutputExpressions());
+ })
+ .toRule(RuleType.SIMPLIFY_AGG_GROUP_BY);
+ }
+
+ private boolean isBinaryArithmeticSlot(Expression expr) {
+ if (expr instanceof Slot) {
+ return true;
+ }
+ if (!(expr instanceof BinaryArithmetic)) {
+ return false;
+ }
+ return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal
+ || ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal;
+ }
+}
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
index 8a4d509507..0d2772d768 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java
@@ -60,6 +60,10 @@ public abstract class Expression extends AbstractTreeNode implements
super(Optional.empty(), children);
}
+ public Alias alias(String alias) {
+ return new Alias(this, alias);
+ }
+
/**
* check input data types
*/
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 606f13f4ca..26c07f8526 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -111,12 +111,6 @@ public class ExpressionUtils {
}
}
- public static Set extractToSet(Expression predicate) {
- Set result = Sets.newHashSet();
- extract(predicate.getClass(), predicate, result);
- return result;
- }
-
public static Optional optionalAnd(List expressions) {
if (expressions.isEmpty()) {
return Optional.empty();
diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupByTest.java
new file mode 100644
index 0000000000..68512c2923
--- /dev/null
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/SimplifyAggGroupByTest.java
@@ -0,0 +1,113 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.rewrite.logical;
+
+import org.apache.doris.nereids.trees.expressions.Add;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.Multiply;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
+import org.apache.doris.nereids.util.LogicalPlanBuilder;
+import org.apache.doris.nereids.util.MemoPatternMatchSupported;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PlanChecker;
+import org.apache.doris.nereids.util.PlanConstructor;
+
+import com.google.common.collect.ImmutableList;
+import org.junit.jupiter.api.Test;
+
+import java.util.List;
+
+class SimplifyAggGroupByTest implements MemoPatternMatchSupported {
+ private static final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0);
+
+ @Test
+ void test() {
+ Slot id = scan1.getOutput().get(0);
+ List output = ImmutableList.of(
+ id,
+ new Add(id, Literal.of(1)).alias("id1"),
+ new Add(id, Literal.of(2)).alias("id2"),
+ new Add(id, Literal.of(3)).alias("id3"),
+ new Count().alias("count")
+ );
+ List groupBy = ImmutableList.of(
+ id,
+ new Add(id, Literal.of(1)),
+ new Add(id, Literal.of(2)),
+ new Add(id, Literal.of(3))
+ );
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .agg(groupBy, output)
+ .build();
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(new SimplifyAggGroupBy())
+ .matchesFromRoot(
+ logicalAggregate().when(a -> a.getGroupByExpressions().size() == 1)
+ );
+ }
+
+ @Test
+ void testSqrt() {
+ Slot id = scan1.getOutput().get(0);
+ List output = ImmutableList.of(
+ id,
+ new Multiply(id, id).alias("sqrt"),
+ new Count().alias("count")
+ );
+ List groupBy = ImmutableList.of(
+ id,
+ new Multiply(id, id)
+ );
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .agg(groupBy, output)
+ .build();
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(new SimplifyAggGroupBy())
+ .matchesFromRoot(
+ logicalAggregate().when(a -> a.equals(agg))
+ );
+ }
+
+ @Test
+ void testAbs() {
+ Slot id = scan1.getOutput().get(0);
+ List output = ImmutableList.of(
+ id,
+ new Abs(id).alias("abs"),
+ new Count().alias("count")
+ );
+ List groupBy = ImmutableList.of(
+ id,
+ new Abs(id)
+ );
+ LogicalPlan agg = new LogicalPlanBuilder(scan1)
+ .agg(groupBy, output)
+ .build();
+ PlanChecker.from(MemoTestUtils.createConnectContext(), agg)
+ .applyTopDown(new SimplifyAggGroupBy())
+ .matchesFromRoot(
+ logicalAggregate().when(a -> a.equals(agg))
+ );
+ }
+}