From de4bdc7f6ff19d1d0bad6c3bed3ec94be7e623b8 Mon Sep 17 00:00:00 2001 From: morrySnow <101034200+morrySnow@users.noreply.github.com> Date: Tue, 30 Aug 2022 19:53:25 +0800 Subject: [PATCH] [fix](Nereids)Sum return DoubleType when child is DecimalType by mistake (#12169) When Sum's child is Decimal, Return Double Type by mistake lead to result error, so we should keep the return type to decimal when the child expression's type is decimal. --- .../trees/expressions/functions/Sum.java | 4 + .../trees/expressions/GetDataTypeTest.java | 84 +++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Sum.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Sum.java index 9c31349d09..f946713945 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Sum.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Sum.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DecimalType; import org.apache.doris.nereids.types.DoubleType; import org.apache.doris.nereids.types.LargeIntType; import org.apache.doris.nereids.types.coercion.AbstractDataType; @@ -49,6 +50,9 @@ public class Sum extends AggregateFunction implements UnaryExpression, ImplicitC DataType dataType = child().getDataType(); if (dataType instanceof LargeIntType) { return dataType; + } else if (dataType instanceof DecimalType) { + // TODO: precision + 10 + return dataType; } else if (dataType instanceof IntegralType) { return BigIntType.INSTANCE; } else if (dataType instanceof FractionalType) { diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java new file mode 100644 index 0000000000..e956a58e16 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/trees/expressions/GetDataTypeTest.java @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import org.apache.doris.nereids.trees.expressions.functions.Sum; +import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; +import org.apache.doris.nereids.trees.expressions.literal.CharLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DateLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DateTimeLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DecimalLiteral; +import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral; +import org.apache.doris.nereids.trees.expressions.literal.FloatLiteral; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.expressions.literal.LargeIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; +import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.StringLiteral; +import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral; +import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral; +import org.apache.doris.nereids.types.BigIntType; +import org.apache.doris.nereids.types.DecimalType; +import org.apache.doris.nereids.types.DoubleType; +import org.apache.doris.nereids.types.LargeIntType; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.math.BigDecimal; +import java.math.BigInteger; + +public class GetDataTypeTest { + + NullLiteral nullLiteral = NullLiteral.INSTANCE; + BooleanLiteral booleanLiteral = BooleanLiteral.FALSE; + TinyIntLiteral tinyIntLiteral = new TinyIntLiteral((byte) 1); + SmallIntLiteral smallIntLiteral = new SmallIntLiteral((short) 1); + IntegerLiteral integerLiteral = new IntegerLiteral(1); + BigIntLiteral bigIntLiteral = new BigIntLiteral(1L); + LargeIntLiteral largeIntLiteral = new LargeIntLiteral(BigInteger.valueOf(1L)); + FloatLiteral floatLiteral = new FloatLiteral(1.0F); + DoubleLiteral doubleLiteral = new DoubleLiteral(1.0); + DecimalLiteral decimalLiteral = new DecimalLiteral(BigDecimal.ONE); + CharLiteral charLiteral = new CharLiteral("hello", 5); + VarcharLiteral varcharLiteral = new VarcharLiteral("hello", 5); + StringLiteral stringLiteral = new StringLiteral("hello"); + DateLiteral dateLiteral = new DateLiteral(2022, 2, 2); + DateTimeLiteral dateTimeLiteral = new DateTimeLiteral(2022, 2, 2, 2, 2, 2); + + @Test + public void testSum() { + Assertions.assertThrows(RuntimeException.class, () -> new Sum(nullLiteral).getDataType()); + Assertions.assertThrows(RuntimeException.class, () -> new Sum(booleanLiteral).getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, new Sum(tinyIntLiteral).getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, new Sum(smallIntLiteral).getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, new Sum(integerLiteral).getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, new Sum(bigIntLiteral).getDataType()); + Assertions.assertEquals(LargeIntType.INSTANCE, new Sum(largeIntLiteral).getDataType()); + Assertions.assertEquals(DoubleType.INSTANCE, new Sum(floatLiteral).getDataType()); + Assertions.assertEquals(DoubleType.INSTANCE, new Sum(doubleLiteral).getDataType()); + Assertions.assertEquals(DecimalType.createDecimalType(BigDecimal.ONE), new Sum(decimalLiteral).getDataType()); + Assertions.assertEquals(BigIntType.INSTANCE, new Sum(bigIntLiteral).getDataType()); + Assertions.assertThrows(RuntimeException.class, () -> new Sum(charLiteral).getDataType()); + Assertions.assertThrows(RuntimeException.class, () -> new Sum(varcharLiteral).getDataType()); + Assertions.assertThrows(RuntimeException.class, () -> new Sum(stringLiteral).getDataType()); + Assertions.assertThrows(RuntimeException.class, () -> new Sum(dateLiteral).getDataType()); + Assertions.assertThrows(RuntimeException.class, () -> new Sum(dateTimeLiteral).getDataType()); + } +}