[fix](DECIMALV3)fix cumulative precision when literal and DECIMALV3 operations in Legacy (#20354)
The precision handling for division with DECIMALV3 is as follows (excluding cases where division increases precision):
(p1, s1) / (p2, s2) ----> (p1 + s2, s1)
However, due to precision loss in division, it is considered to increase the precision of the left operand:
(p1, s1) / (p2, s2) =====> (p1 + s2, s1 + s2) / (p2, s2) ----> (p1 + s2, s1)
However, the legacy optimizer repeats the analyze and substitute steps for an expression, which can result in the accumulation of precision:
(p1, s1) / (p2, s2) =====> (p1 + s2, s1 + s2) / (p2, s2) =====> (p1 + s2 + s2, s1 + s2 + s2) / (p2, s2)
To address this, the previous approach was to forcibly convert the left operand of DECIMALV3 calculations. This results in rewriting the expression as:
(p1, s1) / (p2, s2) =====> cast((p1, s1) as (p1 + s2, s1 + s2)) / (p2, s2)
Then, during the substitution step, a check is performed. If it is a cast expression, the expression modified by the cast is extracted:
cast((p1, s1) as (p1 + s2, s1 + s2)) =====> (p1, s1)
protected Expr substituteImpl(ExprSubstitutionMap smap, ExprSubstitutionMap disjunctsMap, Analyzer analyzer) {
if (isImplicitCast()) {
return getChild(0).substituteImpl(smap, disjunctsMap, analyzer);
}
This way, there won't be repeated analysis, preventing the continuous increase in precision. However, if the left expression is a constant (literal), theoretically, the precision would continue to increase. Unfortunately, the code that was removed in this PR (#19926) obscured this issue.
for (Expr child : children) {
if (child instanceof DecimalLiteral && child.getType().isDecimalV3()) {
((DecimalLiteral)child).tryToReduceType();
}
}
An attempt will be made to reduce the precision of literals in the expressions. However, this code snippet can cause such a bug.
mysql [test]>select cast(1 as DECIMALV3(16, 2)) / cast(3 as DECIMALV3(16, 2));
+-----------------------------------------------------------+
| CAST(1 AS DECIMALV3(16, 2)) / CAST(3 AS DECIMALV3(16, 2)) |
+-----------------------------------------------------------+
| 0.00 |
+-----------------------------------------------------------+
1.00 / 3.00, due to reduced precision, becomes 1 / 3.
<--Describe your changes.-->
This commit is contained in:
@ -505,7 +505,7 @@ public class ArithmeticExpr extends Expr {
|
||||
if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(1).type).getScalarScale()) {
|
||||
castChild(type, 1);
|
||||
}
|
||||
} else if (op == Operator.DIVIDE && t1TargetType.isDecimalV3()) {
|
||||
} else if (op == Operator.DIVIDE && (t1TargetType.isDecimalV3())) {
|
||||
int leftPrecision = t1Precision + t2Scale + Config.div_precision_increment;
|
||||
int leftScale = t1Scale + t2Scale + Config.div_precision_increment;
|
||||
if (leftPrecision > ScalarType.MAX_DECIMAL128_PRECISION) {
|
||||
@ -515,7 +515,15 @@ public class ArithmeticExpr extends Expr {
|
||||
type = castBinaryOp(Type.DOUBLE);
|
||||
break;
|
||||
}
|
||||
castChild(ScalarType.createDecimalV3Type(leftPrecision, leftScale), 0);
|
||||
Expr child = getChild(0);
|
||||
if (child instanceof DecimalLiteral) {
|
||||
DecimalLiteral literalChild = (DecimalLiteral) child;
|
||||
Expr newChild = literalChild
|
||||
.castToDecimalV3ByDivde(ScalarType.createDecimalV3Type(leftPrecision, leftScale));
|
||||
setChild(0, newChild);
|
||||
} else {
|
||||
castChild(ScalarType.createDecimalV3Type(leftPrecision, leftScale), 0);
|
||||
}
|
||||
} else if (op == Operator.MOD) {
|
||||
// TODO use max int part + max scale of two operands as result type
|
||||
// because BE require the result and operands types are the exact the same decimalv3 type
|
||||
|
||||
@ -63,6 +63,8 @@ public class CastExpr extends Expr {
|
||||
// True if this cast does not change the type.
|
||||
private boolean noOp = false;
|
||||
|
||||
private boolean notFold = false;
|
||||
|
||||
private static final Map<Pair<Type, Type>, Function.NullableMode> TYPE_NULLABLE_MODE;
|
||||
|
||||
static {
|
||||
@ -582,5 +584,13 @@ public class CastExpr extends Expr {
|
||||
public String getStringValueForArray() {
|
||||
return children.get(0).getStringValueForArray();
|
||||
}
|
||||
|
||||
public void setNotFold(boolean notFold) {
|
||||
this.notFold = notFold;
|
||||
}
|
||||
|
||||
public boolean isNotFold() {
|
||||
return this.notFold;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -384,6 +384,13 @@ public class DecimalLiteral extends LiteralExpr {
|
||||
return super.uncheckedCastTo(targetType);
|
||||
}
|
||||
|
||||
public Expr castToDecimalV3ByDivde(Type targetType) {
|
||||
// onlye use in DecimalLiteral divide DecimalV3
|
||||
CastExpr expr = new CastExpr(targetType, this);
|
||||
expr.setNotFold(true);
|
||||
return expr;
|
||||
}
|
||||
|
||||
@Override
|
||||
public int hashCode() {
|
||||
return 31 * super.hashCode() + Objects.hashCode(value);
|
||||
|
||||
@ -109,6 +109,9 @@ public class FoldConstantsRule implements ExprRewriteRule {
|
||||
// cast-to-types and that can lead to query failures, e.g., CTAS
|
||||
if (expr instanceof CastExpr) {
|
||||
CastExpr castExpr = (CastExpr) expr;
|
||||
if (castExpr.isNotFold()) {
|
||||
return castExpr;
|
||||
}
|
||||
if (castExpr.getChild(0) instanceof NullLiteral) {
|
||||
return castExpr.getChild(0);
|
||||
}
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !select1 --
|
||||
0.333333333333333333333333
|
||||
0.250000000000000000000000
|
||||
0.200000000000000000000000
|
||||
0.333333
|
||||
0.250000
|
||||
0.200000
|
||||
|
||||
-- !select2 --
|
||||
0.333333
|
||||
|
||||
-- !select3 --
|
||||
0.33333
|
||||
0.25000
|
||||
0.20000
|
||||
|
||||
|
||||
@ -47,4 +47,7 @@ suite("test_cast_as_decimalv3") {
|
||||
qt_select2 """
|
||||
select cast(1 as DECIMALV3(5, 2)) / cast(3 as DECIMALV3(5, 2))
|
||||
"""
|
||||
qt_select3 """
|
||||
select 1.0 / val from divtest order by id
|
||||
"""
|
||||
}
|
||||
Reference in New Issue
Block a user