diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java index cc1694575e..cb61795865 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicate.java @@ -229,15 +229,15 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule i left = cast.child(); DecimalV3Literal literal = (DecimalV3Literal) right; if (left.getDataType().isDecimalV3Type()) { - if (((DecimalV3Type) left.getDataType()) - .getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) { + DecimalV3Type leftType = (DecimalV3Type) left.getDataType(); + DecimalV3Type literalType = (DecimalV3Type) literal.getDataType(); + if (leftType.getScale() < literalType.getScale()) { int toScale = ((DecimalV3Type) left.getDataType()).getScale(); if (comparisonPredicate instanceof EqualTo) { try { return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) - comparisonPredicate.withChildren(left, - new DecimalV3Literal((DecimalV3Type) left.getDataType(), - literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); + comparisonPredicate.withChildren(left, new DecimalV3Literal( + literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); } catch (ArithmeticException e) { if (left.nullable()) { // TODO: the ideal way is to return an If expr like: @@ -255,9 +255,8 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule i } else if (comparisonPredicate instanceof NullSafeEqual) { try { return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate) - comparisonPredicate.withChildren(left, - new DecimalV3Literal((DecimalV3Type) left.getDataType(), - literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); + comparisonPredicate.withChildren(left, new DecimalV3Literal( + literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY)))); } catch (ArithmeticException e) { return BooleanLiteral.of(false); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java index 402594d686..84ebd7c725 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/SimplifyComparisonPredicateTest.java @@ -279,5 +279,17 @@ class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper { rewrittenExpression.child(0).getDataType()); Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); + + // left's child range smaller than right literal + leftChild = new DecimalV3Literal(new BigDecimal("1234.12")); + left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(10, 5)); + right = new DecimalV3Literal(new BigDecimal("12345.12000")); + expression = new EqualTo(left, right); + rewrittenExpression = executor.rewrite(expression, context); + Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0)); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(7, 2), + rewrittenExpression.child(0).getDataType()); + Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1)); + Assertions.assertEquals(new BigDecimal("12345.12"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue()); } }