[opt](Nereids) remove decimalv2 signature from min, max, sum, nvl and case when (#29282)

This commit is contained in:
morrySnow
2023-12-29 23:22:32 +08:00
committed by GitHub
parent 03ece437f0
commit 989d20e0ac
6 changed files with 32 additions and 17 deletions

View File

@ -27,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindow
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -58,6 +60,9 @@ public class Max extends NullableAggregateFunction
@Override
public FunctionSignature customSignature() {
DataType dataType = getArgument(0).getDataType();
if (dataType instanceof DecimalV2Type) {
dataType = DecimalV3Type.forType(dataType);
}
return FunctionSignature.ret(dataType).args(dataType);
}

View File

@ -27,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindow
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -59,6 +61,9 @@ public class Min extends NullableAggregateFunction
@Override
public FunctionSignature customSignature() {
DataType dataType = getArgument(0).getDataType();
if (dataType instanceof DecimalV2Type) {
dataType = DecimalV3Type.forType(dataType);
}
return FunctionSignature.ret(dataType).args(dataType);
}

View File

@ -29,9 +29,9 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.SmallIntType;
@ -56,9 +56,8 @@ public class Sum extends NullableAggregateFunction
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD)
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
);
/**
@ -112,6 +111,14 @@ public class Sum extends NullableAggregateFunction
return SIGNATURES;
}
@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
if (getArgument(0).getDataType() instanceof FloatType) {
return FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE);
}
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}
@Override
public Function constructRollUp(Expression param, Expression... varParams) {
return new Sum(this.distinct, param);

View File

@ -29,7 +29,6 @@ import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
@ -58,6 +57,7 @@ public class Nvl extends ScalarFunction
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE, LargeIntType.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).args(DateType.INSTANCE, DateType.INSTANCE),
@ -66,10 +66,6 @@ public class Nvl extends ScalarFunction
.args(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateV2Type.INSTANCE)
.args(DateV2Type.INSTANCE, DateV2Type.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD)
.args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT)
.args(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE, BitmapType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),

View File

@ -986,25 +986,27 @@ public class TypeCoercionUtils {
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonTypeForCaseWhen(dataTypesForCoercion);
return optionalCommonType
.map(commonType -> {
DataType realCommonType = commonType instanceof DecimalV2Type
? DecimalV3Type.forType(commonType) : commonType;
List<Expression> newChildren
= caseWhen.getWhenClauses().stream()
.map(wc -> {
Expression valueExpr = TypeCoercionUtils.castIfNotSameType(
wc.getResult(), commonType);
wc.getResult(), realCommonType);
// we must cast every child to the common type, and then
// FoldConstantRuleOnFe can eliminate some branches and direct
// return a branch value
if (!valueExpr.getDataType().equals(commonType)) {
valueExpr = new Cast(valueExpr, commonType);
if (!valueExpr.getDataType().equals(realCommonType)) {
valueExpr = new Cast(valueExpr, realCommonType);
}
return wc.withChildren(wc.getOperand(), valueExpr);
})
.collect(Collectors.toList());
caseWhen.getDefaultValue()
.map(dv -> {
Expression defaultExpr = TypeCoercionUtils.castIfNotSameType(dv, commonType);
if (!defaultExpr.getDataType().equals(commonType)) {
defaultExpr = new Cast(defaultExpr, commonType);
Expression defaultExpr = TypeCoercionUtils.castIfNotSameType(dv, realCommonType);
if (!defaultExpr.getDataType().equals(realCommonType)) {
defaultExpr = new Cast(defaultExpr, realCommonType);
}
return defaultExpr;
})