[fix](Nereids) nested type literal type coercion and insert values with map (#26669)

This commit is contained in:
morrySnow
2023-11-15 11:13:26 +08:00
committed by GitHub
parent febf4bcb23
commit 2c6d2255c3
13 changed files with 669 additions and 141 deletions

View File

@ -52,7 +52,6 @@ import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.RelationUtil;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -226,7 +225,6 @@ public class BindSink implements AnalysisRuleFactory {
// we skip it.
continue;
}
maybeFallbackCastUnsupportedType(expr, ctx.connectContext);
DataType inputType = expr.getDataType();
DataType targetType = DataType.fromCatalogType(table.getFullSchema().get(i).getType());
Expression castExpr = expr;
@ -309,17 +307,6 @@ public class BindSink implements AnalysisRuleFactory {
}).collect(ImmutableList.toImmutableList());
}
private void maybeFallbackCastUnsupportedType(Expression expression, ConnectContext ctx) {
if (expression.getDataType().isMapType()) {
try {
ctx.getSessionVariable().enableFallbackToOriginalPlannerOnce();
} catch (Exception e) {
throw new AnalysisException("failed to try to fall back to original planner");
}
throw new AnalysisException("failed to cast type when binding sink, type is: " + expression.getDataType());
}
}
private boolean isSourceAndTargetStringLikeType(DataType input, DataType target) {
return input.isStringLikeType() && target.isStringLikeType();
}

View File

