[fix](Nereids) support AnyDataType in function signature (#25173)

1. support AnyDataType in function signature
2. update histogram signature
This commit is contained in:
morrySnow
2023-10-10 19:09:47 +08:00
committed by GitHub
parent ba1edcf2dc
commit 8b56ca84c7
6 changed files with 113 additions and 39 deletions

View File

@ -108,6 +108,7 @@ public interface ComputeSignature extends FunctionTrait, ImplicitCastInputTypes
// If you want to add some special cases, please override this method in the special
// function class, like 'If' function and 'Substring' function.
return ComputeSignatureChain.from(this, signature, getArguments())
.then(ComputeSignatureHelper::implementAnyDataTypeWithOutIndex)
.then(ComputeSignatureHelper::implementAnyDataTypeWithIndex)
.then(ComputeSignatureHelper::computePrecision)
.then(ComputeSignatureHelper::implementFollowToArgumentReturnType)

View File

@ -64,6 +64,46 @@ public class ComputeSignatureHelper {
return signature;
}
private static DataType replaceAnyDataTypeWithOutIndex(DataType sigType, DataType expressionType) {
if (expressionType instanceof NullType) {
if (sigType instanceof ArrayType) {
return ArrayType.of(replaceAnyDataTypeWithOutIndex(
((ArrayType) sigType).getItemType(), NullType.INSTANCE));
} else if (sigType instanceof MapType) {
return MapType.of(replaceAnyDataTypeWithOutIndex(((MapType) sigType).getKeyType(), NullType.INSTANCE),
replaceAnyDataTypeWithOutIndex(((MapType) sigType).getValueType(), NullType.INSTANCE));
} else if (sigType instanceof StructType) {
// TODO: do not support struct type now
// throw new AnalysisException("do not support struct type now");
return sigType;
} else {
if (sigType instanceof AnyDataType
&& ((AnyDataType) sigType).getIndex() == AnyDataType.INDEX_OF_INSTANCE_WITHOUT_INDEX) {
return expressionType;
}
return sigType;
}
} else if (sigType instanceof ArrayType && expressionType instanceof ArrayType) {
return ArrayType.of(replaceAnyDataTypeWithOutIndex(
((ArrayType) sigType).getItemType(), ((ArrayType) expressionType).getItemType()));
} else if (sigType instanceof MapType && expressionType instanceof MapType) {
return MapType.of(replaceAnyDataTypeWithOutIndex(
((MapType) sigType).getKeyType(), ((MapType) expressionType).getKeyType()),
replaceAnyDataTypeWithOutIndex(
((MapType) sigType).getValueType(), ((MapType) expressionType).getValueType()));
} else if (sigType instanceof StructType && expressionType instanceof StructType) {
// TODO: do not support struct type now
// throw new AnalysisException("do not support struct type now");
return sigType;
} else {
if (sigType instanceof AnyDataType
&& ((AnyDataType) sigType).getIndex() == AnyDataType.INDEX_OF_INSTANCE_WITHOUT_INDEX) {
return expressionType;
}
return sigType;
}
}
private static void collectAnyDataType(DataType sigType, DataType expressionType,
Map<Integer, List<DataType>> indexToArgumentTypes) {
if (expressionType instanceof NullType) {
@ -173,6 +213,26 @@ public class ComputeSignatureHelper {
}
}
/** implementFollowToAnyDataType */
public static FunctionSignature implementAnyDataTypeWithOutIndex(
FunctionSignature signature, List<Expression> arguments) {
// collect all any data type with index
List<DataType> newArgTypes = Lists.newArrayList();
for (int i = 0; i < arguments.size(); i++) {
DataType sigType;
if (i >= signature.argumentsTypes.size()) {
sigType = signature.getVarArgType().orElseThrow(
() -> new AnalysisException("function arity not match with signature"));
} else {
sigType = signature.argumentsTypes.get(i);
}
DataType expressionType = arguments.get(i).getDataType();
newArgTypes.add(replaceAnyDataTypeWithOutIndex(sigType, expressionType));
}
signature = signature.withArgumentTypes(signature.hasVarArgs, newArgTypes);
return signature;
}
/** implementFollowToAnyDataType */
public static FunctionSignature implementAnyDataTypeWithIndex(
FunctionSignature signature, List<Expression> arguments) {

View File

@ -21,27 +21,14 @@ import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.functions.SearchSignature;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DateTimeType;
import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DateV2Type;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.types.FloatType;
import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.PrimitiveType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
@ -50,27 +37,21 @@ import java.util.List;
* AggregateFunction 'histogram'. This class is generated by GenerateFunction.
*/
public class Histogram extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable {
implements ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(BooleanType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(TinyIntType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(SmallIntType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(IntegerType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(BigIntType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(LargeIntType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(FloatType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DoubleType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DecimalV3Type.WILDCARD),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateTimeType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateV2Type.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(DateTimeV2Type.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(CharType.SYSTEM_DEFAULT),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(StringType.INSTANCE)
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(AnyDataType.INSTANCE_WITHOUT_INDEX),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(AnyDataType.INSTANCE_WITHOUT_INDEX, IntegerType.INSTANCE),
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
.args(AnyDataType.INSTANCE_WITHOUT_INDEX, DoubleType.INSTANCE, IntegerType.INSTANCE)
);
private Histogram(boolean distinct, List<Expression> args) {
super("histogram", distinct, args);
}
/**
* constructor with 1 argument.
*/
@ -78,6 +59,14 @@ public class Histogram extends AggregateFunction
super("histogram", arg);
}
public Histogram(Expression arg0, Expression arg1) {
super("histogram", arg0, arg1);
}
public Histogram(Expression arg0, Expression arg1, Expression arg2) {
super("histogram", arg0, arg1, arg2);
}
/**
* constructor with 1 argument.
*/
@ -85,13 +74,33 @@ public class Histogram extends AggregateFunction
super("histogram", distinct, arg);
}
/**
* constructor with 2 argument.
*/
public Histogram(boolean distinct, Expression arg0, Expression arg1) {
super("histogram", distinct, arg0, arg1);
}
/**
* constructor with 3 argument.
*/
public Histogram(boolean distinct, Expression arg0, Expression arg1, Expression arg2) {
super("histogram", distinct, arg0, arg1, arg2);
}
@Override
public void checkLegalityBeforeTypeCoercion() {
if (!(child(0).getDataType() instanceof PrimitiveType)) {
SearchSignature.throwCanNotFoundFunctionException(this.getName(), getArguments());
}
}
/**
* withDistinctAndChildren.
*/
@Override
public Histogram withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new Histogram(distinct, children.get(0));
return new Histogram(distinct, children);
}
@Override

View File

@ -28,7 +28,8 @@ import java.util.Locale;
*/
public class AnyDataType extends DataType {
public static final AnyDataType INSTANCE_WITHOUT_INDEX = new AnyDataType(-1);
public static final int INDEX_OF_INSTANCE_WITHOUT_INDEX = -1;
public static final AnyDataType INSTANCE_WITHOUT_INDEX = new AnyDataType(INDEX_OF_INSTANCE_WITHOUT_INDEX);
private final int index;

View File

@ -146,10 +146,9 @@ public class TypeCoercionUtils {
* Return Optional.empty() if we cannot do implicit cast.
*/
public static Optional<DataType> implicitCast(DataType input, DataType expected) {
if ((input instanceof ArrayType || input instanceof NullType) && expected instanceof ArrayType) {
if (input instanceof ArrayType && expected instanceof ArrayType) {
Optional<DataType> itemType = implicitCast(
input instanceof ArrayType ? ((ArrayType) input).getItemType() : input,
((ArrayType) expected).getItemType());
((ArrayType) input).getItemType(), ((ArrayType) expected).getItemType());
return itemType.map(ArrayType::of);
} else if (input instanceof MapType && expected instanceof MapType) {
Optional<DataType> keyType = implicitCast(