diff --git a/be/src/vec/data_types/data_type_decimal.h b/be/src/vec/data_types/data_type_decimal.h index 9123b3837f..7b719adc35 100644 --- a/be/src/vec/data_types/data_type_decimal.h +++ b/be/src/vec/data_types/data_type_decimal.h @@ -236,7 +236,7 @@ DataTypePtr decimal_result_type(const DataTypeDecimal& tx, const DataTypeDeci size_t divide_precision = tx.get_precision() + ty.get_scale(); size_t plus_minus_precision = std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) + - scale; + scale + 1; if (is_multiply) { scale = tx.get_scale() + ty.get_scale(); precision = std::min(multiply_precision, max_decimal_precision()); diff --git a/docs/en/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md b/docs/en/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md index ef65446193..5e303ad67d 100644 --- a/docs/en/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md +++ b/docs/en/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md @@ -47,7 +47,7 @@ DECIMALV3 has a very complex set of type inference rules. For different expressi #### Arithmetic Expressions -* Plus / Minus: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y), max(b, y)). That is, the integer part and the decimal part use the larger value of the two operands respectively. +* Plus / Minus: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y) + 1, max(b, y)). * Multiply: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + x, b + y). * Divide: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + y, b). diff --git a/docs/zh-CN/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md b/docs/zh-CN/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md index 7eea5f8109..42838f9ef1 100644 --- a/docs/zh-CN/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md +++ b/docs/zh-CN/docs/sql-manual/sql-reference/Data-Types/DECIMALV3.md @@ -45,7 +45,7 @@ DECIMALV3有一套很复杂的类型推演规则,针对不同的表达式, #### 四则运算 -* 加法 / 减法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y), max(b, y)),即整数部分和小数部分都分别使用两个操作数中较大的值。 +* 加法 / 减法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y) + 1, max(b, y))。 * 乘法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + x, b + y)。 * 除法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + y, b)。 diff --git a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java index 8f991c92f3..3f9acb8dcb 100644 --- a/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-common/src/main/java/org/apache/doris/catalog/Type.java @@ -1816,6 +1816,10 @@ public abstract class Type { // Whether `type1` matches the exact type of `type2`. public static boolean matchExactType(Type type1, Type type2) { + return matchExactType(type1, type2, false); + } + + public static boolean matchExactType(Type type1, Type type2, boolean ignorePrecision) { if (type1.matchesType(type2)) { if (PrimitiveType.typeWithPrecision.contains(type2.getPrimitiveType())) { // For types which has precision and scale, we also need to check quality between precisions and scales @@ -1823,6 +1827,10 @@ public abstract class Type { == ((ScalarType) type1).decimalPrecision()) && (((ScalarType) type2).decimalScale() == ((ScalarType) type1).decimalScale())) { return true; + } else if (((ScalarType) type2).decimalScale() == ((ScalarType) type1).decimalScale() + && ignorePrecision) { + return isSameDecimalTypeWithDifferentPrecision(((ScalarType) type2).decimalPrecision(), + ((ScalarType) type1).decimalPrecision()); } } else if (type2.isArrayType()) { // For types array, we also need to check contains null for case like @@ -1836,5 +1844,20 @@ public abstract class Type { } return false; } + + public static boolean isSameDecimalTypeWithDifferentPrecision(int precision1, int precision2) { + if (precision1 <= ScalarType.MAX_DECIMAL32_PRECISION && precision2 <= ScalarType.MAX_DECIMAL32_PRECISION) { + return true; + } else if (precision1 > ScalarType.MAX_DECIMAL32_PRECISION && precision2 > ScalarType.MAX_DECIMAL32_PRECISION + && precision1 <= ScalarType.MAX_DECIMAL64_PRECISION + && precision2 <= ScalarType.MAX_DECIMAL64_PRECISION) { + return true; + } else if (precision1 > ScalarType.MAX_DECIMAL64_PRECISION && precision2 > ScalarType.MAX_DECIMAL64_PRECISION + && precision1 <= ScalarType.MAX_DECIMAL128_PRECISION + && precision2 <= ScalarType.MAX_DECIMAL128_PRECISION) { + return true; + } + return false; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java index 4b3c252418..383556fc32 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/ArithmeticExpr.java @@ -544,7 +544,7 @@ public class ArithmeticExpr extends Expr { // target type: DECIMALV3(max(widthOfIntPart1, widthOfIntPart2) + max(scale1, scale2) + 1, // max(scale1, scale2)) scale = Math.max(t1Scale, t2Scale); - precision = Math.max(widthOfIntPart1, widthOfIntPart2) + scale; + precision = Math.max(widthOfIntPart1, widthOfIntPart2) + scale + 1; } else { scale = Math.max(t1Scale, t2Scale); precision = widthOfIntPart2 + scale; @@ -559,10 +559,10 @@ public class ArithmeticExpr extends Expr { } type = ScalarType.createDecimalV3Type(precision, scale); if (op == Operator.ADD || op == Operator.SUBTRACT) { - if (!Type.matchExactType(type, children.get(0).type)) { + if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(0).type).getScalarScale()) { castChild(type, 0); } - if (!Type.matchExactType(type, children.get(1).type)) { + if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(1).type).getScalarScale()) { castChild(type, 1); } } else if (op == Operator.DIVIDE && (t2Scale != 0) && t1.isDecimalV3()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java index c1b1ee1fa2..ec3fda9eb1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/CastExpr.java @@ -288,7 +288,7 @@ public class CastExpr extends Expr { Type childType = getChild(0).getType(); // this cast may result in loss of precision, but the user requested it - noOp = Type.matchExactType(childType, type); + noOp = Type.matchExactType(childType, type, true); if (noOp) { // For decimalv2, we do not perform an actual cast between different precision/scale. Instead, we just diff --git a/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out b/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out index 4f68777f2a..085b844d7c 100644 --- a/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out +++ b/regression-test/data/datatype_p0/decimalv3/test_arithmetic_expressions.out @@ -32,3 +32,18 @@ 2.0736 3.2399999999999998 +-- !select_all -- +999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 + +-- !select -- +2999999.997 + +-- !select -- +2999999994000.000003 + +-- !select -- +3.000 + +-- !select -- +10999999.989 + diff --git a/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy b/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy index 301d719b15..284cf482e4 100644 --- a/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy +++ b/regression-test/suites/datatype_p0/decimalv3/test_arithmetic_expressions.groovy @@ -49,4 +49,20 @@ suite("test_arithmetic_expressions") { qt_select "select k1 * k2 * k3 * k1 * k2 * k3 from ${table1} order by k1" qt_select "select k1 * k2 / k3 * k1 * k2 * k3 from ${table1} order by k1" sql "drop table if exists ${table1}" + + sql """ + CREATE TABLE IF NOT EXISTS ${table1} ( `a` DECIMALV3(9, 3) NOT NULL, `b` DECIMALV3(9, 3) NOT NULL, `c` DECIMALV3(9, 3) NOT NULL, `d` DECIMALV3(9, 3) NOT NULL, `e` DECIMALV3(9, 3) NOT NULL, `f` DECIMALV3(9, 3) NOT + NULL, `g` DECIMALV3(9, 3) NOT NULL , `h` DECIMALV3(9, 3) NOT NULL, `i` DECIMALV3(9, 3) NOT NULL, `j` DECIMALV3(9, 3) NOT NULL, `k` DECIMALV3(9, 3) NOT NULL) DISTRIBUTED BY HASH(a) PROPERTIES("replication_num" = "1"); + """ + + sql """ + insert into ${table1} values(999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999); + """ + qt_select_all "select * from ${table1} order by a" + + qt_select "select a + b + c from ${table1};" + qt_select "select (a + b + c) * d from ${table1};" + qt_select "select (a + b + c) / d from ${table1};" + qt_select "select a + b + c + d + e + f + g + h + i + j + k from ${table1};" + sql "drop table if exists ${table1}" }