[fix](DECIMALV3)fix cumulative precision when literal and DECIMALV3 operations in Legacy (#20354)

The precision handling for division with DECIMALV3 is as follows (excluding cases where division increases precision):

(p1, s1) / (p2, s2) ----> (p1 + s2, s1)

However, due to precision loss in division, it is considered to increase the precision of the left operand:

(p1, s1) / (p2, s2) =====> (p1 + s2, s1 + s2) / (p2, s2) ----> (p1 + s2, s1)

However, the legacy optimizer repeats the analyze and substitute steps for an expression, which can result in the accumulation of precision:

(p1, s1) / (p2, s2) =====> (p1 + s2, s1 + s2) / (p2, s2) =====> (p1 + s2 + s2, s1 + s2 + s2) / (p2, s2)

To address this, the previous approach was to forcibly convert the left operand of DECIMALV3 calculations. This results in rewriting the expression as:

(p1, s1) / (p2, s2) =====> cast((p1, s1) as (p1 + s2, s1 + s2)) / (p2, s2)

Then, during the substitution step, a check is performed. If it is a cast expression, the expression modified by the cast is extracted:

cast((p1, s1) as (p1 + s2, s1 + s2)) =====> (p1, s1)

protected Expr substituteImpl(ExprSubstitutionMap smap, ExprSubstitutionMap disjunctsMap, Analyzer analyzer) {
        if (isImplicitCast()) {
            return getChild(0).substituteImpl(smap, disjunctsMap, analyzer);
        }
This way, there won't be repeated analysis, preventing the continuous increase in precision. However, if the left expression is a constant (literal), theoretically, the precision would continue to increase. Unfortunately, the code that was removed in this PR (#19926) obscured this issue.

for (Expr child : children) {
    if (child instanceof DecimalLiteral && child.getType().isDecimalV3()) {
      ((DecimalLiteral)child).tryToReduceType();
    }
}
An attempt will be made to reduce the precision of literals in the expressions. However, this code snippet can cause such a bug.

mysql [test]>select cast(1 as DECIMALV3(16, 2)) /  cast(3 as DECIMALV3(16, 2));
+-----------------------------------------------------------+
| CAST(1 AS DECIMALV3(16, 2)) / CAST(3 AS DECIMALV3(16, 2)) |
+-----------------------------------------------------------+
|                                                      0.00 |
+-----------------------------------------------------------+
1.00 / 3.00, due to reduced precision, becomes 1 / 3.
<--Describe your changes.-->
This commit is contained in:
Mryange
2023-06-09 08:58:55 +08:00
committed by GitHub
parent 079fb0e56d
commit 4c6df9062e
6 changed files with 41 additions and 5 deletions

View File

@ -505,7 +505,7 @@ public class ArithmeticExpr extends Expr {
if (((ScalarType) type).getScalarScale() != ((ScalarType) children.get(1).type).getScalarScale()) {
castChild(type, 1);
}
} else if (op == Operator.DIVIDE && t1TargetType.isDecimalV3()) {
} else if (op == Operator.DIVIDE && (t1TargetType.isDecimalV3())) {
int leftPrecision = t1Precision + t2Scale + Config.div_precision_increment;
int leftScale = t1Scale + t2Scale + Config.div_precision_increment;
if (leftPrecision > ScalarType.MAX_DECIMAL128_PRECISION) {
@ -515,7 +515,15 @@ public class ArithmeticExpr extends Expr {
type = castBinaryOp(Type.DOUBLE);
break;
}
castChild(ScalarType.createDecimalV3Type(leftPrecision, leftScale), 0);
Expr child = getChild(0);
if (child instanceof DecimalLiteral) {
DecimalLiteral literalChild = (DecimalLiteral) child;
Expr newChild = literalChild
.castToDecimalV3ByDivde(ScalarType.createDecimalV3Type(leftPrecision, leftScale));
setChild(0, newChild);
} else {
castChild(ScalarType.createDecimalV3Type(leftPrecision, leftScale), 0);
}
} else if (op == Operator.MOD) {
// TODO use max int part + max scale of two operands as result type
// because BE require the result and operands types are the exact the same decimalv3 type

View File

@ -63,6 +63,8 @@ public class CastExpr extends Expr {
// True if this cast does not change the type.
private boolean noOp = false;
private boolean notFold = false;
private static final Map<Pair<Type, Type>, Function.NullableMode> TYPE_NULLABLE_MODE;
static {
@ -582,5 +584,13 @@ public class CastExpr extends Expr {
public String getStringValueForArray() {
return children.get(0).getStringValueForArray();
}
public void setNotFold(boolean notFold) {
this.notFold = notFold;
}
public boolean isNotFold() {
return this.notFold;
}
}

View File

@ -384,6 +384,13 @@ public class DecimalLiteral extends LiteralExpr {
return super.uncheckedCastTo(targetType);
}
public Expr castToDecimalV3ByDivde(Type targetType) {
// onlye use in DecimalLiteral divide DecimalV3
CastExpr expr = new CastExpr(targetType, this);
expr.setNotFold(true);
return expr;
}
@Override
public int hashCode() {
return 31 * super.hashCode() + Objects.hashCode(value);

View File

@ -109,6 +109,9 @@ public class FoldConstantsRule implements ExprRewriteRule {
// cast-to-types and that can lead to query failures, e.g., CTAS
if (expr instanceof CastExpr) {
CastExpr castExpr = (CastExpr) expr;
if (castExpr.isNotFold()) {
return castExpr;
}
if (castExpr.getChild(0) instanceof NullLiteral) {
return castExpr.getChild(0);
}

View File

@ -1,9 +1,14 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !select1 --
0.333333333333333333333333
0.250000000000000000000000
0.200000000000000000000000
0.333333
0.250000
0.200000
-- !select2 --
0.333333
-- !select3 --
0.33333
0.25000
0.20000

View File

@ -47,4 +47,7 @@ suite("test_cast_as_decimalv3") {
qt_select2 """
select cast(1 as DECIMALV3(5, 2)) / cast(3 as DECIMALV3(5, 2))
"""
qt_select3 """
select 1.0 / val from divtest order by id
"""
}