diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java index a8530bbc18..a879fe470b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Max.java @@ -27,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindow import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -58,6 +60,9 @@ public class Max extends NullableAggregateFunction @Override public FunctionSignature customSignature() { DataType dataType = getArgument(0).getDataType(); + if (dataType instanceof DecimalV2Type) { + dataType = DecimalV3Type.forType(dataType); + } return FunctionSignature.ret(dataType).args(dataType); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java index d40c32844e..8e8d0c7e87 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Min.java @@ -27,6 +27,8 @@ import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindow import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -59,6 +61,9 @@ public class Min extends NullableAggregateFunction @Override public FunctionSignature customSignature() { DataType dataType = getArgument(0).getDataType(); + if (dataType instanceof DecimalV2Type) { + dataType = DecimalV3Type.forType(dataType); + } return FunctionSignature.ret(dataType).args(dataType); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java index 34681a2611..8799203936 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/agg/Sum.java @@ -29,9 +29,9 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.BooleanType; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.DecimalV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.FloatType; import org.apache.doris.nereids.types.IntegerType; import org.apache.doris.nereids.types.LargeIntType; import org.apache.doris.nereids.types.SmallIntType; @@ -56,9 +56,8 @@ public class Sum extends NullableAggregateFunction FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE), FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE), FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE), - FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE), - FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT), - FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD) + FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD), + FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE) ); /** @@ -112,6 +111,14 @@ public class Sum extends NullableAggregateFunction return SIGNATURES; } + @Override + public FunctionSignature searchSignature(List signatures) { + if (getArgument(0).getDataType() instanceof FloatType) { + return FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE); + } + return ExplicitlyCastableSignature.super.searchSignature(signatures); + } + @Override public Function constructRollUp(Expression param, Expression... varParams) { return new Sum(this.distinct, param); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java index 9b8e51d238..e3d0335cbc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Nvl.java @@ -29,7 +29,6 @@ import org.apache.doris.nereids.types.DateTimeType; import org.apache.doris.nereids.types.DateTimeV2Type; import org.apache.doris.nereids.types.DateType; import org.apache.doris.nereids.types.DateV2Type; -import org.apache.doris.nereids.types.DecimalV2Type; import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.FloatType; @@ -58,6 +57,7 @@ public class Nvl extends ScalarFunction FunctionSignature.ret(IntegerType.INSTANCE).args(IntegerType.INSTANCE, IntegerType.INSTANCE), FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE), FunctionSignature.ret(LargeIntType.INSTANCE).args(LargeIntType.INSTANCE, LargeIntType.INSTANCE), + FunctionSignature.ret(DecimalV3Type.WILDCARD).args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD), FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE), FunctionSignature.ret(FloatType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE), FunctionSignature.ret(DateType.INSTANCE).args(DateType.INSTANCE, DateType.INSTANCE), @@ -66,10 +66,6 @@ public class Nvl extends ScalarFunction .args(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(DateV2Type.INSTANCE) .args(DateV2Type.INSTANCE, DateV2Type.INSTANCE), - FunctionSignature.ret(DecimalV3Type.WILDCARD) - .args(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD), - FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT) - .args(DecimalV2Type.SYSTEM_DEFAULT, DecimalV2Type.SYSTEM_DEFAULT), FunctionSignature.ret(BitmapType.INSTANCE).args(BitmapType.INSTANCE, BitmapType.INSTANCE), FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT) .args(VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java index 273ba4b643..168ee50c55 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/TypeCoercionUtils.java @@ -986,25 +986,27 @@ public class TypeCoercionUtils { Optional optionalCommonType = TypeCoercionUtils.findWiderCommonTypeForCaseWhen(dataTypesForCoercion); return optionalCommonType .map(commonType -> { + DataType realCommonType = commonType instanceof DecimalV2Type + ? DecimalV3Type.forType(commonType) : commonType; List newChildren = caseWhen.getWhenClauses().stream() .map(wc -> { Expression valueExpr = TypeCoercionUtils.castIfNotSameType( - wc.getResult(), commonType); + wc.getResult(), realCommonType); // we must cast every child to the common type, and then // FoldConstantRuleOnFe can eliminate some branches and direct // return a branch value - if (!valueExpr.getDataType().equals(commonType)) { - valueExpr = new Cast(valueExpr, commonType); + if (!valueExpr.getDataType().equals(realCommonType)) { + valueExpr = new Cast(valueExpr, realCommonType); } return wc.withChildren(wc.getOperand(), valueExpr); }) .collect(Collectors.toList()); caseWhen.getDefaultValue() .map(dv -> { - Expression defaultExpr = TypeCoercionUtils.castIfNotSameType(dv, commonType); - if (!defaultExpr.getDataType().equals(commonType)) { - defaultExpr = new Cast(defaultExpr, commonType); + Expression defaultExpr = TypeCoercionUtils.castIfNotSameType(dv, realCommonType); + if (!defaultExpr.getDataType().equals(realCommonType)) { + defaultExpr = new Cast(defaultExpr, realCommonType); } return defaultExpr; }) diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java index 62019d8bac..05824d3280 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java @@ -35,7 +35,7 @@ import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; -import org.apache.doris.nereids.types.DecimalV2Type; +import org.apache.doris.nereids.types.DecimalV3Type; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.LargeIntType; @@ -74,7 +74,7 @@ public class GetDataTypeTest { Assertions.assertEquals(LargeIntType.INSTANCE, checkAndGetDataType(new Sum(largeIntLiteral))); Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum(floatLiteral))); Assertions.assertEquals(DoubleType.INSTANCE, checkAndGetDataType(new Sum(doubleLiteral))); - Assertions.assertEquals(DecimalV2Type.createDecimalV2Type(27, 9), checkAndGetDataType(new Sum(decimalLiteral))); + Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(38, 0), checkAndGetDataType(new Sum(decimalLiteral))); Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum(bigIntLiteral))); Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(charLiteral))); Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(varcharLiteral)));