[fix](Nereids) support AnyDataType in function signature (#25173)
1. support AnyDataType in function signature 2. update histogram signature
This commit is contained in:
@ -108,6 +108,7 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes
|
||||
// If you want to add some special cases, please override this method in the special
|
||||
// function class, like 'If' function and 'Substring' function.
|
||||
return ComputeSignatureChain.from(this, signature, getArguments())
|
||||
.then(ComputeSignatureHelper::implementAnyDataTypeWithOutIndex)
|
||||
.then(ComputeSignatureHelper::implementAnyDataTypeWithIndex)
|
||||
.then(ComputeSignatureHelper::computePrecision)
|
||||
.then(ComputeSignatureHelper::implementFollowToArgumentReturnType)
|
||||
|
||||
@ -64,6 +64,46 @@ public class ComputeSignatureHelper {
|
||||
return signature;
|
||||
}
|
||||
|
||||
private static DataType replaceAnyDataTypeWithOutIndex(DataType sigType, DataType expressionType) {
|
||||
if (expressionType instanceof NullType) {
|
||||
if (sigType instanceof ArrayType) {
|
||||
return ArrayType.of(replaceAnyDataTypeWithOutIndex(
|
||||
((ArrayType) sigType).getItemType(), NullType.INSTANCE));
|
||||
} else if (sigType instanceof MapType) {
|
||||
return MapType.of(replaceAnyDataTypeWithOutIndex(((MapType) sigType).getKeyType(), NullType.INSTANCE),
|
||||
replaceAnyDataTypeWithOutIndex(((MapType) sigType).getValueType(), NullType.INSTANCE));
|
||||
} else if (sigType instanceof StructType) {
|
||||
// TODO: do not support struct type now
|
||||
// throw new AnalysisException("do not support struct type now");
|
||||
return sigType;
|
||||
} else {
|
||||
if (sigType instanceof AnyDataType
|
||||
&& ((AnyDataType) sigType).getIndex() == AnyDataType.INDEX_OF_INSTANCE_WITHOUT_INDEX) {
|
||||
return expressionType;
|
||||
}
|
||||
return sigType;
|
||||
}
|
||||
} else if (sigType instanceof ArrayType && expressionType instanceof ArrayType) {
|
||||
return ArrayType.of(replaceAnyDataTypeWithOutIndex(
|
||||
((ArrayType) sigType).getItemType(), ((ArrayType) expressionType).getItemType()));
|
||||
} else if (sigType instanceof MapType && expressionType instanceof MapType) {
|
||||
return MapType.of(replaceAnyDataTypeWithOutIndex(
|
||||
((MapType) sigType).getKeyType(), ((MapType) expressionType).getKeyType()),
|
||||
replaceAnyDataTypeWithOutIndex(
|
||||
((MapType) sigType).getValueType(), ((MapType) expressionType).getValueType()));
|
||||
} else if (sigType instanceof StructType && expressionType instanceof StructType) {
|
||||
// TODO: do not support struct type now
|
||||
// throw new AnalysisException("do not support struct type now");
|
||||
return sigType;
|
||||
} else {
|
||||
if (sigType instanceof AnyDataType
|
||||
&& ((AnyDataType) sigType).getIndex() == AnyDataType.INDEX_OF_INSTANCE_WITHOUT_INDEX) {
|
||||
return expressionType;
|
||||
}
|
||||
return sigType;
|
||||
}
|
||||
}
|
||||
|
||||
private static void collectAnyDataType(DataType sigType, DataType expressionType,
|
||||
Map<Integer, List<DataType>> indexToArgumentTypes) {
|
||||
if (expressionType instanceof NullType) {
|
||||
@ -173,6 +213,26 @@ public class ComputeSignatureHelper {
|
||||
}
|
||||
}
|
||||
|
||||
/** implementFollowToAnyDataType */
|
||||
public static FunctionSignature implementAnyDataTypeWithOutIndex(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
// collect all any data type with index
|
||||
List<DataType> newArgTypes = Lists.newArrayList();
|
||||
for (int i = 0; i < arguments.size(); i++) {
|
||||
DataType sigType;
|
||||
if (i >= signature.argumentsTypes.size()) {
|
||||
sigType = signature.getVarArgType().orElseThrow(
|
||||
() -> new AnalysisException("function arity not match with signature"));
|
||||
} else {
|
||||
sigType = signature.argumentsTypes.get(i);
|
||||
}
|
||||
DataType expressionType = arguments.get(i).getDataType();
|
||||
newArgTypes.add(replaceAnyDataTypeWithOutIndex(sigType, expressionType));
|
||||
}
|
||||
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
|
||||
return signature;
|
||||
}
|
||||
|
||||
/** implementFollowToAnyDataType */
|
||||
public static FunctionSignature implementAnyDataTypeWithIndex(
|
||||
FunctionSignature signature, List<Expression> arguments) {
|
||||
|
||||
@ -21,27 +21,14 @@ import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.SearchSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.BigIntType;
|
||||
import org.apache.doris.nereids.types.BooleanType;
|
||||
import org.apache.doris.nereids.types.CharType;
|
||||
import org.apache.doris.nereids.types.DateTimeType;
|
||||
import org.apache.doris.nereids.types.DateTimeV2Type;
|
||||
import org.apache.doris.nereids.types.DateType;
|
||||
import org.apache.doris.nereids.types.DateV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV2Type;
|
||||
import org.apache.doris.nereids.types.DecimalV3Type;
|
||||
import org.apache.doris.nereids.types.DoubleType;
|
||||
import org.apache.doris.nereids.types.FloatType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
import org.apache.doris.nereids.types.LargeIntType;
|
||||
import org.apache.doris.nereids.types.SmallIntType;
|
||||
import org.apache.doris.nereids.types.StringType;
|
||||
import org.apache.doris.nereids.types.TinyIntType;
|
||||
import org.apache.doris.nereids.types.VarcharType;
|
||||
import org.apache.doris.nereids.types.coercion.AnyDataType;
|
||||
import org.apache.doris.nereids.types.coercion.PrimitiveType;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
@ -50,27 +37,21 @@ import java.util.List;
|
||||
* AggregateFunction 'histogram'. This class is generated by GenerateFunction.
|
||||
*/
|
||||
public class Histogram extends AggregateFunction
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
|
||||
implements ExplicitlyCastableSignature, PropagateNullable {
|
||||
|
||||
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(BooleanType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(TinyIntType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(SmallIntType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(IntegerType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(BigIntType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(LargeIntType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(FloatType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DoubleType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV3Type.WILDCARD),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateTimeType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateV2Type.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateTimeV2Type.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(CharType.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE)
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
|
||||
.args(AnyDataType.INSTANCE_WITHOUT_INDEX),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
|
||||
.args(AnyDataType.INSTANCE_WITHOUT_INDEX, IntegerType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
|
||||
.args(AnyDataType.INSTANCE_WITHOUT_INDEX, DoubleType.INSTANCE, IntegerType.INSTANCE)
|
||||
);
|
||||
|
||||
private Histogram(boolean distinct, List<Expression> args) {
|
||||
super("histogram", distinct, args);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 1 argument.
|
||||
*/
|
||||
@ -78,6 +59,14 @@ public class Histogram extends AggregateFunction
|
||||
super("histogram", arg);
|
||||
}
|
||||
|
||||
public Histogram(Expression arg0, Expression arg1) {
|
||||
super("histogram", arg0, arg1);
|
||||
}
|
||||
|
||||
public Histogram(Expression arg0, Expression arg1, Expression arg2) {
|
||||
super("histogram", arg0, arg1, arg2);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 1 argument.
|
||||
*/
|
||||
@ -85,13 +74,33 @@ public class Histogram extends AggregateFunction
|
||||
super("histogram", distinct, arg);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 2 argument.
|
||||
*/
|
||||
public Histogram(boolean distinct, Expression arg0, Expression arg1) {
|
||||
super("histogram", distinct, arg0, arg1);
|
||||
}
|
||||
|
||||
/**
|
||||
* constructor with 3 argument.
|
||||
*/
|
||||
public Histogram(boolean distinct, Expression arg0, Expression arg1, Expression arg2) {
|
||||
super("histogram", distinct, arg0, arg1, arg2);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkLegalityBeforeTypeCoercion() {
|
||||
if (!(child(0).getDataType() instanceof PrimitiveType)) {
|
||||
SearchSignature.throwCanNotFoundFunctionException(this.getName(), getArguments());
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* withDistinctAndChildren.
|
||||
*/
|
||||
@Override
|
||||
public Histogram withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() == 1);
|
||||
return new Histogram(distinct, children.get(0));
|
||||
return new Histogram(distinct, children);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -28,7 +28,8 @@ import java.util.Locale;
|
||||
*/
|
||||
public class AnyDataType extends DataType {
|
||||
|
||||
public static final AnyDataType INSTANCE_WITHOUT_INDEX = new AnyDataType(-1);
|
||||
public static final int INDEX_OF_INSTANCE_WITHOUT_INDEX = -1;
|
||||
public static final AnyDataType INSTANCE_WITHOUT_INDEX = new AnyDataType(INDEX_OF_INSTANCE_WITHOUT_INDEX);
|
||||
|
||||
private final int index;
|
||||
|
||||
|
||||
@ -146,10 +146,9 @@ public class TypeCoercionUtils {
|
||||
* Return Optional.empty() if we cannot do implicit cast.
|
||||
*/
|
||||
public static Optional<DataType> implicitCast(DataType input, DataType expected) {
|
||||
if ((input instanceof ArrayType || input instanceof NullType) && expected instanceof ArrayType) {
|
||||
if (input instanceof ArrayType && expected instanceof ArrayType) {
|
||||
Optional<DataType> itemType = implicitCast(
|
||||
input instanceof ArrayType ? ((ArrayType) input).getItemType() : input,
|
||||
((ArrayType) expected).getItemType());
|
||||
((ArrayType) input).getItemType(), ((ArrayType) expected).getItemType());
|
||||
return itemType.map(ArrayType::of);
|
||||
} else if (input instanceof MapType && expected instanceof MapType) {
|
||||
Optional<DataType> keyType = implicitCast(
|
||||
|
||||
Reference in New Issue
Block a user