branch-2.1: [Fix](nereids) Fix incorrect results in GROUP BY with Modulo (%) operations #54153 (#54194)
Cherry-picked from #54153 Co-authored-by: Jensen <czjourney@163.com>
This commit is contained in:
committed by
GitHub
parent
ebbfe1dfa2
commit
0910a57dbf
@ -20,13 +20,18 @@ package org.apache.doris.nereids.rules.rewrite;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.TreeNode;
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.BinaryArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.Divide;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Multiply;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.Subtract;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.util.ExpressionUtils;
|
||||
import org.apache.doris.nereids.util.Utils;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
|
||||
import java.util.List;
|
||||
@ -40,11 +45,15 @@ import java.util.Set;
|
||||
* GROUP BY ClientIP
|
||||
*/
|
||||
public class SimplifyAggGroupBy extends OneRewriteRuleFactory {
|
||||
private static final ImmutableSet<Class<? extends Expression>> supportedFunctions
|
||||
= ImmutableSet.of(Add.class, Subtract.class, Multiply.class, Divide.class);
|
||||
|
||||
@Override
|
||||
public Rule build() {
|
||||
return logicalAggregate()
|
||||
.when(agg -> agg.getGroupByExpressions().size() > 1
|
||||
&& ExpressionUtils.allMatch(agg.getGroupByExpressions(), this::isBinaryArithmeticSlot))
|
||||
&& ExpressionUtils.allMatch(agg.getGroupByExpressions(),
|
||||
SimplifyAggGroupBy::isBinaryArithmeticSlot))
|
||||
.then(agg -> {
|
||||
List<Expression> groupByExpressions = agg.getGroupByExpressions();
|
||||
ImmutableSet.Builder<Expression> inputSlots
|
||||
@ -61,13 +70,17 @@ public class SimplifyAggGroupBy extends OneRewriteRuleFactory {
|
||||
.toRule(RuleType.SIMPLIFY_AGG_GROUP_BY);
|
||||
}
|
||||
|
||||
private boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
|
||||
@VisibleForTesting
|
||||
protected static boolean isBinaryArithmeticSlot(TreeNode<Expression> expr) {
|
||||
if (expr instanceof Slot) {
|
||||
return true;
|
||||
}
|
||||
if (!(expr instanceof BinaryArithmetic)) {
|
||||
return false;
|
||||
}
|
||||
if (!supportedFunctions.contains(expr.getClass())) {
|
||||
return false;
|
||||
}
|
||||
return ExpressionUtils.isSlotOrCastOnSlot(expr.child(0)).isPresent() && expr.child(1) instanceof Literal
|
||||
|| ExpressionUtils.isSlotOrCastOnSlot(expr.child(1)).isPresent() && expr.child(0) instanceof Literal;
|
||||
}
|
||||
|
||||
@ -18,10 +18,13 @@
|
||||
package org.apache.doris.nereids.rules.rewrite;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Add;
|
||||
import org.apache.doris.nereids.trees.expressions.Divide;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.Mod;
|
||||
import org.apache.doris.nereids.trees.expressions.Multiply;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.Slot;
|
||||
import org.apache.doris.nereids.trees.expressions.Subtract;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.Abs;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
@ -35,6 +38,7 @@ import org.apache.doris.nereids.util.PlanConstructor;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
@ -132,4 +136,24 @@ class SimplifyAggGroupByTest implements MemoPatternMatchSupported {
|
||||
logicalProject(logicalAggregate().when(a -> a.getGroupByExpressions().size() == 2))
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
void testisBinaryArithmeticSlot() {
|
||||
Slot id = scan1.getOutput().get(0);
|
||||
|
||||
Mod mod = new Mod(id, Literal.of(2));
|
||||
Assertions.assertFalse(SimplifyAggGroupBy.isBinaryArithmeticSlot(mod));
|
||||
|
||||
Add add = new Add(id, Literal.of(2));
|
||||
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(add));
|
||||
|
||||
Subtract subtract = new Subtract(id, Literal.of(2));
|
||||
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(subtract));
|
||||
|
||||
Multiply multiply = new Multiply(id, Literal.of(2));
|
||||
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(multiply));
|
||||
|
||||
Divide divide = new Divide(id, Literal.of(2));
|
||||
Assertions.assertTrue(SimplifyAggGroupBy.isBinaryArithmeticSlot(divide));
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !aggregate --
|
||||
0 0
|
||||
0 1
|
||||
1 0
|
||||
1 1
|
||||
2 0
|
||||
2 1
|
||||
|
||||
@ -0,0 +1,23 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
suite("aggregate_groupby_simplify") {
|
||||
sql "SET enable_nereids_planner=true"
|
||||
sql "SET enable_fallback_to_original_planner=false"
|
||||
|
||||
qt_aggregate "select number % 3 as a, number % 2 as b from numbers('number' = '10') group by a, b order by a, b;"
|
||||
}
|
||||
Reference in New Issue
Block a user