[opt](Nereids) simplify decimalv3 comparison predicate (#18975)

1. fix constant folding failed on decimalv3 type
2. support reduce decimalv3 literal precision in comparison predicate
3. support fe config enable_decimal_conversion
This commit is contained in:
morrySnow
2023-04-26 23:57:09 +08:00
committed by GitHub
parent 925efc1902
commit ae252d1cfa
7 changed files with 119 additions and 5 deletions

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.parser;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.analysis.SetType;
import org.apache.doris.analysis.UserIdentity;
import org.apache.doris.common.Config;
import org.apache.doris.common.DdlException;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.DorisParser;
@ -185,6 +186,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Interval;
@ -1742,8 +1744,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
}
@Override
public DecimalLiteral visitDecimalLiteral(DecimalLiteralContext ctx) {
return new DecimalLiteral(new BigDecimal(ctx.getText()));
public Literal visitDecimalLiteral(DecimalLiteralContext ctx) {
if (Config.enable_decimal_conversion) {
return new DecimalV3Literal(new BigDecimal(ctx.getText()));
} else {
return new DecimalLiteral(new BigDecimal(ctx.getText()));
}
}
private String parseTVFPropertyItem(TvfPropertyItemContext item) {

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression;
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
import com.google.common.collect.ImmutableList;
@ -34,6 +35,7 @@ public class ExpressionOptimization extends ExpressionRewrite {
ExtractCommonFactorRule.INSTANCE,
DistinctPredicatesRule.INSTANCE,
SimplifyComparisonPredicate.INSTANCE,
SimplifyDecimalV3Comparison.INSTANCE,
SimplifyRange.INSTANCE
);
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);

View File

@ -79,6 +79,7 @@ import java.util.Objects;
* evaluate an expression on fe.
*/
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE();
@Override

View File

@ -0,0 +1,77 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.types.DecimalV3Type;
import com.google.common.base.Preconditions;
import java.math.BigDecimal;
/**
* if we have a column with decimalv3 type and set enable_decimal_conversion = false.
* we have a column named col1 with type decimalv3(15, 2)
* and we have a comparison like col1 > 0.5 + 0.1
* then the result type of 0.5 + 0.1 is decimalv2(27, 9)
* and the col1 need to convert to decimalv3(27, 9) to match the precision of right hand
* this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6
*/
public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule {
public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison();
@Override
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
Expression left = rewrite(cp.left(), context);
Expression right = rewrite(cp.right(), context);
if (left.getDataType() instanceof DecimalV3Type
&& left instanceof Cast
&& ((Cast) left).child().getDataType() instanceof DecimalV3Type
&& right instanceof DecimalV3Literal) {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
}
if (left != cp.left() || right != cp.right()) {
return cp.withChildren(left, right);
} else {
return cp;
}
}
private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
Expression castChild = left.child();
Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type);
DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
// precision and scale of literal must all smaller than left, otherwise we need to do cast on right.
Preconditions.checkState(scale <= leftType.getScale(), "right scale should not greater than left");
Preconditions.checkState(precision <= leftType.getPrecision(), "right precision should not greater than left");
DecimalV3Literal newRight = new DecimalV3Literal(
DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()), trailingZerosValue);
return cp.withChildren(castChild, newRight);
}
}

View File

@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
import com.google.common.collect.ImmutableMultimap;
@ -104,7 +105,7 @@ public enum ExpressionEvaluator {
}
boolean match = true;
for (int i = 0; i < candidateTypes.length; i++) {
if (!candidateTypes[i].equals(expectedTypes[i])) {
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
match = false;
break;
}
@ -142,7 +143,11 @@ public enum ExpressionEvaluator {
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(DataType.convertFromString(type));
if (type.equalsIgnoreCase("DECIMALV3")) {
argTypes.add(DecimalV3Type.WILDCARD);
} else {
argTypes.add(DataType.convertFromString(type));
}
}
FunctionSignature signature = new FunctionSignature(name,
argTypes.toArray(new DataType[argTypes.size()]), returnType);

View File

@ -39,7 +39,9 @@ public class DecimalV3Literal extends Literal {
public DecimalV3Literal(DecimalV3Type dataType, BigDecimal value) {
super(DecimalV3Type.createDecimalV3Type(dataType.getPrecision(), dataType.getScale()));
this.value = Objects.requireNonNull(value.setScale(dataType.getScale(), RoundingMode.DOWN));
Objects.requireNonNull(value, "value not be null");
BigDecimal adjustedValue = value.scale() < 0 ? value : value.setScale(dataType.getScale(), RoundingMode.DOWN);
this.value = Objects.requireNonNull(adjustedValue);
}
@Override