diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java index d88489084d..34143043a0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyCastRule.java @@ -82,14 +82,15 @@ public class SimplifyCastRule extends AbstractExpressionRewriteRule { ((VarcharType) castType).getLen()); } } else if (castType instanceof DecimalV2Type) { + DecimalV2Type decimalV2Type = (DecimalV2Type) castType; if (child instanceof TinyIntLiteral) { - return new DecimalLiteral(new BigDecimal(((TinyIntLiteral) child).getValue())); + return new DecimalLiteral(decimalV2Type, new BigDecimal(((TinyIntLiteral) child).getValue())); } else if (child instanceof SmallIntLiteral) { - return new DecimalLiteral(new BigDecimal(((SmallIntLiteral) child).getValue())); + return new DecimalLiteral(decimalV2Type, new BigDecimal(((SmallIntLiteral) child).getValue())); } else if (child instanceof IntegerLiteral) { - return new DecimalLiteral(new BigDecimal(((IntegerLiteral) child).getValue())); + return new DecimalLiteral(decimalV2Type, new BigDecimal(((IntegerLiteral) child).getValue())); } else if (child instanceof BigIntLiteral) { - return new DecimalLiteral(new BigDecimal(((BigIntLiteral) child).getValue())); + return new DecimalLiteral(decimalV2Type, new BigDecimal(((BigIntLiteral) child).getValue())); } } else if (castType instanceof DecimalV3Type) { DecimalV3Type decimalV3Type = (DecimalV3Type) castType; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java index 711673cc2f..03973fe6b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalLiteral.java @@ -81,6 +81,7 @@ public class DecimalLiteral extends Literal { boolean valid = true; if (precision != -1 && scale != -1) { if (precision < realPrecision || scale < realScale + || realPrecision - realScale > precision - scale || realPrecision - realScale > DecimalV2Type.MAX_PRECISION - DecimalV2Type.MAX_SCALE) { valid = false; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java index 65296b4f79..2a10196e05 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/DecimalV3Literal.java @@ -93,7 +93,7 @@ public class DecimalV3Literal extends Literal { int realScale = value.scale(); boolean valid = true; if (precision != -1 && scale != -1) { - if (precision < realPrecision || scale < realScale) { + if (precision < realPrecision || scale < realScale || precision - scale < realPrecision - realScale) { valid = false; } } else { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index 646174c9eb..de859058ec 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -200,13 +200,13 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper { // decimal literal assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalV2Type.createDecimalV2Type(15, 9)), - new DecimalLiteral(new BigDecimal(1))); + new DecimalLiteral(new BigDecimal("1.000000000"))); assertRewrite(new Cast(new SmallIntLiteral((short) 1), DecimalV2Type.createDecimalV2Type(15, 9)), - new DecimalLiteral(new BigDecimal(1))); + new DecimalLiteral(new BigDecimal("1.000000000"))); assertRewrite(new Cast(new IntegerLiteral(1), DecimalV2Type.createDecimalV2Type(15, 9)), - new DecimalLiteral(new BigDecimal(1))); + new DecimalLiteral(new BigDecimal("1.000000000"))); assertRewrite(new Cast(new BigIntLiteral(1L), DecimalV2Type.createDecimalV2Type(15, 9)), - new DecimalLiteral(new BigDecimal(1))); + new DecimalLiteral(new BigDecimal("1.000000000"))); } @Test diff --git a/regression-test/suites/nereids_syntax_p0/cast.groovy b/regression-test/suites/nereids_syntax_p0/cast.groovy index 0589ec5275..4a4192175b 100644 --- a/regression-test/suites/nereids_syntax_p0/cast.groovy +++ b/regression-test/suites/nereids_syntax_p0/cast.groovy @@ -233,6 +233,11 @@ suite("cast") { sql """select cast(k5 as time) ct from test order by ct;""" exception "cannot cast" } + test { + sql "select cast(12 as decimalv3(2,1))" + exception "Arithmetic overflow" + } + // date test { sql """select cast(k10 as time) ct from test order by ct;"""