[fix](Nereids) polish function signature search algorithm (#38497) (#39436)

pick from master #38497 and #39342

use array<double> for array<string>
- array_avg
- array_cum_sum
- array_difference
- array_product

use array<bigint> for array<string>
- bitmap_from_array

use double first
- fmod
- pmod

let high order function throw friendly exception
- array_filter
- array_first
- array_last
- array_reverse_split
- array_sort_by
- array_split

let return type same as parameter's type
- array_push_back
- array_push_front
- array_with_constant
- if

let greatest / least work same as mysql's greatest
This commit is contained in:
morrySnow
2024-08-16 08:24:25 +08:00
committed by GitHub
parent 6257e706fa
commit 3aaee8f7d5
25 changed files with 247 additions and 75 deletions

View File

@ -811,9 +811,7 @@ public class ExpressionAnalyzer extends SubExprAnalyzer<ExpressionRewriteContext
Lambda lambdaClosure = lambda.withLambdaFunctionArguments(lambdaFunction, arrayItemReferences);
// We don't add the ArrayExpression in high order function at all
return unboundFunction.withChildren(ImmutableList.<Expression>builder()
.add(lambdaClosure)
.build());
return unboundFunction.withChildren(ImmutableList.of(lambdaClosure));
}
private boolean shouldBindSlotBy(int namePartSize, Slot boundSlot) {

View File

@ -48,6 +48,7 @@ public class ArrayAvg extends ScalarFunction implements ExplicitlyCastableSignat
ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(BooleanType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(TinyIntType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(SmallIntType.INSTANCE)),
@ -56,8 +57,7 @@ public class ArrayAvg extends ScalarFunction implements ExplicitlyCastableSignat
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(LargeIntType.INSTANCE)),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(ArrayType.of(DecimalV3Type.WILDCARD)),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(FloatType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(DoubleType.INSTANCE))
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(FloatType.INSTANCE))
);
/**

View File

@ -47,6 +47,7 @@ public class ArrayCumSum extends ScalarFunction
implements ExplicitlyCastableSignature, ComputePrecisionForArrayItemAgg, UnaryExpression, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)).args(ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE)).args(ArrayType.of(TinyIntType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE)).args(ArrayType.of(SmallIntType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE)).args(ArrayType.of(IntegerType.INSTANCE)),
@ -55,8 +56,7 @@ public class ArrayCumSum extends ScalarFunction
FunctionSignature.ret(ArrayType.of(DecimalV3Type.WILDCARD)).args(ArrayType.of(DecimalV3Type.WILDCARD)),
FunctionSignature.ret(ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT))
.args(ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)).args(ArrayType.of(FloatType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)).args(ArrayType.of(DoubleType.INSTANCE))
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)).args(ArrayType.of(FloatType.INSTANCE))
);
/**

View File

@ -46,13 +46,13 @@ public class ArrayDifference extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)).args(ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(SmallIntType.INSTANCE)).args(ArrayType.of(TinyIntType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(IntegerType.INSTANCE)).args(ArrayType.of(SmallIntType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE)).args(ArrayType.of(IntegerType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE)).args(ArrayType.of(BigIntType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(LargeIntType.INSTANCE)).args(ArrayType.of(LargeIntType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)).args(ArrayType.of(FloatType.INSTANCE)),
FunctionSignature.ret(ArrayType.of(DoubleType.INSTANCE)).args(ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.retArgType(0).args(ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)),
FunctionSignature.retArgType(0).args(ArrayType.of(DecimalV3Type.WILDCARD))
);

View File

@ -41,7 +41,8 @@ public class ArrayEnumerateUniq extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE)).varArgs(ArrayType.of(new AnyDataType(0)))
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE))
.varArgs(ArrayType.of(AnyDataType.INSTANCE_WITHOUT_INDEX))
);
/**

View File

@ -49,7 +49,7 @@ public class ArrayFilter extends ScalarFunction
* array_filter(lambda, a1, ...) = array_filter(a1, array_map(lambda, a1, ...))
*/
public ArrayFilter(Expression arg) {
super("array_filter", arg.child(1).child(0), new ArrayMap(arg));
super("array_filter", arg instanceof Lambda ? arg.child(1).child(0) : arg, new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
@ -36,10 +35,6 @@ public class ArrayFirst extends ElementAt
*/
public ArrayFirst(Expression arg) {
super(new ArrayFilter(arg), new BigIntLiteral(1));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
}
@Override

View File

@ -18,7 +18,6 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
@ -36,10 +35,6 @@ public class ArrayLast extends ElementAt
*/
public ArrayLast(Expression arg) {
super(new ArrayFilter(arg), new BigIntLiteral(-1));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
}
@Override

View File

@ -48,6 +48,7 @@ public class ArrayProduct extends ScalarFunction implements ExplicitlyCastableSi
ComputePrecisionForArrayItemAgg, UnaryExpression, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(DoubleType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(BooleanType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(TinyIntType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(SmallIntType.INSTANCE)),
@ -56,8 +57,7 @@ public class ArrayProduct extends ScalarFunction implements ExplicitlyCastableSi
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(LargeIntType.INSTANCE)),
FunctionSignature.ret(DecimalV3Type.WILDCARD).args(ArrayType.of(DecimalV3Type.WILDCARD)),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).args(ArrayType.of(DecimalV2Type.SYSTEM_DEFAULT)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(FloatType.INSTANCE)),
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(DoubleType.INSTANCE))
FunctionSignature.ret(DoubleType.INSTANCE).args(ArrayType.of(FloatType.INSTANCE))
);
/**

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -38,7 +39,8 @@ public class ArrayPushBack extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
/**

View File

@ -25,6 +25,7 @@ import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -38,7 +39,8 @@ public class ArrayPushFront extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(new AnyDataType(0)), new AnyDataType(0))
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), new FollowToAnyDataType(0))
);
/**

View File

@ -48,7 +48,7 @@ public class ArraySortBy extends ScalarFunction
* array_sortby(lambda, a1, ...) = array_sortby(a1, array_map(lambda, a1, ...))
*/
public ArraySortBy(Expression arg) {
super("array_sortby", arg.child(1).child(0), new ArrayMap(arg));
super("array_sortby", arg instanceof Lambda ? arg.child(1).child(0) : arg, new ArrayMap(arg));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));

