diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyCastRule.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyCastRule.java index 5e9fd3465a..24cebf3bc8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyCastRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/rules/SimplifyCastRule.java @@ -21,14 +21,22 @@ import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewri import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.CharLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalType; import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; +import java.math.BigDecimal; + /** * Rewrite rule of simplify CAST expression. * Remove redundant cast like @@ -55,7 +63,7 @@ public class SimplifyCastRule extends AbstractExpressionRewriteRule { } if (child instanceof Literal) { - // TODO: just trick here, process other type + // TODO: process other type DataType castType = cast.getDataType(); if (castType instanceof StringType) { if (child instanceof VarcharLiteral) { @@ -69,6 +77,16 @@ public class SimplifyCastRule extends AbstractExpressionRewriteRule { } else if (child instanceof CharLiteral) { return new VarcharLiteral(((CharLiteral) child).getValue(), ((VarcharType) castType).getLen()); } + } else if (castType instanceof DecimalType) { + if (child instanceof TinyIntLiteral) { + return new DecimalLiteral(new BigDecimal(((TinyIntLiteral) child).getValue())); + } else if (child instanceof SmallIntLiteral) { + return new DecimalLiteral(new BigDecimal(((SmallIntLiteral) child).getValue())); + } else if (child instanceof IntegerLiteral) { + return new DecimalLiteral(new BigDecimal(((IntegerLiteral) child).getValue())); + } else if (child instanceof BigIntLiteral) { + return new DecimalLiteral(new BigDecimal(((BigIntLiteral) child).getValue())); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java index 561c2c8180..6c8dc2677e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java @@ -48,7 +48,7 @@ public abstract class Literal extends Expression implements LeafExpression { * @param dataType logical data type in Nereids */ public Literal(DataType dataType) { - this.dataType = dataType; + this.dataType = Objects.requireNonNull(dataType); } /** diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java index ab7ca66768..ef5f78f28e 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewriteTest.java @@ -25,15 +25,23 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPr import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule; import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule; import org.apache.doris.nereids.trees.expressions.Cast; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.CharLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import org.apache.doris.nereids.types.DecimalType; import org.apache.doris.nereids.types.StringType; import org.apache.doris.nereids.types.VarcharType; import com.google.common.collect.ImmutableList; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; + /** * all expr rewrite rule test case. */ @@ -193,5 +201,15 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper { new VarcharLiteral("123", 10)); assertRewrite(new Cast(new CharLiteral("123", 3), StringType.INSTANCE), new StringLiteral("123")); assertRewrite(new Cast(new VarcharLiteral("123", 3), StringType.INSTANCE), new StringLiteral("123")); + + // decimal literal + assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalType.createDecimalType(15, 9)), + new DecimalLiteral(new BigDecimal(1))); + assertRewrite(new Cast(new SmallIntLiteral((short) 1), DecimalType.createDecimalType(15, 9)), + new DecimalLiteral(new BigDecimal(1))); + assertRewrite(new Cast(new IntegerLiteral(1), DecimalType.createDecimalType(15, 9)), + new DecimalLiteral(new BigDecimal(1))); + assertRewrite(new Cast(new BigIntLiteral(1L), DecimalType.createDecimalType(15, 9)), + new DecimalLiteral(new BigDecimal(1))); } }