diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index f22ff8d63c..8f265cff78 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.parser; import org.apache.doris.analysis.ArithmeticExpr.Operator; import org.apache.doris.analysis.SetType; import org.apache.doris.analysis.UserIdentity; +import org.apache.doris.common.Config; import org.apache.doris.common.DdlException; import org.apache.doris.common.Pair; import org.apache.doris.nereids.DorisParser; @@ -185,6 +186,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Interval; @@ -1742,8 +1744,12 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor { } @Override - public DecimalLiteral visitDecimalLiteral(DecimalLiteralContext ctx) { - return new DecimalLiteral(new BigDecimal(ctx.getText())); + public Literal visitDecimalLiteral(DecimalLiteralContext ctx) { + if (Config.enable_decimal_conversion) { + return new DecimalV3Literal(new BigDecimal(ctx.getText())); + } else { + return new DecimalLiteral(new BigDecimal(ctx.getText())); + } } private String parseTVFPropertyItem(TvfPropertyItemContext item) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java index 47840f226e..50c0af4402 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionOptimization.java @@ -20,6 +20,7 @@ package org.apache.doris.nereids.rules.expression; import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; +import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison; import org.apache.doris.nereids.rules.expression.rules.SimplifyRange; import com.google.common.collect.ImmutableList; @@ -34,6 +35,7 @@ public class ExpressionOptimization extends ExpressionRewrite { ExtractCommonFactorRule.INSTANCE, DistinctPredicatesRule.INSTANCE, SimplifyComparisonPredicate.INSTANCE, + SimplifyDecimalV3Comparison.INSTANCE, SimplifyRange.INSTANCE ); private static final ExpressionRuleExecutor EXECUTOR = new ExpressionRuleExecutor(OPTIMIZE_REWRITE_RULES); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java index c83d3029ce..d729273a64 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/FoldConstantRuleOnFE.java @@ -79,6 +79,7 @@ import java.util.Objects; * evaluate an expression on fe. */ public class FoldConstantRuleOnFE extends AbstractExpressionRewriteRule { + public static final FoldConstantRuleOnFE INSTANCE = new FoldConstantRuleOnFE(); @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java new file mode 100644 index 0000000000..93021f0b58 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyDecimalV3Comparison.java @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.AbstractExpressionRewriteRule; +import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext; +import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; +import org.apache.doris.nereids.types.DecimalV3Type; + +import com.google.common.base.Preconditions; + +import java.math.BigDecimal; + +/** + * if we have a column with decimalv3 type and set enable_decimal_conversion = false. + * we have a column named col1 with type decimalv3(15, 2) + * and we have a comparison like col1 > 0.5 + 0.1 + * then the result type of 0.5 + 0.1 is decimalv2(27, 9) + * and the col1 need to convert to decimalv3(27, 9) to match the precision of right hand + * this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6 + */ +public class SimplifyDecimalV3Comparison extends AbstractExpressionRewriteRule { + + public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison(); + + @Override + public Expression visitComparisonPredicate(ComparisonPredicate cp, ExpressionRewriteContext context) { + Expression left = rewrite(cp.left(), context); + Expression right = rewrite(cp.right(), context); + + if (left.getDataType() instanceof DecimalV3Type + && left instanceof Cast + && ((Cast) left).child().getDataType() instanceof DecimalV3Type + && right instanceof DecimalV3Literal) { + return doProcess(cp, (Cast) left, (DecimalV3Literal) right); + } + + if (left != cp.left() || right != cp.right()) { + return cp.withChildren(left, right); + } else { + return cp; + } + } + + private Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) { + BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros(); + int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue); + int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue); + Expression castChild = left.child(); + Preconditions.checkState(castChild.getDataType() instanceof DecimalV3Type); + DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType(); + // precision and scale of literal must all smaller than left, otherwise we need to do cast on right. + Preconditions.checkState(scale <= leftType.getScale(), "right scale should not greater than left"); + Preconditions.checkState(precision <= leftType.getPrecision(), "right precision should not greater than left"); + DecimalV3Literal newRight = new DecimalV3Literal( + DecimalV3Type.createDecimalV3Type(leftType.getPrecision(), leftType.getScale()), trailingZerosValue); + return cp.withChildren(castChild, newRight); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java index 0eb013614a..f68309f26d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/ExpressionEvaluator.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunctio import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV3Type; import com.google.common.collect.ImmutableMultimap; @@ -104,7 +105,7 @@ public enum ExpressionEvaluator { } boolean match = true; for (int i = 0; i < candidateTypes.length; i++) { - if (!candidateTypes[i].equals(expectedTypes[i])) { + if (!(expectedTypes[i].toCatalogDataType().matchesType(candidateTypes[i].toCatalogDataType()))) { match = false; break; } @@ -142,7 +143,11 @@ public enum ExpressionEvaluator { DataType returnType = DataType.convertFromString(annotation.returnType()); List argTypes = new ArrayList<>(); for (String type : annotation.argTypes()) { - argTypes.add(DataType.convertFromString(type)); + if (type.equalsIgnoreCase("DECIMALV3")) { + argTypes.add(DecimalV3Type.WILDCARD); + } else { + argTypes.add(DataType.convertFromString(type)); + } } FunctionSignature signature = new FunctionSignature(name, argTypes.toArray(new DataType[argTypes.size()]), returnType); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java index 18c13c4ac5..c6c28fa3e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java @@ -39,7 +39,9 @@ public class DecimalV3Literal extends Literal { public DecimalV3Literal(DecimalV3Type dataType, BigDecimal value) { super(DecimalV3Type.createDecimalV3Type(dataType.getPrecision(), dataType.getScale())); - this.value = Objects.requireNonNull(value.setScale(dataType.getScale(), RoundingMode.DOWN)); + Objects.requireNonNull(value, "value not be null"); + BigDecimal adjustedValue = value.scale() < 0 ? value : value.setScale(dataType.getScale(), RoundingMode.DOWN); + this.value = Objects.requireNonNull(adjustedValue); } @Override diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index 0dbaf6d30c..65d9e8c6ac 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule; import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; +import org.apache.doris.nereids.rules.expression.rules.SimplifyDecimalV3Comparison; import org.apache.doris.nereids.rules.expression.rules.SimplifyNotExprRule; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.EqualTo; @@ -38,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal; import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; @@ -46,6 +48,7 @@ import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; @@ -283,4 +286,22 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { new EqualTo(dv2, dv2)); } + + @Test + public void testSimplifyDecimalV3Comparison() { + executor = new ExpressionRuleExecutor(ImmutableList.of(SimplifyDecimalV3Comparison.INSTANCE)); + + // do rewrite + Expression left = new DecimalV3Literal(new BigDecimal("12345.67")); + Expression cast = new Cast(left, DecimalV3Type.createDecimalV3Type(27, 9)); + Expression right = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(27, 9), new BigDecimal("0.01")); + Expression expectedRight = new DecimalV3Literal(DecimalV3Type.createDecimalV3Type(7, 2), new BigDecimal("0.01")); + Expression comparison = new EqualTo(cast, right); + Expression expected = new EqualTo(left, expectedRight); + assertRewrite(comparison, expected); + + // not cast + comparison = new EqualTo(new DecimalV3Literal(new BigDecimal("12345.67")), new DecimalV3Literal(new BigDecimal("76543.21"))); + assertRewrite(comparison, comparison); + } }