[fix](nereids) fix bug in CaseWhen.getDataType and add some missing case for findTightestCommonType (#15776)

This commit is contained in:
minghong
2023-01-19 15:30:25 +08:00
committed by GitHub
parent f9406234c6
commit 0144c51ddb
6 changed files with 44 additions and 4 deletions

View File

@ -20,6 +20,7 @@ package org.apache.doris.nereids.trees.expressions;
import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -78,7 +79,13 @@ public class CaseWhen extends Expression {
@Override
public DataType getDataType() {
return child(0).getDataType();
DataType outputType = child(0).getDataType();
for (Expression child : children) {
DataType tempType = outputType;
outputType = TypeCoercionUtils.findTightestCommonType(null,
outputType, child.getDataType()).orElseGet(() -> tempType);
}
return outputType;
}
@Override

View File

@ -239,6 +239,26 @@ public class TypeCoercionUtils {
} else if (left instanceof DateV2Type || right instanceof DateV2Type) {
tightestCommonType = DateV2Type.INSTANCE;
}
} else if (left instanceof DoubleType && right instanceof DecimalV2Type
|| left instanceof DecimalV2Type && right instanceof DoubleType) {
tightestCommonType = DoubleType.INSTANCE;
} else if (left instanceof DecimalV2Type && right instanceof DecimalV2Type) {
tightestCommonType = DecimalV2Type.widerDecimalV2Type((DecimalV2Type) left, (DecimalV2Type) right);
} else if (left instanceof FloatType && right instanceof DecimalV2Type
|| left instanceof DecimalV2Type && right instanceof FloatType) {
//TODO: need refactor. let operator upgrade data type.
if (binaryOperator != null) {
// for arithmetic, like Float + Decimal, upgrade to Double
tightestCommonType = DoubleType.INSTANCE;
} else {
//of other case, like
// case
// when 1=1 then cast(1 as int)
// when 1>1 then cast(1 as float)
// else 0.0 end;
//do not upgrade data type, keep Float
tightestCommonType = FloatType.INSTANCE;
}
} else if (canCompareDate(left, right)) {
if (binaryOperator instanceof BinaryArithmetic) {
tightestCommonType = IntegerType.INSTANCE;

View File

@ -145,7 +145,8 @@ public class TypeCoercionUtilsTest {
testFindTightestCommonType(BigIntType.INSTANCE, IntegerType.INSTANCE, BigIntType.INSTANCE);
testFindTightestCommonType(StringType.INSTANCE, StringType.INSTANCE, IntegerType.INSTANCE);
testFindTightestCommonType(StringType.INSTANCE, IntegerType.INSTANCE, StringType.INSTANCE);
testFindTightestCommonType(DoubleType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
testFindTightestCommonType(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.createDecimalV2Type(2, 1));
testFindTightestCommonType(FloatType.INSTANCE, FloatType.INSTANCE, DecimalV2Type.SYSTEM_DEFAULT);
testFindTightestCommonType(VarcharType.createVarcharType(10), CharType.createCharType(8), CharType.createCharType(10));
testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), VarcharType.createVarcharType(10));
testFindTightestCommonType(VarcharType.createVarcharType(10), VarcharType.createVarcharType(8), CharType.createCharType(10));