From 0910a57dbf131dc5f19f45b8b6f8a5b58da58357 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 6 Aug 2025 11:48:16 +0800 Subject: [PATCH] branch-2.1: [Fix](nereids) Fix incorrect results in GROUP BY with Modulo (%) operations #54153 (#54194) Cherry-picked from #54153 Co-authored-by: Jensen --- .../rules/rewrite/SimplifyAggGroupBy.java | 17 +++++++++++-- .../rules/rewrite/SimplifyAggGroupByTest.java | 24 +++++++++++++++++++ .../aggregate/aggregate_groupby_simplify.out | 9 +++++++ .../aggregate_groupby_simplify.groovy | 23 ++++++++++++++++++ 4 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 regression-test/data/nereids_p0/aggregate/aggregate_groupby_simplify.out create mode 100644 regression-test/suites/nereids_p0/aggregate/aggregate_groupby_simplify.groovy diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java index 6dc446d88c..37d4d4806f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupBy.java @@ -20,13 +20,18 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.TreeNode; +import org.apache.doris.nereids.trees.expressions.Add; import org.apache.doris.nereids.trees.expressions.BinaryArithmetic; +import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; +import com.google.common.annotations.VisibleForTesting; import com.google.common.collect.ImmutableSet; import java.util.List; @@ -40,11 +45,15 @@ import java.util.Set; * GROUP BY ClientIP */ public class SimplifyAggGroupBy extends OneRewriteRuleFactory { + private static final ImmutableSet> supportedFunctions + = ImmutableSet.of(Add.class, Subtract.class, Multiply.class, Divide.class); + @Override public Rule build() { return logicalAggregate() .when(agg -> agg.getGroupByExpressions().size() > 1 - && ExpressionUtils.allMatch(agg.getGroupByExpressions(), this::isBinaryArithmeticSlot)) + && ExpressionUtils.allMatch(agg.getGroupByExpressions(), + SimplifyAggGroupBy::isBinaryArithmeticSlot)) .then(agg -> { List groupByExpressions = agg.getGroupByExpressions(); ImmutableSet.Builder inputSlots @@ -61,13 +70,17 @@ public class SimplifyAggGroupBy extends OneRewriteRuleFactory { .toRule(RuleType.SIMPLIFY_AGG_GROUP_BY); } - private boolean isBinaryArithmeticSlot(TreeNode expr) { + @VisibleForTesting + protected static boolean isBinaryArithmeticSlot(TreeNode expr) { if (expr instanceof Slot) { return true; } if (!(expr instanceof BinaryArithmetic)) { return false; } + if (!supportedFunctions.contains(expr.getClass())) { + return false; + } return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal || ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java index 34c3b012e7..32c2cc4356 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/SimplifyAggGroupByTest.java @@ -18,10 +18,13 @@ package org.apache.doris.nereids.rules.rewrite; import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Mod; import org.apache.doris.nereids.trees.expressions.Multiply; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.Subtract; import org.apache.doris.nereids.trees.expressions.functions.agg.Count; import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs; import org.apache.doris.nereids.trees.expressions.literal.Literal; @@ -35,6 +38,7 @@ import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import java.util.List; @@ -132,4 +136,24 @@ class SimplifyAggGroupByTest implements MemoPatternMatchSupported { logicalProject(logicalAggregate().when(a -> a.getGroupByExpressions().size() == 2)) ); } + + @Test + void testisBinaryArithmeticSlot() { + Slot id = scan1.getOutput().get(0); + + Mod mod = new Mod(id, Literal.of(2)); + Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(mod)); + + Add add = new Add(id, Literal.of(2)); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(add)); + + Subtract subtract = new Subtract(id, Literal.of(2)); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(subtract)); + + Multiply multiply = new Multiply(id, Literal.of(2)); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(multiply)); + + Divide divide = new Divide(id, Literal.of(2)); + Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide)); + } } diff --git a/regression-test/data/nereids_p0/aggregate/aggregate_groupby_simplify.out b/regression-test/data/nereids_p0/aggregate/aggregate_groupby_simplify.out new file mode 100644 index 0000000000..0a2425d53f --- /dev/null +++ b/regression-test/data/nereids_p0/aggregate/aggregate_groupby_simplify.out @@ -0,0 +1,9 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !aggregate -- +0 0 +0 1 +1 0 +1 1 +2 0 +2 1 + diff --git a/regression-test/suites/nereids_p0/aggregate/aggregate_groupby_simplify.groovy b/regression-test/suites/nereids_p0/aggregate/aggregate_groupby_simplify.groovy new file mode 100644 index 0000000000..0951d9dd2d --- /dev/null +++ b/regression-test/suites/nereids_p0/aggregate/aggregate_groupby_simplify.groovy @@ -0,0 +1,23 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("aggregate_groupby_simplify") { + sql "SET enable_nereids_planner=true" + sql "SET enable_fallback_to_original_planner=false" + + qt_aggregate "select number % 3 as a, number % 2 as b from numbers('number' = '10') group by a, b order by a, b;" +}