View File

@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -39,7 +40,8 @@ public class ArrayWithConstant extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(new AnyDataType(0))).args(BigIntType.INSTANCE, new AnyDataType(0))
FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0)))
.args(BigIntType.INSTANCE, new AnyDataType(0))
);
/**

View File

@ -42,10 +42,10 @@ public class BitmapFromArray extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BitmapType.INSTANCE).args(ArrayType.of(TinyIntType.INSTANCE)),
FunctionSignature.ret(BitmapType.INSTANCE).args(ArrayType.of(SmallIntType.INSTANCE)),
FunctionSignature.ret(BitmapType.INSTANCE).args(ArrayType.of(BigIntType.INSTANCE)),
FunctionSignature.ret(BitmapType.INSTANCE).args(ArrayType.of(IntegerType.INSTANCE)),
FunctionSignature.ret(BitmapType.INSTANCE).args(ArrayType.of(BigIntType.INSTANCE))
FunctionSignature.ret(BitmapType.INSTANCE).args(ArrayType.of(SmallIntType.INSTANCE)),
FunctionSignature.ret(BitmapType.INSTANCE).args(ArrayType.of(TinyIntType.INSTANCE))
);
/**

View File

@ -58,12 +58,12 @@ public class Coalesce extends ScalarFunction
FunctionSignature.ret(IntegerType.INSTANCE).varArgs(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).varArgs(LargeIntType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).varArgs(FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).varArgs(FloatType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE),
FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BitmapType.INSTANCE).varArgs(BitmapType.INSTANCE),

View File

@ -38,8 +38,8 @@ public class Fmod extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(FloatType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE)
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).args(FloatType.INSTANCE, FloatType.INSTANCE)
);
/**

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSi
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
@ -37,6 +38,7 @@ import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
@ -51,21 +53,21 @@ public class Greatest extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).varArgs(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).varArgs(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).varArgs(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).varArgs(LargeIntType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).varArgs(FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE),
FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE),
FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE)
FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE),
FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).varArgs(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(LargeIntType.INSTANCE).varArgs(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).varArgs(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).varArgs(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).varArgs(TinyIntType.INSTANCE)
);
/**
@ -80,11 +82,28 @@ public class Greatest extends ScalarFunction
*/
@Override
public Greatest withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 1);
Preconditions.checkArgument(!children.isEmpty());
return new Greatest(children.get(0),
children.subList(1, children.size()).toArray(new Expression[0]));
}
@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
List<DataType> argTypes = getArgumentsTypes();
if (argTypes.stream().anyMatch(CharacterType.class::isInstance)) {
return FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE);
} else if (argTypes.stream().anyMatch(DateTimeV2Type.class::isInstance)) {
return FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT);
} else if (argTypes.stream().anyMatch(DateTimeType.class::isInstance)) {
return FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE);
} else if (argTypes.stream().anyMatch(DateV2Type.class::isInstance)) {
return FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE);
} else if (argTypes.stream().anyMatch(DateType.class::isInstance)) {
return FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE);
}
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;

