[enhancement](Nereids) remove unnecessary decimal cast (#13745)
This commit is contained in:
@ -21,14 +21,22 @@ import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewri
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DecimalType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
/**
|
||||
* Rewrite rule of simplify CAST expression.
|
||||
* Remove redundant cast like
|
||||
@ -55,7 +63,7 @@ public class SimplifyCastRule extends AbstractExpressionRewriteRule {
|
||||
}
|
||||
|
||||
if (child instanceof Literal) {
|
||||
// TODO: just trick here, process other type
|
||||
// TODO: process other type
|
||||
DataType castType = cast.getDataType();
|
||||
if (castType instanceof StringType) {
|
||||
if (child instanceof VarcharLiteral) {
|
||||
@ -69,6 +77,16 @@ public class SimplifyCastRule extends AbstractExpressionRewriteRule {
|
||||
} else if (child instanceof CharLiteral) {
|
||||
return new VarcharLiteral(((CharLiteral) child).getValue(), ((VarcharType) castType).getLen());
|
||||
}
|
||||
} else if (castType instanceof DecimalType) {
|
||||
if (child instanceof TinyIntLiteral) {
|
||||
return new DecimalLiteral(new BigDecimal(((TinyIntLiteral) child).getValue()));
|
||||
} else if (child instanceof SmallIntLiteral) {
|
||||
return new DecimalLiteral(new BigDecimal(((SmallIntLiteral) child).getValue()));
|
||||
} else if (child instanceof IntegerLiteral) {
|
||||
return new DecimalLiteral(new BigDecimal(((IntegerLiteral) child).getValue()));
|
||||
} else if (child instanceof BigIntLiteral) {
|
||||
return new DecimalLiteral(new BigDecimal(((BigIntLiteral) child).getValue()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -48,7 +48,7 @@ public abstract class Literal extends Expression implements LeafExpression {
|
||||
* @param dataType logical data type in Nereids
|
||||
*/
|
||||
public Literal(DataType dataType) {
|
||||
this.dataType = dataType;
|
||||
this.dataType = Objects.requireNonNull(dataType);
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@ -25,15 +25,23 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPr
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
|
||||
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
|
||||
import org.apache.doris.nereids.trees.expressions.Cast;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
|
||||
import org.apache.doris.nereids.types.DecimalType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.math.BigDecimal;
|
||||
|
||||
/**
|
||||
* all expr rewrite rule test case.
|
||||
*/
|
||||
@ -193,5 +201,15 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
|
||||
new VarcharLiteral("123", 10));
|
||||
assertRewrite(new Cast(new CharLiteral("123", 3), StringType.INSTANCE), new StringLiteral("123"));
|
||||
assertRewrite(new Cast(new VarcharLiteral("123", 3), StringType.INSTANCE), new StringLiteral("123"));
|
||||
|
||||
// decimal literal
|
||||
assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalType.createDecimalType(15, 9)),
|
||||
new DecimalLiteral(new BigDecimal(1)));
|
||||
assertRewrite(new Cast(new SmallIntLiteral((short) 1), DecimalType.createDecimalType(15, 9)),
|
||||
new DecimalLiteral(new BigDecimal(1)));
|
||||
assertRewrite(new Cast(new IntegerLiteral(1), DecimalType.createDecimalType(15, 9)),
|
||||
new DecimalLiteral(new BigDecimal(1)));
|
||||
assertRewrite(new Cast(new BigIntLiteral(1L), DecimalType.createDecimalType(15, 9)),
|
||||
new DecimalLiteral(new BigDecimal(1)));
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user