branch-2.1: [fix](Nereids) simplify comparison predicate do wrong cast #44054 (#44119)

Cherry-picked from #44054

Co-authored-by: morrySnow <zhangwenxin@selectdb.com>
This commit is contained in:
github-actions[bot]
2024-11-18 16:53:51 +08:00
committed by GitHub
parent 7f129433ec
commit abbb12f93f
2 changed files with 19 additions and 8 deletions

View File

@ -229,15 +229,15 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule i
left = cast.child();
DecimalV3Literal literal = (DecimalV3Literal) right;
if (left.getDataType().isDecimalV3Type()) {
if (((DecimalV3Type) left.getDataType())
.getScale() < ((DecimalV3Type) literal.getDataType()).getScale()) {
DecimalV3Type leftType = (DecimalV3Type) left.getDataType();
DecimalV3Type literalType = (DecimalV3Type) literal.getDataType();
if (leftType.getScale() < literalType.getScale()) {
int toScale = ((DecimalV3Type) left.getDataType()).getScale();
if (comparisonPredicate instanceof EqualTo) {
try {
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
comparisonPredicate.withChildren(left, new DecimalV3Literal(
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
} catch (ArithmeticException e) {
if (left.nullable()) {
// TODO: the ideal way is to return an If expr like:
@ -255,9 +255,8 @@ public class SimplifyComparisonPredicate extends AbstractExpressionRewriteRule i
} else if (comparisonPredicate instanceof NullSafeEqual) {
try {
return TypeCoercionUtils.processComparisonPredicate((ComparisonPredicate)
comparisonPredicate.withChildren(left,
new DecimalV3Literal((DecimalV3Type) left.getDataType(),
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
comparisonPredicate.withChildren(left, new DecimalV3Literal(
literal.getValue().setScale(toScale, RoundingMode.UNNECESSARY))));
} catch (ArithmeticException e) {
return BooleanLiteral.of(false);
}

View File

@ -279,5 +279,17 @@ class SimplifyComparisonPredicateTest extends ExpressionRewriteTestHelper {
rewrittenExpression.child(0).getDataType());
Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1));
Assertions.assertEquals(new BigDecimal("12.35"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue());
// left's child range smaller than right literal
leftChild = new DecimalV3Literal(new BigDecimal("1234.12"));
left = new Cast(leftChild, DecimalV3Type.createDecimalV3Type(10, 5));
right = new DecimalV3Literal(new BigDecimal("12345.12000"));
expression = new EqualTo(left, right);
rewrittenExpression = executor.rewrite(expression, context);
Assertions.assertInstanceOf(Cast.class, rewrittenExpression.child(0));
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(7, 2),
rewrittenExpression.child(0).getDataType());
Assertions.assertInstanceOf(DecimalV3Literal.class, rewrittenExpression.child(1));
Assertions.assertEquals(new BigDecimal("12345.12"), ((DecimalV3Literal) rewrittenExpression.child(1)).getValue());
}
}