[Bug](DECIMALV3) Fix wrong precision for plus/minus (#18052)
Result type for DECIMAL(x, y) plus/minus DECIMAL(m, n) should be DECIMAL(max(x - y, m - n) + max(y + n) + 1, max(y + n))
This commit is contained in:
@ -236,7 +236,7 @@ DataTypePtr decimal_result_type(const DataTypeDecimal<T>& tx, const DataTypeDeci
|
||||
size_t divide_precision = tx.get_precision() + ty.get_scale();
|
||||
size_t plus_minus_precision =
|
||||
std::max(tx.get_precision() - tx.get_scale(), ty.get_precision() - ty.get_scale()) +
|
||||
scale;
|
||||
scale + 1;
|
||||
if (is_multiply) {
|
||||
scale = tx.get_scale() + ty.get_scale();
|
||||
precision = std::min(multiply_precision, max_decimal_precision<Decimal128I>());
|
||||
|
||||
@ -47,7 +47,7 @@ DECIMALV3 has a very complex set of type inference rules. For different expressi
|
||||
|
||||
#### Arithmetic Expressions
|
||||
|
||||
* Plus / Minus: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y), max(b, y)). That is, the integer part and the decimal part use the larger value of the two operands respectively.
|
||||
* Plus / Minus: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y) + 1, max(b, y)).
|
||||
* Multiply: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + x, b + y).
|
||||
* Divide: DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + y, b).
|
||||
|
||||
|
||||
@ -45,7 +45,7 @@ DECIMALV3有一套很复杂的类型推演规则,针对不同的表达式,
|
||||
|
||||
#### 四则运算
|
||||
|
||||
* 加法 / 减法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y), max(b, y)),即整数部分和小数部分都分别使用两个操作数中较大的值。
|
||||
* 加法 / 减法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(max(a - b, x - y) + max(b, y) + 1, max(b, y))。
|
||||
* 乘法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + x, b + y)。
|
||||
* 除法:DECIMALV3(a, b) + DECIMALV3(x, y) -> DECIMALV3(a + y, b)。
|
||||
|
||||
|
||||
@ -1816,6 +1816,10 @@ public abstract class Type {
|
||||
|
||||
// Whether `type1` matches the exact type of `type2`.
|
||||
public static boolean matchExactType(Type type1, Type type2) {
|
||||
return matchExactType(type1, type2, false);
|
||||
}
|
||||
|
||||
public static boolean matchExactType(Type type1, Type type2, boolean ignorePrecision) {
|
||||
if (type1.matchesType(type2)) {
|
||||
if (PrimitiveType.typeWithPrecision.contains(type2.getPrimitiveType())) {
|
||||
// For types which has precision and scale, we also need to check quality between precisions and scales
|
||||
@ -1823,6 +1827,10 @@ public abstract class Type {
|
||||
== ((ScalarType) type1).decimalPrecision()) && (((ScalarType) type2).decimalScale()
|
||||
== ((ScalarType) type1).decimalScale())) {
|
||||
return true;
|
||||
} else if (((ScalarType) type2).decimalScale() == ((ScalarType) type1).decimalScale()
|
||||
&& ignorePrecision) {
|
||||
return isSameDecimalTypeWithDifferentPrecision(((ScalarType) type2).decimalPrecision(),
|
||||
((ScalarType) type1).decimalPrecision());
|
||||
}
|
||||
} else if (type2.isArrayType()) {
|
||||
// For types array, we also need to check contains null for case like
|
||||
@ -1836,5 +1844,20 @@ public abstract class Type {
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
public static boolean isSameDecimalTypeWithDifferentPrecision(int precision1, int precision2) {
|
||||
if (precision1 <= ScalarType.MAX_DECIMAL32_PRECISION && precision2 <= ScalarType.MAX_DECIMAL32_PRECISION) {
|
||||
return true;
|
||||
} else if (precision1 > ScalarType.MAX_DECIMAL32_PRECISION && precision2 > ScalarType.MAX_DECIMAL32_PRECISION
|
||||
&& precision1 <= ScalarType.MAX_DECIMAL64_PRECISION
|
||||
&& precision2 <= ScalarType.MAX_DECIMAL64_PRECISION) {
|
||||
return true;
|
||||
} else if (precision1 > ScalarType.MAX_DECIMAL64_PRECISION && precision2 > ScalarType.MAX_DECIMAL64_PRECISION
|
||||
&& precision1 <= ScalarType.MAX_DECIMAL128_PRECISION
|
||||
&& precision2 <= ScalarType.MAX_DECIMAL128_PRECISION) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -544,7 +544,7 @@ public class ArithmeticExpr extends Expr {
|
||||
// target type: DECIMALV3(max(widthOfIntPart1, widthOfIntPart2) + max(scale1, scale2) + 1,
|
||||
// max(scale1, scale2))
|
||||
scale = Math.max(t1Scale, t2Scale);
|
||||
precision = Math.max(widthOfIntPart1, widthOfIntPart2) + scale;
|
||||
precision = Math.max(widthOfIntPart1, widthOfIntPart2) + scale + 1;
|
||||
} else {
|
||||
scale = Math.max(t1Scale, t2Scale);
|
||||
precision = widthOfIntPart2 + scale;
|
||||
@ -559,10 +559,10 @@ public class ArithmeticExpr extends Expr {
|
||||
}
|
||||
type = ScalarType.createDecimalV3Type(precision, scale);
|
||||
if (op == Operator.ADD || op == Operator.SUBTRACT) {
|
||||
if (!Type.matchExactType(type, children.get(0).type)) {
|
||||
if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(0).type).getScalarScale()) {
|
||||
castChild(type, 0);
|
||||
}
|
||||
if (!Type.matchExactType(type, children.get(1).type)) {
|
||||
if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(1).type).getScalarScale()) {
|
||||
castChild(type, 1);
|
||||
}
|
||||
} else if (op == Operator.DIVIDE && (t2Scale != 0) && t1.isDecimalV3()) {
|
||||
|
||||
@ -288,7 +288,7 @@ public class CastExpr extends Expr {
|
||||
Type childType = getChild(0).getType();
|
||||
|
||||
// this cast may result in loss of precision, but the user requested it
|
||||
noOp = Type.matchExactType(childType, type);
|
||||
noOp = Type.matchExactType(childType, type, true);
|
||||
|
||||
if (noOp) {
|
||||
// For decimalv2, we do not perform an actual cast between different precision/scale. Instead, we just
|
||||
|
||||
@ -32,3 +32,18 @@
|
||||
2.0736
|
||||
3.2399999999999998
|
||||
|
||||
-- !select_all --
|
||||
999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999 999999.999
|
||||
|
||||
-- !select --
|
||||
2999999.997
|
||||
|
||||
-- !select --
|
||||
2999999994000.000003
|
||||
|
||||
-- !select --
|
||||
3.000
|
||||
|
||||
-- !select --
|
||||
10999999.989
|
||||
|
||||
|
||||
@ -49,4 +49,20 @@ suite("test_arithmetic_expressions") {
|
||||
qt_select "select k1 * k2 * k3 * k1 * k2 * k3 from ${table1} order by k1"
|
||||
qt_select "select k1 * k2 / k3 * k1 * k2 * k3 from ${table1} order by k1"
|
||||
sql "drop table if exists ${table1}"
|
||||
|
||||
sql """
|
||||
CREATE TABLE IF NOT EXISTS ${table1} ( `a` DECIMALV3(9, 3) NOT NULL, `b` DECIMALV3(9, 3) NOT NULL, `c` DECIMALV3(9, 3) NOT NULL, `d` DECIMALV3(9, 3) NOT NULL, `e` DECIMALV3(9, 3) NOT NULL, `f` DECIMALV3(9, 3) NOT
|
||||
NULL, `g` DECIMALV3(9, 3) NOT NULL , `h` DECIMALV3(9, 3) NOT NULL, `i` DECIMALV3(9, 3) NOT NULL, `j` DECIMALV3(9, 3) NOT NULL, `k` DECIMALV3(9, 3) NOT NULL) DISTRIBUTED BY HASH(a) PROPERTIES("replication_num" = "1");
|
||||
"""
|
||||
|
||||
sql """
|
||||
insert into ${table1} values(999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999,999999.999);
|
||||
"""
|
||||
qt_select_all "select * from ${table1} order by a"
|
||||
|
||||
qt_select "select a + b + c from ${table1};"
|
||||
qt_select "select (a + b + c) * d from ${table1};"
|
||||
qt_select "select (a + b + c) / d from ${table1};"
|
||||
qt_select "select a + b + c + d + e + f + g + h + i + j + k from ${table1};"
|
||||
sql "drop table if exists ${table1}"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user