[opt](Nereids) add boolean type signature for sum aggregate function (#21959)

This commit is contained in:
zhangstar333
2023-07-27 09:41:19 +08:00
committed by GitHub
parent 12222eb145
commit fb41265c27
4 changed files with 341 additions and 448 deletions

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.functions.window.SupportWindow
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
@ -47,6 +48,7 @@ public class Sum extends NullableAggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, ComputePrecisionForSum, SupportWindowAnalytic {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BooleanType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(TinyIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(SmallIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(IntegerType.INSTANCE),
@ -78,8 +80,9 @@ public class Sum extends NullableAggregateFunction
@Override
public void checkLegalityBeforeTypeCoercion() {
DataType argType = child().getDataType();
if (((!argType.isNumericType() && !argType.isNullType()) || argType.isOnlyMetricType())) {
throw new AnalysisException("sum requires a numeric parameter: " + this.toSql());
if ((!argType.isNumericType() && !argType.isBooleanType() && !argType.isNullType())
|| argType.isOnlyMetricType()) {
throw new AnalysisException("sum requires a numeric or boolean parameter: " + this.toSql());
}
}

View File

@ -66,7 +66,7 @@ public class GetDataTypeTest {
@Test
public void testSum() {
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum(nullLiteral)));
Assertions.assertThrows(RuntimeException.class, () -> checkAndGetDataType(new Sum(booleanLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum(booleanLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum(tinyIntLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum(smallIntLiteral)));
Assertions.assertEquals(BigIntType.INSTANCE, checkAndGetDataType(new Sum(integerLiteral)));