View File

@ -18,6 +18,7 @@
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.analyzer.Unbound;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
@ -45,6 +46,7 @@ import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -111,7 +113,8 @@ public class If extends ScalarFunction
* constructor with 3 arguments.
*/
public If(Expression arg0, Expression arg1, Expression arg2) {
super("if", arg0, arg1, arg2);
super("if", arg0 instanceof Unbound ? arg0 : TypeCoercionUtils.castIfNotSameType(arg0, BooleanType.INSTANCE),
arg1, arg2);
}
/**
@ -145,4 +148,10 @@ public class If extends ScalarFunction
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}
}

View File

@ -40,10 +40,10 @@ public class LastDay extends ScalarFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullableOnDateLikeV2Args {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DateType.INSTANCE).args(DateTimeType.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).args(DateType.INSTANCE),
FunctionSignature.ret(DateV2Type.INSTANCE).args(DateV2Type.INSTANCE),
FunctionSignature.ret(DateV2Type.INSTANCE).args(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateV2Type.INSTANCE).args(DateV2Type.INSTANCE)
FunctionSignature.ret(DateType.INSTANCE).args(DateType.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).args(DateTimeType.INSTANCE)
);
/**

View File

@ -23,6 +23,7 @@ import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSi
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
@ -37,6 +38,7 @@ import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.CharacterType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
@ -51,21 +53,21 @@ public class Least extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(TinyIntType.INSTANCE).varArgs(TinyIntType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).varArgs(SmallIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).varArgs(IntegerType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(LargeIntType.INSTANCE).varArgs(LargeIntType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).varArgs(FloatType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE),
FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE),
FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).varArgs(VarcharType.SYSTEM_DEFAULT),
FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE)
FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE),
FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE),
FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE),
FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).varArgs(DoubleType.INSTANCE),
FunctionSignature.ret(FloatType.INSTANCE).varArgs(FloatType.INSTANCE),
FunctionSignature.ret(DecimalV3Type.WILDCARD).varArgs(DecimalV3Type.WILDCARD),
FunctionSignature.ret(DecimalV2Type.SYSTEM_DEFAULT).varArgs(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(LargeIntType.INSTANCE).varArgs(LargeIntType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).varArgs(BigIntType.INSTANCE),
FunctionSignature.ret(IntegerType.INSTANCE).varArgs(IntegerType.INSTANCE),
FunctionSignature.ret(SmallIntType.INSTANCE).varArgs(SmallIntType.INSTANCE),
FunctionSignature.ret(TinyIntType.INSTANCE).varArgs(TinyIntType.INSTANCE)
);
/**
@ -80,11 +82,28 @@ public class Least extends ScalarFunction
*/
@Override
public Least withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() >= 1);
Preconditions.checkArgument(!children.isEmpty());
return new Least(children.get(0),
children.subList(1, children.size()).toArray(new Expression[0]));
}
@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
List<DataType> argTypes = getArgumentsTypes();
if (argTypes.stream().anyMatch(CharacterType.class::isInstance)) {
return FunctionSignature.ret(StringType.INSTANCE).varArgs(StringType.INSTANCE);
} else if (argTypes.stream().anyMatch(DateTimeV2Type.class::isInstance)) {
return FunctionSignature.ret(DateTimeV2Type.SYSTEM_DEFAULT).varArgs(DateTimeV2Type.SYSTEM_DEFAULT);
} else if (argTypes.stream().anyMatch(DateTimeType.class::isInstance)) {
return FunctionSignature.ret(DateTimeType.INSTANCE).varArgs(DateTimeType.INSTANCE);
} else if (argTypes.stream().anyMatch(DateV2Type.class::isInstance)) {
return FunctionSignature.ret(DateV2Type.INSTANCE).varArgs(DateV2Type.INSTANCE);
} else if (argTypes.stream().anyMatch(DateType.class::isInstance)) {
return FunctionSignature.ret(DateType.INSTANCE).varArgs(DateType.INSTANCE);
}
return ExplicitlyCastableSignature.super.searchSignature(signatures);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;

