[enhancement](nereids) convert string literal to commontype in in-expr and cass-when-expr (#17200)
This commit is contained in:
@ -35,6 +35,7 @@ header:
|
||||
- "**/*.log"
|
||||
- "**/*.sql"
|
||||
- "**/*.lock"
|
||||
- "**/*.out"
|
||||
- "tsan_suppressions"
|
||||
- "docs/.markdownlintignore"
|
||||
- "fe/fe-core/src/test/resources/data/net_snmp_normal"
|
||||
|
||||
@ -35,21 +35,26 @@ import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
|
||||
import org.apache.doris.nereids.trees.expressions.Not;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.WhenClause;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonArray;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.scalar.JsonObject;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Locale;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -218,9 +223,6 @@ class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
|
||||
.map(e -> e.accept(this, context)).collect(Collectors.toList());
|
||||
CaseWhen newCaseWhen = caseWhen.withChildren(rewrittenChildren);
|
||||
|
||||
// check
|
||||
newCaseWhen.checkLegalityBeforeTypeCoercion();
|
||||
|
||||
// type coercion
|
||||
List<DataType> dataTypesForCoercion = newCaseWhen.dataTypesForCoercion();
|
||||
if (dataTypesForCoercion.size() <= 1) {
|
||||
@ -230,20 +232,37 @@ class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
|
||||
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
|
||||
return newCaseWhen;
|
||||
}
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
|
||||
return optionalCommonType
|
||||
.map(commonType -> {
|
||||
List<Expression> newChildren
|
||||
= newCaseWhen.getWhenClauses().stream()
|
||||
.map(wc -> wc.withChildren(wc.getOperand(),
|
||||
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)))
|
||||
.collect(Collectors.toList());
|
||||
newCaseWhen.getDefaultValue()
|
||||
.map(dv -> TypeCoercionUtils.castIfNotMatchType(dv, commonType))
|
||||
.ifPresent(newChildren::add);
|
||||
return newCaseWhen.withChildren(newChildren);
|
||||
})
|
||||
.orElse(newCaseWhen);
|
||||
|
||||
Map<Boolean, List<Expression>> filteredStringLiteral = newCaseWhen.expressionForCoercion()
|
||||
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
|
||||
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
|
||||
.stream().map(Expression::getDataType).collect(Collectors.toList()));
|
||||
|
||||
if (!optionalCommonType.isPresent()) {
|
||||
return newCaseWhen;
|
||||
}
|
||||
DataType commonType = optionalCommonType.get();
|
||||
|
||||
// process character literal
|
||||
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
|
||||
Literal literal = (Literal) stringLikeLiteral;
|
||||
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
|
||||
literal.getStringValue(), commonType).isPresent()) {
|
||||
commonType = StringType.INSTANCE;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
List<Expression> newChildren = Lists.newArrayList();
|
||||
for (WhenClause wc : newCaseWhen.getWhenClauses()) {
|
||||
newChildren.add(wc.withChildren(wc.getOperand(),
|
||||
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)));
|
||||
}
|
||||
if (newCaseWhen.getDefaultValue().isPresent()) {
|
||||
newChildren.add(TypeCoercionUtils.castIfNotMatchType(newCaseWhen.getDefaultValue().get(), commonType));
|
||||
}
|
||||
return newCaseWhen.withChildren(newChildren);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -257,17 +276,32 @@ class FunctionBinder extends DefaultExpressionRewriter<CascadesContext> {
|
||||
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
|
||||
return newInPredicate;
|
||||
}
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()
|
||||
|
||||
Map<Boolean, List<Expression>> filteredStringLiteral = newInPredicate.children()
|
||||
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
|
||||
.stream().map(Expression::getDataType).collect(Collectors.toList()));
|
||||
|
||||
return optionalCommonType
|
||||
.map(commonType -> {
|
||||
List<Expression> newChildren = newInPredicate.children().stream()
|
||||
.map(e -> TypeCoercionUtils.castIfNotMatchType(e, commonType))
|
||||
.collect(Collectors.toList());
|
||||
return newInPredicate.withChildren(newChildren);
|
||||
})
|
||||
.orElse(newInPredicate);
|
||||
if (!optionalCommonType.isPresent()) {
|
||||
return newInPredicate;
|
||||
}
|
||||
DataType commonType = optionalCommonType.get();
|
||||
|
||||
// process character literal
|
||||
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
|
||||
Literal literal = (Literal) stringLikeLiteral;
|
||||
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
|
||||
literal.getStringValue(), commonType).isPresent()) {
|
||||
commonType = StringType.INSTANCE;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
List<Expression> newChildren = Lists.newArrayList();
|
||||
for (Expression child : newInPredicate.children()) {
|
||||
newChildren.add(TypeCoercionUtils.castIfNotMatchType(child, commonType));
|
||||
}
|
||||
return newInPredicate.withChildren(newChildren);
|
||||
}
|
||||
|
||||
private Expression visitImplicitCastInputTypes(Expression expr, List<AbstractDataType> expectedInputTypes) {
|
||||
|
||||
@ -35,13 +35,19 @@ import org.apache.doris.nereids.trees.expressions.InPredicate;
|
||||
import org.apache.doris.nereids.trees.expressions.IntegralDivide;
|
||||
import org.apache.doris.nereids.trees.expressions.Not;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.WhenClause;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.coercion.AbstractDataType;
|
||||
import org.apache.doris.nereids.util.TypeCoercionUtils;
|
||||
|
||||
import com.google.common.collect.Lists;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
@ -152,20 +158,37 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
|
||||
if (dataTypesForCoercion.stream().allMatch(dataType -> dataType.equals(first))) {
|
||||
return newCaseWhen;
|
||||
}
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(dataTypesForCoercion);
|
||||
return optionalCommonType
|
||||
.map(commonType -> {
|
||||
List<Expression> newChildren
|
||||
= newCaseWhen.getWhenClauses().stream()
|
||||
.map(wc -> wc.withChildren(wc.getOperand(),
|
||||
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)))
|
||||
.collect(Collectors.toList());
|
||||
newCaseWhen.getDefaultValue()
|
||||
.map(dv -> TypeCoercionUtils.castIfNotMatchType(dv, commonType))
|
||||
.ifPresent(newChildren::add);
|
||||
return newCaseWhen.withChildren(newChildren);
|
||||
})
|
||||
.orElse(newCaseWhen);
|
||||
|
||||
Map<Boolean, List<Expression>> filteredStringLiteral = newCaseWhen.expressionForCoercion()
|
||||
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
|
||||
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
|
||||
.stream().map(Expression::getDataType).collect(Collectors.toList()));
|
||||
|
||||
if (!optionalCommonType.isPresent()) {
|
||||
return newCaseWhen;
|
||||
}
|
||||
DataType commonType = optionalCommonType.get();
|
||||
|
||||
// process character literal
|
||||
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
|
||||
Literal literal = (Literal) stringLikeLiteral;
|
||||
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
|
||||
literal.getStringValue(), commonType).isPresent()) {
|
||||
commonType = StringType.INSTANCE;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
List<Expression> newChildren = Lists.newArrayList();
|
||||
for (WhenClause wc : newCaseWhen.getWhenClauses()) {
|
||||
newChildren.add(wc.withChildren(wc.getOperand(),
|
||||
TypeCoercionUtils.castIfNotMatchType(wc.getResult(), commonType)));
|
||||
}
|
||||
if (newCaseWhen.getDefaultValue().isPresent()) {
|
||||
newChildren.add(TypeCoercionUtils.castIfNotMatchType(newCaseWhen.getDefaultValue().get(), commonType));
|
||||
}
|
||||
return newCaseWhen.withChildren(newChildren);
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -178,17 +201,32 @@ public class TypeCoercion extends AbstractExpressionRewriteRule {
|
||||
.allMatch(dt -> dt.equals(newInPredicate.getCompareExpr().getDataType()))) {
|
||||
return newInPredicate;
|
||||
}
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(newInPredicate.children()
|
||||
|
||||
Map<Boolean, List<Expression>> filteredStringLiteral = newInPredicate.children()
|
||||
.stream().collect(Collectors.partitioningBy(e -> e.isLiteral() && e.getDataType().isStringLikeType()));
|
||||
Optional<DataType> optionalCommonType = TypeCoercionUtils.findWiderCommonType(filteredStringLiteral.get(false)
|
||||
.stream().map(Expression::getDataType).collect(Collectors.toList()));
|
||||
|
||||
return optionalCommonType
|
||||
.map(commonType -> {
|
||||
List<Expression> newChildren = newInPredicate.children().stream()
|
||||
.map(e -> TypeCoercionUtils.castIfNotMatchType(e, commonType))
|
||||
.collect(Collectors.toList());
|
||||
return newInPredicate.withChildren(newChildren);
|
||||
})
|
||||
.orElse(newInPredicate);
|
||||
if (!optionalCommonType.isPresent()) {
|
||||
return newInPredicate;
|
||||
}
|
||||
DataType commonType = optionalCommonType.get();
|
||||
|
||||
// process character literal
|
||||
for (Expression stringLikeLiteral : filteredStringLiteral.get(true)) {
|
||||
Literal literal = (Literal) stringLikeLiteral;
|
||||
if (!TypeCoercionUtils.characterLiteralTypeCoercion(
|
||||
literal.getStringValue(), commonType).isPresent()) {
|
||||
commonType = StringType.INSTANCE;
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
List<Expression> newChildren = Lists.newArrayList();
|
||||
for (Expression child : newInPredicate.children()) {
|
||||
newChildren.add(TypeCoercionUtils.castIfNotMatchType(child, commonType));
|
||||
}
|
||||
return newInPredicate.withChildren(newChildren);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -29,6 +29,7 @@ import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.Optional;
|
||||
import java.util.stream.Collectors;
|
||||
import java.util.stream.Stream;
|
||||
|
||||
/**
|
||||
@ -71,6 +72,12 @@ public class CaseWhen extends Expression {
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
public List<Expression> expressionForCoercion() {
|
||||
List<Expression> ret = whenClauses.stream().map(WhenClause::getResult).collect(Collectors.toList());
|
||||
defaultValue.ifPresent(ret::add);
|
||||
return ret;
|
||||
}
|
||||
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitCaseWhen(this, context);
|
||||
}
|
||||
|
||||
@ -0,0 +1,37 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !test_compare_expression --
|
||||
0
|
||||
|
||||
-- !test_compare_expression_2 --
|
||||
0
|
||||
|
||||
-- !test_compare_expression_3 --
|
||||
true
|
||||
|
||||
-- !test_compare_expression_4 --
|
||||
\N
|
||||
|
||||
-- !test_compare_expression_5 --
|
||||
false
|
||||
|
||||
-- !test_compare_expression_6 --
|
||||
\N
|
||||
|
||||
-- !test_compare_expression_7 --
|
||||
true
|
||||
|
||||
-- !test_compare_expression_8 --
|
||||
\N
|
||||
|
||||
-- !test_compare_expression_9 --
|
||||
false
|
||||
|
||||
-- !test_compare_expression_10 --
|
||||
\N
|
||||
|
||||
-- !test_compare_expression_11 --
|
||||
true
|
||||
|
||||
-- !test_compare_expression_12 --
|
||||
2008-08-08T00:00
|
||||
|
||||
@ -1,31 +0,0 @@
|
||||
-- This file is automatically generated. You should know what you did if you want to edit this
|
||||
-- !test_in_predicate_with_null --
|
||||
0
|
||||
|
||||
-- !test_in_predicate_with_null_2 --
|
||||
0
|
||||
|
||||
-- !test_in_predicate_with_null_3 --
|
||||
true
|
||||
|
||||
-- !test_in_predicate_with_null_4 --
|
||||
\N
|
||||
|
||||
-- !test_in_predicate_with_null_5 --
|
||||
false
|
||||
|
||||
-- !test_in_predicate_with_null_6 --
|
||||
\N
|
||||
|
||||
-- !test_in_predicate_with_null_7 --
|
||||
true
|
||||
|
||||
-- !test_in_predicate_with_null_8 --
|
||||
\N
|
||||
|
||||
-- !test_in_predicate_with_null_9 --
|
||||
false
|
||||
|
||||
-- !test_in_predicate_with_null_10 --
|
||||
\N
|
||||
|
||||
@ -31,3 +31,6 @@ select 1 not in (null, 1);
|
||||
select 1 not in (null, 2);
|
||||
|
||||
|
||||
select timestamp '2008-08-08 00:00:00' in ('2008-08-08');
|
||||
select case when true then timestamp '2008-08-08 00:00:00' else '2008-08-08' end;
|
||||
|
||||
Reference in New Issue
Block a user