branch-2.1: [Fix](nereids) Fix incorrect results in GROUP BY with Modulo (%) operations #54153 (#54194)

Cherry-picked from #54153

Co-authored-by: Jensen <czjourney@163.com>
This commit is contained in:
github-actions[bot]
2025-08-06 11:48:16 +08:00
committed by GitHub
parent ebbfe1dfa2
commit 0910a57dbf
4 changed files with 71 additions and 2 deletions

View File

@ -20,13 +20,18 @@ package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.TreeNode;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.ImmutableSet;
import java.util.List;
@ -40,11 +45,15 @@ import java.util.Set;
* GROUP BY ClientIP
*/
public class SimplifyAggGroupBy extends OneRewriteRuleFactory {
private static final ImmutableSet<Class<? extends Expression>> supportedFunctions
= ImmutableSet.of(Add.class, Subtract.class, Multiply.class, Divide.class);
@Override
public Rule build() {
return logicalAggregate()
.when(agg -> agg.getGroupByExpressions().size() > 1
&& ExpressionUtils.allMatch(agg.getGroupByExpressions(), this::isBinaryArithmeticSlot))
&& ExpressionUtils.allMatch(agg.getGroupByExpressions(),
SimplifyAggGroupBy::isBinaryArithmeticSlot))
.then(agg -> {
List<Expression> groupByExpressions = agg.getGroupByExpressions();
ImmutableSet.Builder<Expression> inputSlots
@ -61,13 +70,17 @@ public class SimplifyAggGroupBy extends OneRewriteRuleFactory {
.toRule(RuleType.SIMPLIFY_AGG_GROUP_BY);
}
private boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
@VisibleForTesting
protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
if (expr instanceof Slot) {
return true;
}
if (!(expr instanceof BinaryArithmetic)) {
return false;
}
if (!supportedFunctions.contains(expr.getClass())) {
return false;
}
return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal
|| ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal;
}

View File

@ -18,10 +18,13 @@
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Mod;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
@ -35,6 +38,7 @@ import org.apache.doris.nereids.util.PlanConstructor;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
@ -132,4 +136,24 @@ class SimplifyAggGroupByTest implements MemoPatternMatchSupported {
logicalProject(logicalAggregate().when(a -> a.getGroupByExpressions().size() == 2))
);
}
@Test
void testisBinaryArithmeticSlot() {
Slot id = scan1.getOutput().get(0);
Mod mod = new Mod(id, Literal.of(2));
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(mod));
Add add = new Add(id, Literal.of(2));
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(add));
Subtract subtract = new Subtract(id, Literal.of(2));
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(subtract));
Multiply multiply = new Multiply(id, Literal.of(2));
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(multiply));
Divide divide = new Divide(id, Literal.of(2));
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide));
}
}

View File

@ -0,0 +1,9 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !aggregate --
0 0
0 1
1 0
1 1
2 0
2 1

View File

@ -0,0 +1,23 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("aggregate_groupby_simplify") {
sql "SET enable_nereids_planner=true"
sql "SET enable_fallback_to_original_planner=false"
qt_aggregate "select number % 3 as a, number % 2 as b from numbers('number' = '10') group by a, b order by a, b;"
}