[enhancement](nereids) convert string literal to commontype in in-expr and cass-when-expr (#17200)

This commit is contained in:
morrySnow
2023-03-02 22:05:35 +08:00
committed by GitHub
parent 93d2d461b4
commit 3eeeff09fd
7 changed files with 169 additions and 80 deletions

View File

@ -35,6 +35,7 @@ header:
- "**/*.log"
- "**/*.sql"
- "**/*.lock"
- "**/*.out"
- "tsan_suppressions"
- "docs/.markdownlintignore"
- "fe/fe-core/src/test/resources/data/net_snmp_normal"

View File

@ -35,21 +35,26 @@ import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonArray;
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonObject;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
@ -218,9 +223,6 @@ class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
.map(e -> e.accept(this, context)).collect(Collectors.toList());
CaseWhen newCaseWhen = caseWhen.withChildren(rewrittenChildren);
// check
newCaseWhen.checkLegalityBeforeTypeCoercion();
// type coercion
List<DataType> dataTypesForCoercion = newCaseWhen.dataTypesForCoercion();
if (dataTypesForCoercion.size() <= 1) {
@ -230,20 +232,37 @@ class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
return newCaseWhen;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren
= newCaseWhen.getWhenClauses().stream()
.map(wc -> wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)))
.collect(Collectors.toList());
newCaseWhen.getDefaultValue()
.map(dv -> TypeCoercionUtils.castIfNotMatchType(dv, commonType))
.ifPresent(newChildren::add);
return newCaseWhen.withChildren(newChildren);
})
.orElse(newCaseWhen);
Map<Boolean, List<Expression>> filteredStringLiteral = newCaseWhen.expressionForCoercion()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));
if (!optionalCommonType.isPresent()) {
return newCaseWhen;
}
DataType commonType = optionalCommonType.get();
// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}
List<Expression> newChildren = Lists.newArrayList();
for (WhenClause wc : newCaseWhen.getWhenClauses()) {
newChildren.add(wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)));
}
if (newCaseWhen.getDefaultValue().isPresent()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(newCaseWhen.getDefaultValue().get(), commonType));
}
return newCaseWhen.withChildren(newChildren);
}
@Override
@ -257,17 +276,32 @@ class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
return newInPredicate;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()
Map<Boolean, List<Expression>> filteredStringLiteral = newInPredicate.children()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren = newInPredicate.children().stream()
.map(e -> TypeCoercionUtils.castIfNotMatchType(e, commonType))
.collect(Collectors.toList());
return newInPredicate.withChildren(newChildren);
})
.orElse(newInPredicate);
if (!optionalCommonType.isPresent()) {
return newInPredicate;
}
DataType commonType = optionalCommonType.get();
// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : newInPredicate.children()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(child, commonType));
}
return newInPredicate.withChildren(newChildren);
}
private Expression visitImplicitCastInputTypes(Expression expr, List<AbstractDataType> expectedInputTypes) {

View File

@ -35,13 +35,19 @@ import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.WhenClause;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.coercion.AbstractDataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.Lists;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
@ -152,20 +158,37 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
return newCaseWhen;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren
= newCaseWhen.getWhenClauses().stream()
.map(wc -> wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)))
.collect(Collectors.toList());
newCaseWhen.getDefaultValue()
.map(dv -> TypeCoercionUtils.castIfNotMatchType(dv, commonType))
.ifPresent(newChildren::add);
return newCaseWhen.withChildren(newChildren);
})
.orElse(newCaseWhen);
Map<Boolean, List<Expression>> filteredStringLiteral = newCaseWhen.expressionForCoercion()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));
if (!optionalCommonType.isPresent()) {
return newCaseWhen;
}
DataType commonType = optionalCommonType.get();
// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}
List<Expression> newChildren = Lists.newArrayList();
for (WhenClause wc : newCaseWhen.getWhenClauses()) {
newChildren.add(wc.withChildren(wc.getOperand(),
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)));
}
if (newCaseWhen.getDefaultValue().isPresent()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(newCaseWhen.getDefaultValue().get(), commonType));
}
return newCaseWhen.withChildren(newChildren);
}
@Override
@ -178,17 +201,32 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
return newInPredicate;
}
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()
Map<Boolean, List<Expression>> filteredStringLiteral = newInPredicate.children()
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
.stream().map(Expression::getDataType).collect(Collectors.toList()));
return optionalCommonType
.map(commonType -> {
List<Expression> newChildren = newInPredicate.children().stream()
.map(e -> TypeCoercionUtils.castIfNotMatchType(e, commonType))
.collect(Collectors.toList());
return newInPredicate.withChildren(newChildren);
})
.orElse(newInPredicate);
if (!optionalCommonType.isPresent()) {
return newInPredicate;
}
DataType commonType = optionalCommonType.get();
// process character literal
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
Literal literal = (Literal) stringLikeLiteral;
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
literal.getStringValue(), commonType).isPresent()) {
commonType = StringType.INSTANCE;
break;
}
}
List<Expression> newChildren = Lists.newArrayList();
for (Expression child : newInPredicate.children()) {
newChildren.add(TypeCoercionUtils.castIfNotMatchType(child, commonType));
}
return newInPredicate.withChildren(newChildren);
}
@Override

View File

@ -29,6 +29,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
@ -71,6 +72,12 @@ public class CaseWhen extends Expression {
.collect(ImmutableList.toImmutableList());
}
public List<Expression> expressionForCoercion() {
List<Expression> ret = whenClauses.stream().map(WhenClause::getResult).collect(Collectors.toList());
defaultValue.ifPresent(ret::add);
return ret;
}
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitCaseWhen(this, context);
}

View File

@ -0,0 +1,37 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test_compare_expression --
0
-- !test_compare_expression_2 --
0
-- !test_compare_expression_3 --
true
-- !test_compare_expression_4 --
\N
-- !test_compare_expression_5 --
false
-- !test_compare_expression_6 --
\N
-- !test_compare_expression_7 --
true
-- !test_compare_expression_8 --
\N
-- !test_compare_expression_9 --
false
-- !test_compare_expression_10 --
\N
-- !test_compare_expression_11 --
true
-- !test_compare_expression_12 --
2008-08-08T00:00

View File

@ -1,31 +0,0 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !test_in_predicate_with_null --
0
-- !test_in_predicate_with_null_2 --
0
-- !test_in_predicate_with_null_3 --
true
-- !test_in_predicate_with_null_4 --
\N
-- !test_in_predicate_with_null_5 --
false
-- !test_in_predicate_with_null_6 --
\N
-- !test_in_predicate_with_null_7 --
true
-- !test_in_predicate_with_null_8 --
\N
-- !test_in_predicate_with_null_9 --
false
-- !test_in_predicate_with_null_10 --
\N

View File

@ -31,3 +31,6 @@ select 1 not in (null, 1);
select 1 not in (null, 2);
select timestamp '2008-08-08 00:00:00' in ('2008-08-08');
select case when true then timestamp '2008-08-08 00:00:00' else '2008-08-08' end;