View File

@ -42,9 +42,9 @@ public class MinutesDiff extends ScalarFunction
private static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE)
.args(DateTimeV2Type.SYSTEM_DEFAULT, DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateTimeType.INSTANCE, DateTimeType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateV2Type.INSTANCE, DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateTimeV2Type.SYSTEM_DEFAULT, DateV2Type.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateTimeType.INSTANCE, DateTimeType.INSTANCE),
FunctionSignature.ret(BigIntType.INSTANCE).args(DateV2Type.INSTANCE, DateV2Type.INSTANCE)
);

View File

@ -21,10 +21,12 @@ import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
@ -37,10 +39,10 @@ import java.util.List;
public class Pmod extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(BigIntType.INSTANCE).args(BigIntType.INSTANCE, BigIntType.INSTANCE),
FunctionSignature.ret(DoubleType.INSTANCE).args(DoubleType.INSTANCE, DoubleType.INSTANCE)
);
public static final FunctionSignature BIGINT_SIGNATURE = FunctionSignature.ret(BigIntType.INSTANCE)
.args(BigIntType.INSTANCE, BigIntType.INSTANCE);
public static final FunctionSignature DOUBLE_SIGNATURE = FunctionSignature.ret(DoubleType.INSTANCE)
.args(DoubleType.INSTANCE, DoubleType.INSTANCE);
/**
* constructor with 2 arguments.
@ -58,9 +60,39 @@ public class Pmod extends ScalarFunction
return new Pmod(children.get(0), children.get(1));
}
/**
* already override searchSignature and computeSignature, so getSignatures is useless anymore.
*
* @return empty list
*/
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
return ImmutableList.of();
}
@Override
public FunctionSignature computeSignature(FunctionSignature signature) {
return signature;
}
@Override
public FunctionSignature searchSignature(List<FunctionSignature> signatures) {
boolean leftCouldBeBigInt = false;
boolean rightCouldBeBigInt = false;
if (getArgument(0) instanceof StringLikeLiteral) {
leftCouldBeBigInt = TypeCoercionUtils.characterLiteralTypeCoercion(
((StringLikeLiteral) getArgument(0)).getValue(), BigIntType.INSTANCE).isPresent();
}
if (getArgument(1) instanceof StringLikeLiteral) {
rightCouldBeBigInt = TypeCoercionUtils.characterLiteralTypeCoercion(
((StringLikeLiteral) getArgument(1)).getValue(), BigIntType.INSTANCE).isPresent();
}
if ((getArgument(0).getDataType().isIntegerLikeType() || leftCouldBeBigInt)
&& (getArgument(1).getDataType().isIntegerLikeType() || rightCouldBeBigInt)) {
return BIGINT_SIGNATURE;
} else {
return DOUBLE_SIGNATURE;
}
}
@Override

