branch-2.1: [fix](function) fixed some nested type func's param type which is not suitable and make result wrong #44923 (#45798)

Cherry-picked from #44923
This commit is contained in:
amory
2024-12-24 14:57:33 +08:00
committed by GitHub
parent 69704df447
commit 8b35b0e477
12 changed files with 208 additions and 18 deletions

View File

@ -41,11 +41,16 @@ import java.util.List;
*/
public class ArrayApply extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT,
new FollowToAnyDataType(0)));
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), VarcharType.SYSTEM_DEFAULT,
new AnyDataType(0)));
/**
* constructor
*/
@ -93,6 +98,13 @@ public class ArrayApply extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(2).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,10 +38,14 @@ import java.util.List;
public class ArrayContains extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);
/**
* constructor with 2 arguments.
@ -71,6 +75,13 @@ public class ArrayContains extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,11 +38,16 @@ import java.util.List;
public class ArrayPosition extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);
/**
* constructor with 2 arguments.
*/
@ -71,6 +76,13 @@ public class ArrayPosition extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,11 +38,16 @@ import java.util.List;
public class ArrayPushBack extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);
/**
* constructor with 1 argument.
*/
@ -66,6 +71,13 @@ public class ArrayPushBack extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,11 +38,16 @@ import java.util.List;
public class ArrayPushFront extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);
/**
* constructor with 1 argument.
*/
@ -66,6 +71,13 @@ public class ArrayPushFront extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,11 +38,16 @@ import java.util.List;
public class ArrayRemove extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.retArgType(0).args(
ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(
ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);
/**
* constructor with 2 arguments.
*/
@ -66,6 +71,13 @@ public class ArrayRemove extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,11 +38,16 @@ import java.util.List;
public class CountEqual extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
);
/**
* constructor with 2 arguments.
*/
@ -71,6 +76,16 @@ public class CountEqual extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
// to find out element type in array vs param type,
// if they are different, return first array element type,
// else return least common type between element type and param
if (getArgument(0).getDataType().isArrayType()
&&
((ArrayType) getArgument(0).getDataType()).getItemType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,12 +38,18 @@ import java.util.List;
public class MapContainsKey extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX),
new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(new AnyDataType(0), AnyDataType.INSTANCE_WITHOUT_INDEX),
new AnyDataType(0))
);
/**
* constructor with 2 arguments.
*/
@ -72,6 +78,13 @@ public class MapContainsKey extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isMapType()
&&
((MapType) getArgument(0).getDataType()).getKeyType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -38,12 +38,18 @@ import java.util.List;
public class MapContainsValue extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
public static final List<FunctionSignature> FOLLOW_DATATYPE_SIGNATURE = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)),
new FollowToAnyDataType(0))
);
public static final List<FunctionSignature> MIN_COMMON_TYPE_SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BooleanType.INSTANCE)
.args(MapType.of(AnyDataType.INSTANCE_WITHOUT_INDEX, new AnyDataType(0)),
new AnyDataType(0))
);
/**
* constructor with 2 arguments.
*/
@ -72,6 +78,13 @@ public class MapContainsValue extends ScalarFunction
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
if (getArgument(0).getDataType().isMapType()
&&
((MapType) getArgument(0).getDataType()).getValueType()
.isSameTypeForComplexTypeParam(getArgument(1).getDataType())) {
// return least common type
return MIN_COMMON_TYPE_SIGNATURES;
}
return FOLLOW_DATATYPE_SIGNATURE;
}
}

View File

@ -648,6 +648,41 @@ public abstract class DataType {
}
}
/**
* whether the param dataType is same-like type for nested in complex type
* same-like type means: string-like, date-like, number type
*/
public boolean isSameTypeForComplexTypeParam(DataType paramType) {
if (this.isArrayType() && paramType.isArrayType()) {
return ((ArrayType) this).getItemType()
.isSameTypeForComplexTypeParam(((ArrayType) paramType).getItemType());
} else if (this.isMapType() && paramType.isMapType()) {
MapType thisMapType = (MapType) this;
MapType otherMapType = (MapType) paramType;
return thisMapType.getKeyType().isSameTypeForComplexTypeParam(otherMapType.getKeyType())
&& thisMapType.getValueType().isSameTypeForComplexTypeParam(otherMapType.getValueType());
} else if (this.isStructType() && paramType.isStructType()) {
StructType thisStructType = (StructType) this;
StructType otherStructType = (StructType) paramType;
if (thisStructType.getFields().size() != otherStructType.getFields().size()) {
return false;
}
for (int i = 0; i < thisStructType.getFields().size(); i++) {
if (!thisStructType.getFields().get(i).getDataType().isSameTypeForComplexTypeParam(
otherStructType.getFields().get(i).getDataType())) {
return false;
}
}
return true;
} else if (this.isStringLikeType() && paramType.isStringLikeType()) {
return true;
} else if (this.isDateLikeType() && paramType.isDateLikeType()) {
return true;
} else {
return this.isNumericType() && paramType.isNumericType();
}
}
/** getAllPromotions */
public List<DataType> getAllPromotions() {
if (this instanceof ArrayType) {

View File

@ -15579,3 +15579,30 @@ false
\N
\N
-- !sql --
0 0
-- !sql --
[258] []
-- !sql --
false false
-- !sql --
[257, 258] [258, 1, 2, 3]
-- !sql --
[1, 258, 257] [1, 2, 3, 258]
-- !sql --
[1, 258] [1, 2, 3]
-- !sql --
0 0
-- !sql --
false false
-- !sql --
false false

View File

@ -1375,4 +1375,20 @@ suite("nereids_scalar_fn_Array") {
order_qt_sql_array_overlaps_5 """select arrays_overlap(b, c) from fn_test_array_with_large_decimal order by id"""
order_qt_sql_array_overlaps_6 """select arrays_overlap(c, b) from fn_test_array_with_large_decimal order by id"""
// tests for nereids array functions for number overflow cases
qt_sql """ SELECT array_position([1,258],257),array_position([2],258);"""
qt_sql """ select array_apply([258], '>' , 257), array_apply([1,2,3], '>', 258);"""
qt_sql """ select array_contains([258], 257), array_contains([1,2,3], 258);"""
// pushfront and pushback
qt_sql """ select array_pushfront([258], 257), array_pushfront([1,2,3], 258);"""
qt_sql """ select array_pushback([1,258], 257), array_pushback([1,2,3], 258);"""
// array_remove
qt_sql """ select array_remove([1,258], 257), array_remove([1,2,3], 258);"""
// countequal
qt_sql """ select countequal([1,258], 257), countequal([1,2,3], 258);"""
// map_contains_key
qt_sql """ select map_contains_key(map(1,258), 257), map_contains_key(map(2,1), 258);"""
// map_contains_value
qt_sql """ select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258);"""
}