[Bug](DECIMALV3) Fix wrong precision for plus/minus (#18052)

Result type for DECIMAL(x, y) plus/minus DECIMAL(m, n) should be DECIMAL(max(x - y, m - n) + max(y + n) + 1, max(y + n))
This commit is contained in:
Gabriel
2023-03-25 09:42:39 +08:00
committed by GitHub
parent b2c70b51cc
commit 2408ca5da8
8 changed files with 61 additions and 7 deletions

View File

@ -236,7 +236,7 @@ DataTypePtr decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeDeci
size_t divide_precision = tx.get_precision() + ty.get_scale();
size_t plus_minus_precision =
std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) +
scale;
scale + 1;
if (is_multiply) {
scale = tx.get_scale() + ty.get_scale();
precision = std::min(multiply_precision, max_decimal_precision<Decimal128I>());

View File

@ -47,7 +47,7 @@ DECIMALV3 has a very complex set of type inference rules. For different expressi
#### Arithmetic Expressions
* Plus / Minus: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y), max(b, y)). That is, the integer part and the decimal part use the larger value of the two operands respectively.
* Plus / Minus: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y) + 1, max(b, y)).
* Multiply: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + x, b + y).
* Divide: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + y, b).

View File

@ -45,7 +45,7 @@ DECIMALV3有一套很复杂的类型推演规则,针对不同的表达式,
#### 四则运算
* 加法 / 减法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y), max(b, y)),即整数部分和小数部分都分别使用两个操作数中较大的值
* 加法 / 减法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y) + 1, max(b, y))。
* 乘法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + x, b + y)。
* 除法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + y, b)。

View File

@ -1816,6 +1816,10 @@ public abstract class Type {
// Whether `type1` matches the exact type of `type2`.
public static boolean matchExactType(Type type1, Type type2) {
return matchExactType(type1, type2, false);
}
public static boolean matchExactType(Type type1, Type type2, boolean ignorePrecision) {
if (type1.matchesType(type2)) {
if (PrimitiveType.typeWithPrecision.contains(type2.getPrimitiveType())) {
// For types which has precision and scale, we also need to check quality between precisions and scales
@ -1823,6 +1827,10 @@ public abstract class Type {
== ((ScalarType) type1).decimalPrecision()) && (((ScalarType) type2).decimalScale()
== ((ScalarType) type1).decimalScale())) {
return true;
} else if (((ScalarType) type2).decimalScale() == ((ScalarType) type1).decimalScale()
&& ignorePrecision) {
return isSameDecimalTypeWithDifferentPrecision(((ScalarType) type2).decimalPrecision(),
((ScalarType) type1).decimalPrecision());
}
} else if (type2.isArrayType()) {
// For types array, we also need to check contains null for case like
@ -1836,5 +1844,20 @@ public abstract class Type {
}
return false;
}
public static boolean isSameDecimalTypeWithDifferentPrecision(int precision1, int precision2) {
if (precision1 <= ScalarType.MAX_DECIMAL32_PRECISION && precision2 <= ScalarType.MAX_DECIMAL32_PRECISION) {
return true;
} else if (precision1 > ScalarType.MAX_DECIMAL32_PRECISION && precision2 > ScalarType.MAX_DECIMAL32_PRECISION
&& precision1 <= ScalarType.MAX_DECIMAL64_PRECISION
&& precision2 <= ScalarType.MAX_DECIMAL64_PRECISION) {
return true;
} else if (precision1 > ScalarType.MAX_DECIMAL64_PRECISION && precision2 > ScalarType.MAX_DECIMAL64_PRECISION
&& precision1 <= ScalarType.MAX_DECIMAL128_PRECISION
&& precision2 <= ScalarType.MAX_DECIMAL128_PRECISION) {
return true;
}
return false;
}
}

View File

@ -544,7 +544,7 @@ public class ArithmeticExpr extends Expr {
// target type: DECIMALV3(max(widthOfIntPart1, widthOfIntPart2) + max(scale1, scale2) + 1,
// max(scale1, scale2))
scale = Math.max(t1Scale, t2Scale);
precision = Math.max(widthOfIntPart1, widthOfIntPart2) + scale;
precision = Math.max(widthOfIntPart1, widthOfIntPart2) + scale + 1;
} else {
scale = Math.max(t1Scale, t2Scale);
precision = widthOfIntPart2 + scale;
@ -559,10 +559,10 @@ public class ArithmeticExpr extends Expr {
}
type = ScalarType.createDecimalV3Type(precision, scale);
if (op == Operator.ADD || op == Operator.SUBTRACT) {
if (!Type.matchExactType(type, children.get(0).type)) {
if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(0).type).getScalarScale()) {
castChild(type, 0);
}
if (!Type.matchExactType(type, children.get(1).type)) {
if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(1).type).getScalarScale()) {
castChild(type, 1);
}
} else if (op == Operator.DIVIDE && (t2Scale != 0) && t1.isDecimalV3()) {

View File

@ -288,7 +288,7 @@ public class CastExpr extends Expr {
Type childType = getChild(0).getType();
// this cast may result in loss of precision, but the user requested it
noOp = Type.matchExactType(childType, type);
noOp = Type.matchExactType(childType, type, true);
if (noOp) {
// For decimalv2, we do not perform an actual cast between different precision/scale. Instead, we just

View File

@ -32,3 +32,18 @@
2.0736
3.2399999999999998
-- !select_all --
999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999
-- !select --
2999999.997
-- !select --
2999999994000.000003
-- !select --
3.000
-- !select --
10999999.989

View File

@ -49,4 +49,20 @@ suite("test_arithmetic_expressions") {
qt_select "select k1 * k2 * k3 * k1 * k2 * k3 from ${table1} order by k1"
qt_select "select k1 * k2 / k3 * k1 * k2 * k3 from ${table1} order by k1"
sql "drop table if exists ${table1}"
sql """
CREATE TABLE IF NOT EXISTS ${table1} ( `a` DECIMALV3(9, 3) NOT NULL, `b` DECIMALV3(9, 3) NOT NULL, `c` DECIMALV3(9, 3) NOT NULL, `d` DECIMALV3(9, 3) NOT NULL, `e` DECIMALV3(9, 3) NOT NULL, `f` DECIMALV3(9, 3) NOT
NULL, `g` DECIMALV3(9, 3) NOT NULL , `h` DECIMALV3(9, 3) NOT NULL, `i` DECIMALV3(9, 3) NOT NULL, `j` DECIMALV3(9, 3) NOT NULL, `k` DECIMALV3(9, 3) NOT NULL) DISTRIBUTED BY HASH(a) PROPERTIES("replication_num" = "1");
"""
sql """
insert into ${table1} values(999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999);
"""
qt_select_all "select * from ${table1} order by a"
qt_select "select a + b + c from ${table1};"
qt_select "select (a + b + c) * d from ${table1};"
qt_select "select (a + b + c) / d from ${table1};"
qt_select "select a + b + c + d + e + f + g + h + i + j + k from ${table1};"
sql "drop table if exists ${table1}"
}