[fix](Nereids) fix create array function type coercion (#30329)

This commit is contained in:
morrySnow
2024-01-25 13:20:32 +08:00
committed by yiguolei
parent b60a272be0
commit f87484d6b3
2 changed files with 39 additions and 34 deletions

View File

@ -34,6 +34,7 @@ import com.google.common.collect.Sets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
/**
@ -70,38 +71,43 @@ public class Array extends ScalarFunction
public List<FunctionSignature> getSignatures() {
if (arity() == 0) {
return SIGNATURES;
} else {
Map<Boolean, List<DataType>> partitioned = children.stream()
.map(ExpressionTrait::getDataType)
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasCharacterType));
List<DataType> needTypeCoercion = Lists.newArrayList(Sets.newHashSet(partitioned.get(true)));
if (needTypeCoercion.size() > 1 || !partitioned.get(false).isEmpty()) {
needTypeCoercion = needTypeCoercion.stream()
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV2WithDefault).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV3Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV3WithWildcard).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDateTimeV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDateTimeV2WithMax).collect(Collectors.toList()));
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream()
.map(dataType -> FunctionSignature.ret(ArrayType.of(new FollowToArgumentType(0))).varArgs(dataType))
.collect(ImmutableList.toImmutableList());
}
Optional<DataType> commonDataType = TypeCoercionUtils.findWiderCommonTypeForCaseWhen(
children.stream().map(ExpressionTrait::getDataType).collect(Collectors.toList()));
if (commonDataType.isPresent()) {
return ImmutableList.of(
FunctionSignature.ret(ArrayType.of(commonDataType.get())).varArgs(commonDataType.get()));
}
Map<Boolean, List<DataType>> partitioned = children.stream()
.map(ExpressionTrait::getDataType)
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasCharacterType));
List<DataType> needTypeCoercion = Lists.newArrayList(Sets.newHashSet(partitioned.get(true)));
if (needTypeCoercion.size() > 1 || !partitioned.get(false).isEmpty()) {
needTypeCoercion = needTypeCoercion.stream()
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV2WithDefault).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV3Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV3WithWildcard).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDateTimeV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDateTimeV2WithMax).collect(Collectors.toList()));
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream()
.map(dataType -> FunctionSignature.ret(ArrayType.of(new FollowToArgumentType(0))).varArgs(dataType))
.collect(ImmutableList.toImmutableList());
}
}

View File

@ -1277,8 +1277,7 @@ public class TypeCoercionUtils {
/**
* find wider common type for data type list.
*/
@Developing
private static Optional<DataType> findWiderCommonTypeForCaseWhen(List<DataType> dataTypes) {
public static Optional<DataType> findWiderCommonTypeForCaseWhen(List<DataType> dataTypes) {
Map<Boolean, List<DataType>> partitioned = dataTypes.stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasCharacterType));
List<DataType> needTypeCoercion = Lists.newArrayList(Sets.newHashSet(partitioned.get(true)));