[opt](Nereids) remove decimalv2 signature from min, max, sum, nvl and case when (#29282)
This commit is contained in:
@ -27,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindow
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -58,6 +60,9 @@ public class Max extends NullableAggregateFunction
|
||||
@Override
|
||||
public FunctionSignature customSignature() {
|
||||
DataType dataType = getArgument(0).getDataType();
|
||||
if (dataType instanceof DecimalV2Type) {
|
||||
dataType = DecimalV3Type.forType(dataType);
|
||||
}
|
||||
return FunctionSignature.ret(dataType).args(dataType);
|
||||
}
|
||||
|
||||
|
||||
@ -27,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindow
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
@ -59,6 +61,9 @@ public class Min extends NullableAggregateFunction
|
||||
@Override
|
||||
public FunctionSignature customSignature() {
|
||||
DataType dataType = getArgument(0).getDataType();
|
||||
if (dataType instanceof DecimalV2Type) {
|
||||
dataType = DecimalV3Type.forType(dataType);
|
||||
}
|
||||
return FunctionSignature.ret(dataType).args(dataType);
|
||||
}
|
||||
|
||||
|
||||
@ -29,9 +29,9 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.BooleanType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.LargeIntType;
|
||||
import org.apache.doris.nereids.types.SmallIntType;
|
||||
@ -56,9 +56,8 @@ public class Sum extends NullableAggregateFunction
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE),
|
||||
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD)
|
||||
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE)
|
||||
);
|
||||
|
||||
/**
|
||||
@ -112,6 +111,14 @@ public class Sum extends NullableAggregateFunction
|
||||
return SIGNATURES;
|
||||
}
|
||||
|
||||
@Override
|
||||
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
|
||||
if (getArgument(0).getDataType() instanceof FloatType) {
|
||||
return FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE);
|
||||
}
|
||||
return ExplicitlyCastableSignature.super.searchSignature(signatures);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Function constructRollUp(Expression param, Expression... varParams) {
|
||||
return new Sum(this.distinct, param);
|
||||
|
||||
@ -29,7 +29,6 @@ import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.types.DateV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
@ -58,6 +57,7 @@ public class Nvl extends ScalarFunction
|
||||
FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE),
|
||||
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
|
||||
FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE, LargeIntType.INSTANCE),
|
||||
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
|
||||
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(FloatType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
|
||||
FunctionSignature.ret(DateType.INSTANCE).args(DateType.INSTANCE, DateType.INSTANCE),
|
||||
@ -66,10 +66,6 @@ public class Nvl extends ScalarFunction
|
||||
.args(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(DateV2Type.INSTANCE)
|
||||
.args(DateV2Type.INSTANCE, DateV2Type.INSTANCE),
|
||||
FunctionSignature.ret(DecimalV3Type.WILDCARD)
|
||||
.args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
|
||||
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT)
|
||||
.args(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE, BitmapType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
|
||||
.args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
|
||||
|
||||
@ -986,25 +986,27 @@ public class TypeCoercionUtils {
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonTypeForCaseWhen(dataTypesForCoercion);
|
||||
return optionalCommonType
|
||||
.map(commonType -> {
|
||||
DataType realCommonType = commonType instanceof DecimalV2Type
|
||||
? DecimalV3Type.forType(commonType) : commonType;
|
||||
List<Expression> newChildren
|
||||
= caseWhen.getWhenClauses().stream()
|
||||
.map(wc -> {
|
||||
Expression valueExpr = TypeCoercionUtils.castIfNotSameType(
|
||||
wc.getResult(), commonType);
|
||||
wc.getResult(), realCommonType);
|
||||
// we must cast every child to the common type, and then
|
||||
// FoldConstantRuleOnFe can eliminate some branches and direct
|
||||
// return a branch value
|
||||
if (!valueExpr.getDataType().equals(commonType)) {
|
||||
valueExpr = new Cast(valueExpr, commonType);
|
||||
if (!valueExpr.getDataType().equals(realCommonType)) {
|
||||
valueExpr = new Cast(valueExpr, realCommonType);
|
||||
}
|
||||
return wc.withChildren(wc.getOperand(), valueExpr);
|
||||
})
|
||||
.collect(Collectors.toList());
|
||||
caseWhen.getDefaultValue()
|
||||
.map(dv -> {
|
||||
Expression defaultExpr = TypeCoercionUtils.castIfNotSameType(dv, commonType);
|
||||
if (!defaultExpr.getDataType().equals(commonType)) {
|
||||
defaultExpr = new Cast(defaultExpr, commonType);
|
||||
Expression defaultExpr = TypeCoercionUtils.castIfNotSameType(dv, realCommonType);
|
||||
if (!defaultExpr.getDataType().equals(realCommonType)) {
|
||||
defaultExpr = new Cast(defaultExpr, realCommonType);
|
||||
}
|
||||
return defaultExpr;
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user