diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 5c776b00ac..c3548d0eaa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -143,6 +143,10 @@ public class TypeCoercionUtils { if (leftType instanceof NullType && rightType instanceof DecimalType) { return true; } + if (leftType instanceof DecimalType && rightType instanceof IntegralType + || leftType instanceof IntegralType && rightType instanceof DecimalType) { + return true; + } // TODO: add decimal promotion support if (!(leftType instanceof DecimalType) && !(rightType instanceof DecimalType) && !leftType.equals(rightType)) { return true; @@ -189,6 +193,10 @@ public class TypeCoercionUtils { } } else if (left instanceof CharacterType || right instanceof CharacterType) { tightestCommonType = StringType.INSTANCE; + } else if (left instanceof DecimalType && right instanceof IntegralType) { + tightestCommonType = DecimalType.widerDecimalType((DecimalType) left, DecimalType.forType(right)); + } else if (left instanceof IntegralType && right instanceof DecimalType) { + tightestCommonType = DecimalType.widerDecimalType((DecimalType) right, DecimalType.forType(left)); } return Optional.ofNullable(tightestCommonType); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java index a4358d56aa..155637ec8c 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rewrite/TypeCoercionTest.java @@ -24,6 +24,7 @@ import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.Divide; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.trees.expressions.LessThanEqual; import org.apache.doris.nereids.trees.expressions.WhenClause; import org.apache.doris.nereids.trees.expressions.functions.Avg; import org.apache.doris.nereids.trees.expressions.functions.Substring; @@ -32,20 +33,25 @@ import org.apache.doris.nereids.trees.expressions.functions.Year; import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DecimalType; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.StringType; +import org.apache.doris.nereids.types.TinyIntType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.math.BigDecimal; import java.util.List; public class TypeCoercionTest { @@ -103,6 +109,35 @@ public class TypeCoercionTest { assertRewrite(expected, expression); } + @Test + public void testBinaryPredicate() { + Expression left = new DecimalLiteral(new BigDecimal(2.4)); + Expression right = new TinyIntLiteral((byte) 2); + Expression lessThanEq = new LessThanEqual(left, right); + Expression rewrittenPred = + new LessThanEqual( + left, + new Cast(right, left.getDataType())); + assertRewrite(rewrittenPred, lessThanEq); + + rewrittenPred = + new LessThanEqual( + new Cast(right, left.getDataType()), + left + ); + lessThanEq = new LessThanEqual(right, left); + assertRewrite(rewrittenPred, lessThanEq); + + left = new DecimalLiteral(new BigDecimal(1)); + lessThanEq = new LessThanEqual(left, right); + rewrittenPred = + new LessThanEqual( + new Cast(left, DecimalType.forType(TinyIntType.INSTANCE)), + new Cast(right, DecimalType.forType(TinyIntType.INSTANCE)) + ); + assertRewrite(rewrittenPred, lessThanEq); + } + @Test public void testCaseWhenTypeCoercion() { WhenClause actualWhenClause1 = new WhenClause(new BooleanLiteral(true), new SmallIntLiteral((short) 1)); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java index 92613e4b86..74a864d3af 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/TypeCoercionUtilsTest.java @@ -110,8 +110,8 @@ public class TypeCoercionUtilsTest { Assertions.assertTrue(TypeCoercionUtils.canHandleTypeCoercion(decimalType, nullType)); Assertions.assertTrue(TypeCoercionUtils.canHandleTypeCoercion(nullType, decimalType)); Assertions.assertTrue(TypeCoercionUtils.canHandleTypeCoercion(smallIntType, integerType)); - Assertions.assertFalse(TypeCoercionUtils.canHandleTypeCoercion(integerType, decimalType)); - Assertions.assertFalse(TypeCoercionUtils.canHandleTypeCoercion(decimalType, integerType)); + Assertions.assertTrue(TypeCoercionUtils.canHandleTypeCoercion(integerType, decimalType)); + Assertions.assertTrue(TypeCoercionUtils.canHandleTypeCoercion(decimalType, integerType)); Assertions.assertFalse(TypeCoercionUtils.canHandleTypeCoercion(integerType, integerType)); }