View File

@ -534,7 +534,7 @@ public class TypeCoercionUtils {
if ("false".equalsIgnoreCase(value)) {
ret = BooleanLiteral.FALSE;
}
} else if (dataType instanceof IntegralType) {
} else if (dataType instanceof TinyIntType) {
BigInteger bigInt = new BigInteger(value);
if (BigInteger.valueOf(bigInt.byteValue()).equals(bigInt)) {
ret = new TinyIntLiteral(bigInt.byteValue());
@ -547,6 +547,36 @@ public class TypeCoercionUtils {
} else {
ret = new LargeIntLiteral(bigInt);
}
} else if (dataType instanceof SmallIntType) {
BigInteger bigInt = new BigInteger(value);
if (BigInteger.valueOf(bigInt.shortValue()).equals(bigInt)) {
ret = new SmallIntLiteral(bigInt.shortValue());
} else if (BigInteger.valueOf(bigInt.intValue()).equals(bigInt)) {
ret = new IntegerLiteral(bigInt.intValue());
} else if (BigInteger.valueOf(bigInt.longValue()).equals(bigInt)) {
ret = new BigIntLiteral(bigInt.longValueExact());
} else {
ret = new LargeIntLiteral(bigInt);
}
} else if (dataType instanceof IntegerType) {
BigInteger bigInt = new BigInteger(value);
if (BigInteger.valueOf(bigInt.intValue()).equals(bigInt)) {
ret = new IntegerLiteral(bigInt.intValue());
} else if (BigInteger.valueOf(bigInt.longValue()).equals(bigInt)) {
ret = new BigIntLiteral(bigInt.longValueExact());
} else {
ret = new LargeIntLiteral(bigInt);
}
} else if (dataType instanceof BigIntType) {
BigInteger bigInt = new BigInteger(value);
if (BigInteger.valueOf(bigInt.longValue()).equals(bigInt)) {
ret = new BigIntLiteral(bigInt.longValueExact());
} else {
ret = new LargeIntLiteral(bigInt);
}
} else if (dataType instanceof LargeIntType) {
BigInteger bigInt = new BigInteger(value);
ret = new LargeIntLiteral(bigInt);
} else if (dataType instanceof FloatType) {
ret = new FloatLiteral(Float.parseFloat(value));
} else if (dataType instanceof DoubleType) {

View File

@ -0,0 +1,37 @@
-- This file is automatically generated. You should know what you did if you want to edit this
-- !greatest --
333
-- !least --
2000000
-- !if --
2222
-- !array_product --
6000.0
-- !array_avg --
1001.0
-- !array_pushfront --
[4444, 1, 2, 3, 555555]
-- !array_pushback --
[1, 2, 3, 555555, 4444]
-- !array_difference --
[0, 1, 198]
-- !array_enumerate_uniq --
[1, 2, 1]
-- !array_cum_sum --
[1, 3, 3003]
-- !pmod --
0.0
-- !nullif --
13

View File

@ -0,0 +1,31 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
suite("function_type_coercion") {
sql """set enable_fold_constant_by_be=false""" // remove this if array<double> BE return result be fixed.
qt_greatest """select greatest(1, 2222, '333')"""
qt_least """select least(5,2000000,'3.0023')"""
qt_if """select if (1, 2222, 33)"""
qt_array_product """select array_product(array(1, 2, '3000'))"""
qt_array_avg """select array_avg(array(1, 2, '3000'))"""
qt_array_pushfront """select array_pushfront(array(1,2,3,555555), '4444')"""
qt_array_pushback """select array_pushback(array(1,2,3,555555), '4444')"""
qt_array_difference """select array_difference(array(1,2,'200'))"""
qt_array_enumerate_uniq """select array_enumerate_uniq([1,1,1],['1','1','1.0'])"""
qt_array_cum_sum """select array_cum_sum(array('1', '2', '3000'))"""
qt_pmod """select pmod(2, '1.0')"""
qt_nullif """SELECT nullif(13, -4851)"""
}