[opt](Nereids) simplify decimalv3 comparison predicate (#18975)
1. fix constant folding failed on decimalv3 type 2. support reduce decimalv3 literal precision in comparison predicate 3. support fe config enable_decimal_conversion
This commit is contained in:
@ -20,6 +20,7 @@ package org.apache.doris.nereids.parser;
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.analysis.SetType;
|
||||
import org.apache.doris.analysis.UserIdentity;
|
||||
import org.apache.doris.common.Config;
|
||||
import org.apache.doris.common.DdlException;
|
||||
import org.apache.doris.common.Pair;
|
||||
import org.apache.doris.nereids.DorisParser;
|
||||
@ -185,6 +186,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Interval;
|
||||
@ -1742,8 +1744,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> {
|
||||
}
|
||||
|
||||
@Override
|
||||
public DecimalLiteral visitDecimalLiteral(DecimalLiteralContext ctx) {
|
||||
return new DecimalLiteral(new BigDecimal(ctx.getText()));
|
||||
public Literal visitDecimalLiteral(DecimalLiteralContext ctx) {
|
||||
if (Config.enable_decimal_conversion) {
|
||||
return new DecimalV3Literal(new BigDecimal(ctx.getText()));
|
||||
} else {
|
||||
return new DecimalLiteral(new BigDecimal(ctx.getText()));
|
||||
}
|
||||
}
|
||||
|
||||
private String parseTVFPropertyItem(TvfPropertyItemContext item) {
|
||||
|
||||
@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression;
|
||||
import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison;
|
||||
import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -34,6 +35,7 @@ public class ExpressionOptimization extends ExpressionRewrite {
|
||||
ExtractCommonFactorRule.INSTANCE,
|
||||
DistinctPredicatesRule.INSTANCE,
|
||||
SimplifyComparisonPredicate.INSTANCE,
|
||||
SimplifyDecimalV3Comparison.INSTANCE,
|
||||
SimplifyRange.INSTANCE
|
||||
);
|
||||
private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES);
|
||||
|
||||
@ -79,6 +79,7 @@ import java.util.Objects;
|
||||
* evaluate an expression on fe.
|
||||
*/
|
||||
public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE();
|
||||
|
||||
@Override
|
||||
|
||||
@ -0,0 +1,77 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.expression.rules;
|
||||
|
||||
import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule;
|
||||
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
/**
|
||||
* if we have a column with decimalv3 type and set enable_decimal_conversion = false.
|
||||
* we have a column named col1 with type decimalv3(15, 2)
|
||||
* and we have a comparison like col1 > 0.5 + 0.1
|
||||
* then the result type of 0.5 + 0.1 is decimalv2(27, 9)
|
||||
* and the col1 need to convert to decimalv3(27, 9) to match the precision of right hand
|
||||
* this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6
|
||||
*/
|
||||
public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule {
|
||||
|
||||
public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison();
|
||||
|
||||
@Override
|
||||
public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) {
|
||||
Expression left = rewrite(cp.left(), context);
|
||||
Expression right = rewrite(cp.right(), context);
|
||||
|
||||
if (left.getDataType() instanceof DecimalV3Type
|
||||
&& left instanceof Cast
|
||||
&& ((Cast) left).child().getDataType() instanceof DecimalV3Type
|
||||
&& right instanceof DecimalV3Literal) {
|
||||
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
|
||||
}
|
||||
|
||||
if (left != cp.left() || right != cp.right()) {
|
||||
return cp.withChildren(left, right);
|
||||
} else {
|
||||
return cp;
|
||||
}
|
||||
}
|
||||
|
||||
private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
|
||||
BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
|
||||
int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
|
||||
int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
|
||||
Expression castChild = left.child();
|
||||
Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type);
|
||||
DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
|
||||
// precision and scale of literal must all smaller than left, otherwise we need to do cast on right.
|
||||
Preconditions.checkState(scale <= leftType.getScale(), "right scale should not greater than left");
|
||||
Preconditions.checkState(precision <= leftType.getPrecision(), "right precision should not greater than left");
|
||||
DecimalV3Literal newRight = new DecimalV3Literal(
|
||||
DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()), trailingZerosValue);
|
||||
return cp.withChildren(castChild, newRight);
|
||||
}
|
||||
}
|
||||
@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
|
||||
import com.google.common.collect.ImmutableMultimap;
|
||||
|
||||
@ -104,7 +105,7 @@ public enum ExpressionEvaluator {
|
||||
}
|
||||
boolean match = true;
|
||||
for (int i = 0; i < candidateTypes.length; i++) {
|
||||
if (!candidateTypes[i].equals(expectedTypes[i])) {
|
||||
if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) {
|
||||
match = false;
|
||||
break;
|
||||
}
|
||||
@ -142,7 +143,11 @@ public enum ExpressionEvaluator {
|
||||
DataType returnType = DataType.convertFromString(annotation.returnType());
|
||||
List<DataType> argTypes = new ArrayList<>();
|
||||
for (String type : annotation.argTypes()) {
|
||||
argTypes.add(DataType.convertFromString(type));
|
||||
if (type.equalsIgnoreCase("DECIMALV3")) {
|
||||
argTypes.add(DecimalV3Type.WILDCARD);
|
||||
} else {
|
||||
argTypes.add(DataType.convertFromString(type));
|
||||
}
|
||||
}
|
||||
FunctionSignature signature = new FunctionSignature(name,
|
||||
argTypes.toArray(new DataType[argTypes.size()]), returnType);
|
||||
|
||||
@ -39,7 +39,9 @@ public class DecimalV3Literal extends Literal {
|
||||
|
||||
public DecimalV3Literal(DecimalV3Type dataType, BigDecimal value) {
|
||||
super(DecimalV3Type.createDecimalV3Type(dataType.getPrecision(), dataType.getScale()));
|
||||
this.value = Objects.requireNonNull(value.setScale(dataType.getScale(), RoundingMode.DOWN));
|
||||
Objects.requireNonNull(value, "value not be null");
|
||||
BigDecimal adjustedValue = value.scale() < 0 ? value : value.setScale(dataType.getScale(), RoundingMode.DOWN);
|
||||
this.value = Objects.requireNonNull(adjustedValue);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
Reference in New Issue
Block a user