[Fix](Nereids) fix type coercion for binary arithmetic (#15185)

support sql like: select true + 1 + '2.0' and prevent select true + 1 + 'x';
This commit is contained in:
mch_ucchi
2023-01-11 02:55:44 +08:00
committed by GitHub
parent c87a9a5949
commit bc34a44f06
5 changed files with 492 additions and 8 deletions

View File

@ -32,6 +32,7 @@ import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
@ -46,6 +47,7 @@ import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import java.math.BigDecimal;
import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
@ -84,6 +86,15 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
if (binaryOperator instanceof ImplicitCastInputTypes) {
List<AbstractDataType> expectedInputTypes = ((ImplicitCastInputTypes) binaryOperator).expectedInputTypes();
if (!expectedInputTypes.isEmpty()) {
binaryOperator.children().stream().filter(e -> e instanceof StringLikeLiteral)
.forEach(expr -> {
try {
new BigDecimal(((StringLikeLiteral) expr).getStringValue());
} catch (NumberFormatException e) {
throw new IllegalStateException(String.format(
"string literal %s cannot be cast to double", expr.toSql()));
}
});
binaryOperator = (BinaryOperator) visitImplicitCastInputTypes(binaryOperator, expectedInputTypes,
context);
}
@ -111,13 +122,14 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
public Expression visitDivide(Divide divide, ExpressionRewriteContext context) {
Expression left = rewrite(divide.left(), context);
Expression right = rewrite(divide.right(), context);
DataType t1 = TypeCoercionUtils.getNumResultType(left.getDataType());
DataType t2 = TypeCoercionUtils.getNumResultType(right.getDataType());
DataType commonType = TypeCoercionUtils.findCommonNumericsType(t1, t2);
if (divide.getLegacyOperator() == Operator.DIVIDE) {
if (commonType.isBigIntType() || commonType.isLargeIntType()) {
commonType = DoubleType.INSTANCE;
}
if (divide.getLegacyOperator() == Operator.DIVIDE
&& (commonType.isBigIntType() || commonType.isLargeIntType())) {
commonType = DoubleType.INSTANCE;
}
Expression newLeft = TypeCoercionUtils.castIfNotSameType(left, commonType);
Expression newRight = TypeCoercionUtils.castIfNotSameType(right, commonType);

View File

@ -148,10 +148,8 @@ public class TypeCoercionUtils {
* return ture if two type could do type coercion.
*/
public static boolean canHandleTypeCoercion(DataType leftType, DataType rightType) {
if (leftType instanceof DecimalV2Type && rightType instanceof NullType) {
return true;
}
if (leftType instanceof NullType && rightType instanceof DecimalV2Type) {
if (leftType instanceof DecimalV2Type && rightType instanceof NullType
|| leftType instanceof NullType && rightType instanceof DecimalV2Type) {
return true;
}
if (leftType instanceof DecimalV2Type && rightType instanceof IntegralType

View File

@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.expression.rewrite;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.rewrite.rules.TypeCoercion;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.CaseWhen;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Divide;
@ -186,6 +187,15 @@ public class TypeCoercionTest extends ExpressionRewriteTestHelper {
Expression expected = new Divide(new Cast(Literal.of((short) 1), DoubleType.INSTANCE),
new Cast(Literal.of(10L), DoubleType.INSTANCE));
assertRewrite(actual, expected);
Expression actual1 = new Add(new IntegerLiteral(1), new Add(BooleanLiteral.TRUE, new StringLiteral("2")));
Expression expected1 = new Add(new Cast(new IntegerLiteral(1), DoubleType.INSTANCE), new Add(
new Cast(BooleanLiteral.TRUE, DoubleType.INSTANCE),
new Cast(new StringLiteral("2"), DoubleType.INSTANCE)));
assertRewrite(actual1, expected1);
Expression actual2 = new Add(new IntegerLiteral(1), new Add(BooleanLiteral.TRUE, new StringLiteral("x")));
Assertions.assertThrows(IllegalStateException.class, () -> assertRewrite(actual2, null));
}
private DataType checkAndGetDataType(Expression expression) {