[enhancement](Nereids) remove unnecessary decimal cast (#13745)

This commit is contained in:
morrySnow
2022-11-07 19:24:10 +08:00
committed by GitHub
parent f2978fb6ff
commit 4ea1b39cb2
3 changed files with 38 additions and 2 deletions

View File

@ -21,14 +21,22 @@ import org.apache.doris.nereids.rules.expression.rewrite.AbstractExpressionRewri
import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;
import java.math.BigDecimal;
/**
* Rewrite rule of simplify CAST expression.
* Remove redundant cast like
@ -55,7 +63,7 @@ public class SimplifyCastRule extends AbstractExpressionRewriteRule {
}
if (child instanceof Literal) {
// TODO: just trick here, process other type
// TODO: process other type
DataType castType = cast.getDataType();
if (castType instanceof StringType) {
if (child instanceof VarcharLiteral) {
@ -69,6 +77,16 @@ public class SimplifyCastRule extends AbstractExpressionRewriteRule {
} else if (child instanceof CharLiteral) {
return new VarcharLiteral(((CharLiteral) child).getValue(), ((VarcharType) castType).getLen());
}
} else if (castType instanceof DecimalType) {
if (child instanceof TinyIntLiteral) {
return new DecimalLiteral(new BigDecimal(((TinyIntLiteral) child).getValue()));
} else if (child instanceof SmallIntLiteral) {
return new DecimalLiteral(new BigDecimal(((SmallIntLiteral) child).getValue()));
} else if (child instanceof IntegerLiteral) {
return new DecimalLiteral(new BigDecimal(((IntegerLiteral) child).getValue()));
} else if (child instanceof BigIntLiteral) {
return new DecimalLiteral(new BigDecimal(((BigIntLiteral) child).getValue()));
}
}
}

View File

@ -48,7 +48,7 @@ public abstract class Literal extends Expression implements LeafExpression {
* @param dataType logical data type in Nereids
*/
public Literal(DataType dataType) {
this.dataType = dataType;
this.dataType = Objects.requireNonNull(dataType);
}
/**

View File

@ -25,15 +25,23 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeBinaryPr
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyCastRule;
import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.CharLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.types.DecimalType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.VarcharType;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
import java.math.BigDecimal;
/**
* all expr rewrite rule test case.
*/
@ -193,5 +201,15 @@ public class ExpressionRewriteTest extends ExpressionRewriteTestHelper {
new VarcharLiteral("123", 10));
assertRewrite(new Cast(new CharLiteral("123", 3), StringType.INSTANCE), new StringLiteral("123"));
assertRewrite(new Cast(new VarcharLiteral("123", 3), StringType.INSTANCE), new StringLiteral("123"));
// decimal literal
assertRewrite(new Cast(new TinyIntLiteral((byte) 1), DecimalType.createDecimalType(15, 9)),
new DecimalLiteral(new BigDecimal(1)));
assertRewrite(new Cast(new SmallIntLiteral((short) 1), DecimalType.createDecimalType(15, 9)),
new DecimalLiteral(new BigDecimal(1)));
assertRewrite(new Cast(new IntegerLiteral(1), DecimalType.createDecimalType(15, 9)),
new DecimalLiteral(new BigDecimal(1)));
assertRewrite(new Cast(new BigIntLiteral(1L), DecimalType.createDecimalType(15, 9)),
new DecimalLiteral(new BigDecimal(1)));
}
}