[fix](nereids) fix bug in CaseWhen.getDataType and add some missing case for findTightestCommonType (#15776)

This commit is contained in:
minghong
2023-01-19 15:30:25 +08:00
committed by GitHub
parent f9406234c6
commit 0144c51ddb
6 changed files with 44 additions and 4 deletions

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -78,7 +79,13 @@ public class CaseWhen extends Expression {
@Override
public DataType getDataType() {
return child(0).getDataType();
DataType outputType = child(0).getDataType();
for (Expression child : children) {
DataType tempType = outputType;
outputType = TypeCoercionUtils.findTightestCommonType(null,
outputType, child.getDataType()).orElseGet(() -> tempType);
}
return outputType;
}
@Override

View File

@ -239,6 +239,26 @@ public class TypeCoercionUtils {
} else if (left instanceof DateV2Type || right instanceof DateV2Type) {
tightestCommonType = DateV2Type.INSTANCE;
}
} else if (left instanceof DoubleType && right instanceof DecimalV2Type
|| left instanceof DecimalV2Type && right instanceof DoubleType) {
tightestCommonType = DoubleType.INSTANCE;
} else if (left instanceof DecimalV2Type && right instanceof DecimalV2Type) {
tightestCommonType = DecimalV2Type.widerDecimalV2Type((DecimalV2Type) left, (DecimalV2Type) right);
} else if (left instanceof FloatType && right instanceof DecimalV2Type
|| left instanceof DecimalV2Type && right instanceof FloatType) {
//TODO: need refactor. let operator upgrade data type.
if (binaryOperator != null) {
// for arithmetic, like Float + Decimal, upgrade to Double
tightestCommonType = DoubleType.INSTANCE;
} else {
//of other case, like
// case
// when 1=1 then cast(1 as int)
// when 1>1 then cast(1 as float)
// else 0.0 end;
//do not upgrade data type, keep Float
tightestCommonType = FloatType.INSTANCE;
}
} else if (canCompareDate(left, right)) {
if (binaryOperator instanceof BinaryArithmetic) {
tightestCommonType = IntegerType.INSTANCE;

View File

@ -145,7 +145,8 @@ public class TypeCoercionUtilsTest {
testFindTightestCommonType(BigIntType.INSTANCE, IntegerType.INSTANCE, BigIntType.INSTANCE);
testFindTightestCommonType(StringType.INSTANCE, StringType.INSTANCE, IntegerType.INSTANCE);
testFindTightestCommonType(StringType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE);
testFindTightestCommonType(DoubleType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
testFindTightestCommonType(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
testFindTightestCommonType(FloatType.INSTANCE, FloatType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT);
testFindTightestCommonType(VarcharType.createVarcharType(10), CharType.createCharType(8), CharType.createCharType(10));
testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), VarcharType.createVarcharType(10));
testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), CharType.createCharType(10));

View File

@ -34,6 +34,7 @@ false
-- !between11 --
-- !between12 --
6.333
-- !between13 --
123.123

View File

@ -53,4 +53,15 @@ suite("nereids_explain") {
sql("plan with s as (select * from supplier) select * from s as s1, s as s2")
contains "*LogicalSubQueryAlias"
}
explain {
sql """
verbose
select case
when 1=1 then cast(1 as int)
when 1>1 then cast(1 as float)
else 0.0 end;
"""
contains "SlotDescriptor{id=0, col=null, colUniqueId=null, type=FLOAT, nullable=false}"
}
}

View File

@ -36,6 +36,6 @@ suite("nereids_test_query_between", "query,p0") {
and \"9999-12-31 12:12:12\" order by k1, k2, k3, k4"""
qt_between11 """select k10 from ${tableName} where k10 between \"2015-04-02\"
and \"9999-12-31\" order by k1, k2, k3, k4"""
qt_between12 "select k9 from ${tableName} where k9 between -1 and 6.333 order by k1, k2, k3, k4"
qt_between13 "select k5 from ${tableName} where k5 between 0 and 1243.5 order by k1, k2, k3, k4"
qt_between12 "select k9 from ${tableName} where k9 between -1 and 6.34 order by k1, k2, k3, k4"
qt_between13 "select k5 from ${tableName} where k5 between 0 and 1243.6 order by k1, k2, k3, k4"
}