@ -31,11 +31,8 @@ import org.apache.doris.nereids.trees.expressions.functions.executable.TimeRound
import org.apache.doris.nereids.trees.expressions.literal.DateLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.StructType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMultimap;
@ -167,7 +164,7 @@ public enum ExpressionEvaluator {
DataType returnType = DataType.convertFromString(annotation.returnType());
List<DataType> argTypes = new ArrayList<>();
for (String type : annotation.argTypes()) {
argTypes.add(replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
argTypes.add(TypeCoercionUtils.replaceDecimalV3WithWildcard(DataType.convertFromString(type)));
}
FunctionSignature signature = new FunctionSignature(name,
argTypes.toArray(new DataType[0]), returnType);
@ -175,31 +172,6 @@ public enum ExpressionEvaluator {
}
}
private DataType replaceDecimalV3WithWildcard(DataType input) {
if (input instanceof ArrayType) {
DataType item = replaceDecimalV3WithWildcard(((ArrayType) input).getItemType());
if (item == ((ArrayType) input).getItemType()) {
return input;
}
return ArrayType.of(item);
} else if (input instanceof MapType) {
DataType keyType = replaceDecimalV3WithWildcard(((MapType) input).getKeyType());
DataType valueType = replaceDecimalV3WithWildcard(((MapType) input).getValueType());
if (keyType == ((MapType) input).getKeyType() && valueType == ((MapType) input).getValueType()) {
return input;
}
return MapType.of(keyType, valueType);
} else if (input instanceof StructType) {
// TODO: support struct type
return input;
} else {
if (input instanceof DecimalV3Type) {
return DecimalV3Type.WILDCARD;
}
return input;
}
}
/**
* function invoker.
*/

View File

@ -57,11 +57,27 @@ public class ComputeSignatureHelper {
/** implementAbstractReturnType */
public static FunctionSignature implementFollowToArgumentReturnType(
FunctionSignature signature, List<Expression> arguments) {
if (signature.returnType instanceof FollowToArgumentType) {
int argumentIndex = ((FollowToArgumentType) signature.returnType).argumentIndex;
return signature.withReturnType(arguments.get(argumentIndex).getDataType());
return signature.withReturnType(replaceFollowToArgumentReturnType(
signature.returnType, signature.argumentsTypes));
}
private static DataType replaceFollowToArgumentReturnType(DataType returnType, List<DataType> argumentTypes) {
if (returnType instanceof ArrayType) {
return ArrayType.of(replaceFollowToArgumentReturnType(
((ArrayType) returnType).getItemType(), argumentTypes));
} else if (returnType instanceof MapType) {
return MapType.of(replaceFollowToArgumentReturnType(((MapType) returnType).getKeyType(), argumentTypes),
replaceFollowToArgumentReturnType(((MapType) returnType).getValueType(), argumentTypes));
} else if (returnType instanceof StructType) {
// TODO: do not support struct type now
// throw new AnalysisException("do not support struct type now");
return returnType;
} else if (returnType instanceof FollowToArgumentType) {
int argumentIndex = ((FollowToArgumentType) returnType).argumentIndex;
return argumentTypes.get(argumentIndex);
} else {
return returnType;
}
return signature;
}
private static DataType replaceAnyDataTypeWithOutIndex(DataType sigType, DataType expressionType) {
@ -308,10 +324,10 @@ public class ComputeSignatureHelper {
if (computeSignature instanceof ComputePrecision) {
return ((ComputePrecision) computeSignature).computePrecision(signature);
}
if (signature.argumentsTypes.stream().anyMatch(DateTimeV2Type.class::isInstance)) {
if (signature.argumentsTypes.stream().anyMatch(TypeCoercionUtils::hasDateTimeV2Type)) {
signature = defaultDateTimeV2PrecisionPromotion(signature, arguments);
}
if (signature.argumentsTypes.stream().anyMatch(DecimalV3Type.class::isInstance)) {
if (signature.argumentsTypes.stream().anyMatch(TypeCoercionUtils::hasDecimalV3Type)) {
// do decimal v3 precision
signature = defaultDecimalV3PrecisionPromotion(signature, arguments);
}
@ -354,30 +370,34 @@ public class ComputeSignatureHelper {
} else {
targetType = signature.getArgType(i);
}
if (!(targetType instanceof DateTimeV2Type)) {
List<DataType> argTypes = extractArgumentType(DateTimeV2Type.class,
targetType, arguments.get(i).getDataType());
if (argTypes.isEmpty()) {
continue;
}
if (finalType == null) {
if (arguments.get(i) instanceof StringLikeLiteral) {
// We need to determine the scale based on the string literal.
for (DataType argType : argTypes) {
Expression arg = arguments.get(i);
DateTimeV2Type dateTimeV2Type;
if (arg instanceof StringLikeLiteral) {
StringLikeLiteral str = (StringLikeLiteral) arguments.get(i);
finalType = DateTimeV2Type.forTypeFromString(str.getStringValue());
dateTimeV2Type = DateTimeV2Type.forTypeFromString(str.getStringValue());
} else {
finalType = DateTimeV2Type.forType(arguments.get(i).getDataType());
dateTimeV2Type = DateTimeV2Type.forType(argType);
}
if (finalType == null) {
finalType = dateTimeV2Type;
} else {
finalType = DateTimeV2Type.getWiderDatetimeV2Type(finalType,
DateTimeV2Type.forType(arguments.get(i).getDataType()));
}
} else {
finalType = DateTimeV2Type.getWiderDatetimeV2Type(finalType,
DateTimeV2Type.forType(arguments.get(i).getDataType()));
}
}
DateTimeV2Type argType = finalType;
List<DataType> newArgTypes = signature.argumentsTypes.stream().map(t -> {
if (t instanceof DateTimeV2Type) {
return argType;
} else {
return t;
}
}).collect(Collectors.toList());
List<DataType> newArgTypes = signature.argumentsTypes.stream()
.map(at -> TypeCoercionUtils.replaceDateTimeV2WithTarget(at, argType))
.collect(Collectors.toList());
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
if (signature.returnType instanceof DateTimeV2Type) {
signature = signature.withReturnType(argType);
@ -387,7 +407,7 @@ public class ComputeSignatureHelper {
private static FunctionSignature defaultDecimalV3PrecisionPromotion(
FunctionSignature signature, List<Expression> arguments) {
DataType finalType = null;
DecimalV3Type finalType = null;
for (int i = 0; i < arguments.size(); i++) {
DataType targetType;
if (i >= signature.argumentsTypes.size()) {
@ -397,37 +417,32 @@ public class ComputeSignatureHelper {
} else {
targetType = signature.getArgType(i);
}
if (!(targetType instanceof DecimalV3Type)) {
List<DataType> argTypes = extractArgumentType(DecimalV3Type.class,
targetType, arguments.get(i).getDataType());
if (argTypes.isEmpty()) {
continue;
}
// only process wildcard decimalv3
if (((DecimalV3Type) targetType).getPrecision() > 0) {
continue;
}
if (finalType == null) {
finalType = DecimalV3Type.forType(arguments.get(i).getDataType());
} else {
for (DataType argType : argTypes) {
Expression arg = arguments.get(i);
DecimalV3Type argType;
DecimalV3Type decimalV3Type;
if (arg.isLiteral() && arg.getDataType().isIntegralType()) {
// create decimalV3 with minimum scale enough to hold the integral literal
argType = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue()));
decimalV3Type = DecimalV3Type.createDecimalV3Type(new BigDecimal(((Literal) arg).getStringValue()));
} else {
argType = DecimalV3Type.forType(arg.getDataType());
decimalV3Type = DecimalV3Type.forType(argType);
}
if (finalType == null) {
finalType = decimalV3Type;
} else {
finalType = (DecimalV3Type) DecimalV3Type.widerDecimalV3Type(finalType, decimalV3Type, false);
}
finalType = DecimalV3Type.widerDecimalV3Type((DecimalV3Type) finalType, argType, true);
}
Preconditions.checkState(finalType.isDecimalV3Type(), "decimalv3 precision promotion failed.");
}
DataType argType = finalType;
List<DataType> newArgTypes = signature.argumentsTypes.stream().map(t -> {
// only process wildcard decimalv3
if (t instanceof DecimalV3Type && ((DecimalV3Type) t).getPrecision() <= 0) {
return argType;
} else {
return t;
}
}).collect(Collectors.toList());
DecimalV3Type argType = finalType;
List<DataType> newArgTypes = signature.argumentsTypes.stream()
.map(at -> TypeCoercionUtils.replaceDecimalV3WithTarget(at, argType))
.collect(Collectors.toList());
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
if (signature.returnType instanceof DecimalV3Type
&& ((DecimalV3Type) signature.returnType).getPrecision() <= 0) {
@ -436,6 +451,42 @@ public class ComputeSignatureHelper {
return signature;
}
private static List<DataType> extractArgumentType(Class<? extends DataType> targetType,
DataType signatureType, DataType argumentType) {
if (targetType.isAssignableFrom(signatureType.getClass())) {
return Lists.newArrayList(argumentType);
} else if (signatureType instanceof ArrayType) {
if (argumentType instanceof NullType) {
return extractArgumentType(targetType, ((ArrayType) signatureType).getItemType(), argumentType);
} else if (argumentType instanceof ArrayType) {
return extractArgumentType(targetType,
((ArrayType) signatureType).getItemType(), ((ArrayType) argumentType).getItemType());
} else {
return Lists.newArrayList();
}
} else if (signatureType instanceof MapType) {
if (argumentType instanceof NullType) {
List<DataType> ret = extractArgumentType(targetType,
((MapType) signatureType).getKeyType(), argumentType);
ret.addAll(extractArgumentType(targetType, ((MapType) signatureType).getValueType(), argumentType));
return ret;
} else if (argumentType instanceof MapType) {
List<DataType> ret = extractArgumentType(targetType,
((MapType) signatureType).getKeyType(), ((MapType) argumentType).getKeyType());
ret.addAll(extractArgumentType(targetType,
((MapType) signatureType).getValueType(), ((MapType) argumentType).getValueType()));
return ret;
} else {
return Lists.newArrayList();
}
} else if (signatureType instanceof StructType) {
// TODO: do not support struct type now
return Lists.newArrayList();
} else {
return Lists.newArrayList();
}
}
static class ComputeSignatureChain {
private final ResponsibilityChain<SignatureContext> computeChain;

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.FollowToArgumentType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.collect.ImmutableList;
@ -79,9 +80,27 @@ public class Array extends ScalarFunction
.map(TypeCoercionUtils::replaceCharacterToString)
.collect(Collectors.toList());
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV2WithDefault).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDecimalV3Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDecimalV3WithWildcard).collect(Collectors.toList()));
}
partitioned = partitioned.get(false).stream()
.collect(Collectors.partitioningBy(TypeCoercionUtils::hasDateTimeV2Type));
if (!partitioned.get(true).isEmpty()) {
needTypeCoercion.addAll(partitioned.get(true).stream()
.map(TypeCoercionUtils::replaceDateTimeV2WithMax).collect(Collectors.toList()));
}
needTypeCoercion.addAll(partitioned.get(false));
return needTypeCoercion.stream()
.map(dataType -> FunctionSignature.ret(ArrayType.of(dataType)).varArgs(dataType))
.map(dataType -> FunctionSignature.ret(ArrayType.of(new FollowToArgumentType(0))).varArgs(dataType))
.collect(ImmutableList.toImmutableList());
}
}

View File

@ -37,6 +37,8 @@ import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.HllType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
@ -55,9 +57,8 @@ public class If extends ScalarFunction
implements TernaryExpression, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, ArrayType.of(new AnyDataType(0)),
ArrayType.of(new AnyDataType(0))),
FunctionSignature.ret(NullType.INSTANCE)
.args(BooleanType.INSTANCE, NullType.INSTANCE, NullType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE, DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateV2Type.INSTANCE)
@ -88,6 +89,15 @@ public class If extends ScalarFunction
FunctionSignature.ret(BitmapType.INSTANCE)
.args(BooleanType.INSTANCE, BitmapType.INSTANCE, BitmapType.INSTANCE),
FunctionSignature.ret(HllType.INSTANCE).args(BooleanType.INSTANCE, HllType.INSTANCE, HllType.INSTANCE),
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, ArrayType.of(new AnyDataType(0)),
ArrayType.of(new AnyDataType(0))),
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, MapType.of(new AnyDataType(0), new AnyDataType(1)),
MapType.of(new AnyDataType(0), new AnyDataType(1))),
FunctionSignature.retArgType(1)
.args(BooleanType.INSTANCE, new AnyDataType(0), new AnyDataType(0)),
// NOTICE string must at least of signature list, because all complex type could implicit cast to string
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(BooleanType.INSTANCE, VarcharType.SYSTEM_DEFAULT, VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE)

View File

@ -274,33 +274,75 @@ public class TypeCoercionUtils {
* return ture if datatype has character type in it, cannot use instance of CharacterType because of complex type.
*/
public static boolean hasCharacterType(DataType dataType) {
return hasSpecifiedType(dataType, CharacterType.class);
}
public static boolean hasDecimalV2Type(DataType dataType) {
return hasSpecifiedType(dataType, DecimalV2Type.class);
}
public static boolean hasDecimalV3Type(DataType dataType) {
return hasSpecifiedType(dataType, DecimalV3Type.class);
}
public static boolean hasDateTimeV2Type(DataType dataType) {
return hasSpecifiedType(dataType, DateTimeV2Type.class);
}
private static boolean hasSpecifiedType(DataType dataType, Class<? extends DataType> specifiedType) {
if (dataType instanceof ArrayType) {
return hasCharacterType(((ArrayType) dataType).getItemType());
return hasSpecifiedType(((ArrayType) dataType).getItemType(), specifiedType);
} else if (dataType instanceof MapType) {
return hasCharacterType(((MapType) dataType).getKeyType())
|| hasCharacterType(((MapType) dataType).getValueType());
return hasSpecifiedType(((MapType) dataType).getKeyType(), specifiedType)
|| hasSpecifiedType(((MapType) dataType).getValueType(), specifiedType);
} else if (dataType instanceof StructType) {
return ((StructType) dataType).getFields().stream().anyMatch(f -> hasCharacterType(f.getDataType()));
return ((StructType) dataType).getFields().stream()
.anyMatch(f -> hasSpecifiedType(f.getDataType(), specifiedType));
}
return dataType instanceof CharacterType;
return specifiedType.isAssignableFrom(dataType.getClass());
}
/**
* replace all character types to string for correct type coercion
*/
public static DataType replaceCharacterToString(DataType dataType) {
return replaceSpecifiedType(dataType, CharacterType.class, StringType.INSTANCE);
}
public static DataType replaceDecimalV2WithDefault(DataType dataType) {
return replaceSpecifiedType(dataType, DecimalV2Type.class, DecimalV2Type.SYSTEM_DEFAULT);
}
public static DataType replaceDecimalV3WithTarget(DataType dataType, DecimalV3Type target) {
return replaceSpecifiedType(dataType, DecimalV3Type.class, target);
}
public static DataType replaceDecimalV3WithWildcard(DataType dataType) {
return replaceSpecifiedType(dataType, DecimalV3Type.class, DecimalV3Type.WILDCARD);
}
public static DataType replaceDateTimeV2WithTarget(DataType dataType, DateTimeV2Type target) {
return replaceSpecifiedType(dataType, DateTimeV2Type.class, target);
}
public static DataType replaceDateTimeV2WithMax(DataType dataType) {
return replaceSpecifiedType(dataType, DateTimeV2Type.class, DateTimeV2Type.MAX);
}
private static DataType replaceSpecifiedType(DataType dataType,
Class<? extends DataType> specifiedType, DataType newType) {
if (dataType instanceof ArrayType) {
return ArrayType.of(replaceCharacterToString(((ArrayType) dataType).getItemType()));
return ArrayType.of(replaceSpecifiedType(((ArrayType) dataType).getItemType(), specifiedType, newType));
} else if (dataType instanceof MapType) {
return MapType.of(replaceCharacterToString(((MapType) dataType).getKeyType()),
replaceCharacterToString(((MapType) dataType).getValueType()));
return MapType.of(replaceSpecifiedType(((MapType) dataType).getKeyType(), specifiedType, newType),
replaceSpecifiedType(((MapType) dataType).getValueType(), specifiedType, newType));
} else if (dataType instanceof StructType) {
List<StructField> newFields = ((StructType) dataType).getFields().stream()
.map(f -> f.withDataType(replaceCharacterToString(f.getDataType())))
.map(f -> f.withDataType(replaceSpecifiedType(f.getDataType(), specifiedType, newType)))
.collect(ImmutableList.toImmutableList());
return new StructType(newFields);
} else if (dataType instanceof CharacterType) {
return StringType.INSTANCE;
} else if (specifiedType.isAssignableFrom(dataType.getClass())) {
return newType;
} else {
return dataType;
}

View File

@ -0,0 +1,447 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.ArrayLiteral;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.MapLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.literal.SmallIntLiteral;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.MapType;
import org.apache.doris.nereids.types.NullType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToArgumentType;
import com.google.common.collect.Lists;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.math.BigDecimal;
import java.util.Collections;
import java.util.List;
public class ComputeSignatureHelperTest {
/////////////////////////////////////////
// implementFollowToArgumentReturnType
/////////////////////////////////////////
@Test
void testNoImplementFollowToArgumentReturnType() {
FunctionSignature signature = FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE);
signature = ComputeSignatureHelper.implementFollowToArgumentReturnType(signature, Collections.emptyList());
Assertions.assertTrue(signature.returnType instanceof DoubleType);
}
@Test
void testArrayImplementFollowToArgumentReturnType() {
FunctionSignature signature = FunctionSignature.ret(ArrayType.of(new FollowToArgumentType(0)))
.args(IntegerType.INSTANCE);
signature = ComputeSignatureHelper.implementFollowToArgumentReturnType(signature, Collections.emptyList());
Assertions.assertTrue(signature.returnType instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.returnType).getItemType() instanceof IntegerType);
}
@Test
void testMapImplementFollowToArgumentReturnType() {
FunctionSignature signature = FunctionSignature.ret(MapType.of(
new FollowToArgumentType(0), new FollowToArgumentType(1)))
.args(IntegerType.INSTANCE, DoubleType.INSTANCE);
signature = ComputeSignatureHelper.implementFollowToArgumentReturnType(signature, Collections.emptyList());
Assertions.assertTrue(signature.returnType instanceof MapType);
Assertions.assertTrue(((MapType) signature.returnType).getKeyType() instanceof IntegerType);
Assertions.assertTrue(((MapType) signature.returnType).getValueType() instanceof DoubleType);
}
/////////////////////////////////////////
// implementAnyDataTypeWithOutIndex
/////////////////////////////////////////
@Test
void testNoImplementAnyDataTypeWithOutIndex() {
FunctionSignature signature = FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE);
signature = ComputeSignatureHelper.implementAnyDataTypeWithOutIndex(signature, Collections.emptyList());
Assertions.assertTrue(signature.returnType instanceof DoubleType);
}
@Test
void testArraySigWithNullArgImplementAnyDataTypeWithOutIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX));
List<Expression> arguments = Lists.newArrayList(new NullLiteral());
signature = ComputeSignatureHelper.implementAnyDataTypeWithOutIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(0)).getItemType() instanceof NullType);
}
@Test
void testMapSigWithNullArgImplementAnyDataTypeWithOutIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, AnyDataType.INSTANCE_WITHOUT_INDEX));
List<Expression> arguments = Lists.newArrayList(new NullLiteral());
signature = ComputeSignatureHelper.implementAnyDataTypeWithOutIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getKeyType() instanceof NullType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getValueType() instanceof NullType);
}
@Test
void testArrayImplementAnyDataTypeWithOutIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX));
List<Expression> arguments = Lists.newArrayList(new ArrayLiteral(Lists.newArrayList(new IntegerLiteral(0))));
signature = ComputeSignatureHelper.implementAnyDataTypeWithOutIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(0)).getItemType() instanceof IntegerType);
}
@Test
void testMapImplementAnyDataTypeWithOutIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, AnyDataType.INSTANCE_WITHOUT_INDEX));
List<Expression> arguments = Lists.newArrayList(new MapLiteral(Lists.newArrayList(new IntegerLiteral(0)),
Lists.newArrayList(new BigIntLiteral(0))));
signature = ComputeSignatureHelper.implementAnyDataTypeWithOutIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getKeyType() instanceof IntegerType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getValueType() instanceof BigIntType);
}
/////////////////////////////////////////
// implementAnyDataTypeWithIndex
/////////////////////////////////////////
@Test
void testNoImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(DoubleType.INSTANCE).args(IntegerType.INSTANCE);
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, Collections.emptyList());
Assertions.assertTrue(signature.returnType instanceof DoubleType);
}
@Test
void testArraySigWithNullArgImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0));
List<Expression> arguments = Lists.newArrayList(
new NullLiteral(),
new BigIntLiteral(0));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(0)).getItemType() instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(1) instanceof BigIntType);
}
@Test
void testMapSigWithNullArgImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(MapType.of(new AnyDataType(0), new AnyDataType(1)),
new AnyDataType(0), new AnyDataType(1));
List<Expression> arguments = Lists.newArrayList(
new NullLiteral(), new BigIntLiteral(0), new IntegerLiteral(0));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getKeyType() instanceof BigIntType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getValueType() instanceof IntegerType);
Assertions.assertTrue(signature.getArgType(1) instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(2) instanceof IntegerType);
}
@Test
void testArrayImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0));
List<Expression> arguments = Lists.newArrayList(
new ArrayLiteral(Lists.newArrayList(new IntegerLiteral(0))),
new BigIntLiteral(0));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(0)).getItemType() instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(1) instanceof BigIntType);
}
@Test
void testMapImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(MapType.of(new AnyDataType(0), new AnyDataType(1)),
new AnyDataType(0), new AnyDataType(1));
List<Expression> arguments = Lists.newArrayList(
new MapLiteral(Lists.newArrayList(new IntegerLiteral(0)), Lists.newArrayList(new BigIntLiteral(0))),
new BigIntLiteral(0), new IntegerLiteral(0));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getKeyType() instanceof BigIntType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getValueType() instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(1) instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(2) instanceof BigIntType);
}
@Test
void testArraySigWithNullArgWithFollowToAnyImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0),
ArrayType.of(new FollowToAnyDataType(0)));
List<Expression> arguments = Lists.newArrayList(
new NullLiteral(),
new NullLiteral(),
new ArrayLiteral(Lists.newArrayList(new SmallIntLiteral((byte) 0))));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(0)).getItemType() instanceof SmallIntType);
Assertions.assertTrue(signature.getArgType(1) instanceof SmallIntType);
Assertions.assertTrue(signature.getArgType(2) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(2)).getItemType() instanceof SmallIntType);
}
@Test
void testMapSigWithNullArgWithFollowToAnyImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(MapType.of(new AnyDataType(0), new AnyDataType(1)),
new AnyDataType(0), new AnyDataType(1),
MapType.of(new FollowToAnyDataType(0), new FollowToAnyDataType(1)));
List<Expression> arguments = Lists.newArrayList(
new NullLiteral(), new NullLiteral(), new NullLiteral(),
new MapLiteral(Lists.newArrayList(new BigIntLiteral(0)),
Lists.newArrayList(new IntegerLiteral(0))));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getKeyType() instanceof BigIntType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getValueType() instanceof IntegerType);
Assertions.assertTrue(signature.getArgType(1) instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(2) instanceof IntegerType);
Assertions.assertTrue(signature.getArgType(3) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(3)).getKeyType() instanceof BigIntType);
Assertions.assertTrue(((MapType) signature.getArgType(3)).getValueType() instanceof IntegerType);
}
@Test
void testArrayWithFollowToAnyImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0),
ArrayType.of(new FollowToAnyDataType(0)));
List<Expression> arguments = Lists.newArrayList(
new ArrayLiteral(Lists.newArrayList(new IntegerLiteral(0))),
new BigIntLiteral(0),
new ArrayLiteral(Lists.newArrayList(new IntegerLiteral(0))));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(0)).getItemType() instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(1) instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(2) instanceof ArrayType);
Assertions.assertTrue(((ArrayType) signature.getArgType(2)).getItemType() instanceof BigIntType);
}
@Test
void testMapWithFollowToAnyImplementAnyDataTypeWithIndex() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE)
.args(MapType.of(new AnyDataType(0), new AnyDataType(1)),
new AnyDataType(0), new AnyDataType(1),
MapType.of(new FollowToAnyDataType(0), new FollowToAnyDataType(1)));
List<Expression> arguments = Lists.newArrayList(
new MapLiteral(Lists.newArrayList(new IntegerLiteral(0)), Lists.newArrayList(new BigIntLiteral(0))),
new BigIntLiteral(0), new IntegerLiteral(0),
new MapLiteral(Lists.newArrayList(new IntegerLiteral(0)), Lists.newArrayList(new BigIntLiteral(0))));
signature = ComputeSignatureHelper.implementAnyDataTypeWithIndex(signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getKeyType() instanceof BigIntType);
Assertions.assertTrue(((MapType) signature.getArgType(0)).getValueType() instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(1) instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(2) instanceof BigIntType);
Assertions.assertTrue(signature.getArgType(3) instanceof MapType);
Assertions.assertTrue(((MapType) signature.getArgType(3)).getKeyType() instanceof BigIntType);
Assertions.assertTrue(((MapType) signature.getArgType(3)).getValueType() instanceof BigIntType);
}
@Test
void testNoNormalizeDecimalV2() {
FunctionSignature signature = FunctionSignature.ret(IntegerType.INSTANCE).args();
signature = ComputeSignatureHelper.normalizeDecimalV2(signature, Collections.emptyList());
Assertions.assertEquals(IntegerType.INSTANCE, signature.returnType);
}
@Test
void testNormalizeDecimalV2() {
FunctionSignature signature = FunctionSignature.ret(DecimalV2Type.createDecimalV2Type(15, 3)).args();
signature = ComputeSignatureHelper.normalizeDecimalV2(signature, Collections.emptyList());
Assertions.assertEquals(DecimalV2Type.SYSTEM_DEFAULT, signature.returnType);
}
@Test
void testArrayDecimalV3ComputePrecision() {
FunctionSignature signature = FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(DecimalV3Type.WILDCARD),
ArrayType.of(DecimalV3Type.WILDCARD),
DecimalV3Type.WILDCARD,
IntegerType.INSTANCE,
ArrayType.of(IntegerType.INSTANCE));
List<Expression> arguments = Lists.newArrayList(
new ArrayLiteral(Lists.newArrayList(new DecimalV3Literal(new BigDecimal("1.1234")))),
new NullLiteral(),
new DecimalV3Literal(new BigDecimal("123.123")),
new IntegerLiteral(0),
new ArrayLiteral(Lists.newArrayList(new IntegerLiteral(0))));
signature = ComputeSignatureHelper.computePrecision(new FakeComputeSignature(), signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(7, 4),
((ArrayType) signature.getArgType(0)).getItemType());
Assertions.assertTrue(signature.getArgType(1) instanceof ArrayType);
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(7, 4),
((ArrayType) signature.getArgType(1)).getItemType());
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(7, 4),
signature.getArgType(2));
Assertions.assertTrue(signature.getArgType(4) instanceof ArrayType);
Assertions.assertEquals(IntegerType.INSTANCE,
((ArrayType) signature.getArgType(4)).getItemType());
}
@Test
void testMapDecimalV3ComputePrecision() {
FunctionSignature signature = FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
MapType.of(DecimalV3Type.WILDCARD, DecimalV3Type.WILDCARD),
DecimalV3Type.WILDCARD);
List<Expression> arguments = Lists.newArrayList(
new MapLiteral(Lists.newArrayList(new DecimalV3Literal(new BigDecimal("1.1234"))),
Lists.newArrayList(new DecimalV3Literal(new BigDecimal("12.12345")))),
new NullLiteral(),
new DecimalV3Literal(new BigDecimal("123.123")));
signature = ComputeSignatureHelper.computePrecision(new FakeComputeSignature(), signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(8, 5),
((MapType) signature.getArgType(0)).getKeyType());
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(8, 5),
((MapType) signature.getArgType(0)).getValueType());
Assertions.assertTrue(signature.getArgType(1) instanceof MapType);
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(8, 5),
((MapType) signature.getArgType(1)).getKeyType());
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(8, 5),
((MapType) signature.getArgType(1)).getValueType());
Assertions.assertEquals(DecimalV3Type.createDecimalV3Type(8, 5),
signature.getArgType(2));
}
@Test
void testArrayDateTimeV2ComputePrecision() {
FunctionSignature signature = FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(DateTimeV2Type.SYSTEM_DEFAULT),
ArrayType.of(DateTimeV2Type.SYSTEM_DEFAULT),
DateTimeV2Type.SYSTEM_DEFAULT,
IntegerType.INSTANCE,
ArrayType.of(IntegerType.INSTANCE));
List<Expression> arguments = Lists.newArrayList(
new ArrayLiteral(Lists.newArrayList(new DateTimeV2Literal("2020-02-02 00:00:00.123"))),
new NullLiteral(),
new DateTimeV2Literal("2020-02-02 00:00:00.12"),
new IntegerLiteral(0),
new ArrayLiteral(Lists.newArrayList(new IntegerLiteral(0))));
signature = ComputeSignatureHelper.computePrecision(new FakeComputeSignature(), signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof ArrayType);
Assertions.assertEquals(DateTimeV2Type.of(3),
((ArrayType) signature.getArgType(0)).getItemType());
Assertions.assertTrue(signature.getArgType(1) instanceof ArrayType);
Assertions.assertEquals(DateTimeV2Type.of(3),
((ArrayType) signature.getArgType(1)).getItemType());
Assertions.assertEquals(DateTimeV2Type.of(3),
signature.getArgType(2));
Assertions.assertTrue(signature.getArgType(4) instanceof ArrayType);
Assertions.assertEquals(IntegerType.INSTANCE,
((ArrayType) signature.getArgType(4)).getItemType());
}
@Test
void testMapDateTimeV2ComputePrecision() {
FunctionSignature signature = FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
MapType.of(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
DateTimeV2Type.SYSTEM_DEFAULT);
List<Expression> arguments = Lists.newArrayList(
new MapLiteral(Lists.newArrayList(new DateTimeV2Literal("2020-02-02 00:00:00.123")),
Lists.newArrayList(new DateTimeV2Literal("2020-02-02 00:00:00.12"))),
new NullLiteral(),
new DateTimeV2Literal("2020-02-02 00:00:00.1234"));
signature = ComputeSignatureHelper.computePrecision(new FakeComputeSignature(), signature, arguments);
Assertions.assertTrue(signature.getArgType(0) instanceof MapType);
Assertions.assertEquals(DateTimeV2Type.of(6),
((MapType) signature.getArgType(0)).getKeyType());
Assertions.assertEquals(DateTimeV2Type.of(6),
((MapType) signature.getArgType(0)).getValueType());
Assertions.assertTrue(signature.getArgType(1) instanceof MapType);
Assertions.assertEquals(DateTimeV2Type.of(6),
((MapType) signature.getArgType(1)).getKeyType());
Assertions.assertEquals(DateTimeV2Type.of(6),
((MapType) signature.getArgType(1)).getValueType());
Assertions.assertEquals(DateTimeV2Type.of(6),
signature.getArgType(2));
}
private static class FakeComputeSignature implements ComputeSignature {
@Override
public List<Expression> children() {
return null;
}
@Override
public Expression child(int index) {
return null;
}
@Override
public int arity() {
return 0;
}
@Override
public Expression withChildren(List<Expression> children) {
return null;
}
@Override
public List<FunctionSignature> getSignatures() {
return null;
}
@Override
public FunctionSignature getSignature() {
return null;
}
@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
return null;
}
@Override
public boolean nullable() {
return false;
}
}
}