[feature](function) support type template in SQL function (#17344)

A new way just like c++ template is proposed in this PR. The previous functions can be defined much simpler using template function. 

    # map element extract template function
    [['element_at', '%element_extract%'], 'E', ['ARRAY<E>', 'BIGINT'], 'ALWAYS_NULLABLE', ['E']],

    # map element extract template function
    [['element_at', '%element_extract%'], 'V', ['MAP<K, V>', 'K'], 'ALWAYS_NULLABLE', ['K', 'V']],


BTW, the plain type function is not affected and the legacy ARRAY_X MAP_K_V is still supported for compatability.
This commit is contained in:
Kang
2023-03-08 10:51:31 +08:00
committed by GitHub
parent 9213dd906a
commit 4b743061b4
12 changed files with 413 additions and 63 deletions

View File

@ -1369,37 +1369,34 @@ public class FunctionCallExpr extends Expr {
}
}
if (!fn.getFunctionName().getFunction().equals(ELEMENT_EXTRACT_FN_NAME)) {
Type[] args = fn.getArgs();
if (args.length > 0) {
// Implicitly cast all the children to match the function if necessary
for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) {
// For varargs, we must compare with the last type in callArgs.argTypes.
int ix = Math.min(args.length - 1, i);
if (fnName.getFunction().equalsIgnoreCase("money_format")
&& children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()) {
continue;
} else if (fnName.getFunction().equalsIgnoreCase("array")
&& (children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()
|| children.get(0).getType().isDatetimeV2() && args[ix].isDatetimeV2())) {
continue;
} else if ((fnName.getFunction().equalsIgnoreCase("array_min") || fnName.getFunction()
.equalsIgnoreCase("array_max") || fnName.getFunction().equalsIgnoreCase("element_at"))
&& ((
children.get(0).getType().isDecimalV3() && ((ArrayType) args[ix]).getItemType()
.isDecimalV3())
|| (children.get(0).getType().isDatetimeV2()
&& ((ArrayType) args[ix]).getItemType().isDatetimeV2())
|| (children.get(0).getType().isDecimalV2()
&& ((ArrayType) args[ix]).getItemType().isDecimalV2()))) {
continue;
} else if (!argTypes[i].matchesType(args[ix])
&& !(argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
&& (!fn.getReturnType().isDecimalV3()
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3()
&& args[ix].isDecimalV3()))) {
uncheckedCastChild(args[ix], i);
}
Type[] args = fn.getArgs();
if (args.length > 0) {
// Implicitly cast all the children to match the function if necessary
for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) {
// For varargs, we must compare with the last type in callArgs.argTypes.
int ix = Math.min(args.length - 1, i);
if (fnName.getFunction().equalsIgnoreCase("money_format")
&& children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()) {
continue;
} else if (fnName.getFunction().equalsIgnoreCase("array")
&& (children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()
|| children.get(0).getType().isDatetimeV2() && args[ix].isDatetimeV2())) {
continue;
} else if ((fnName.getFunction().equalsIgnoreCase("array_min") || fnName.getFunction()
.equalsIgnoreCase("array_max") || fnName.getFunction().equalsIgnoreCase("element_at"))
&& ((
children.get(0).getType().isDecimalV3() && ((ArrayType) args[ix]).getItemType()
.isDecimalV3())
|| (children.get(0).getType().isDatetimeV2()
&& ((ArrayType) args[ix]).getItemType().isDatetimeV2())
|| (children.get(0).getType().isDecimalV2()
&& ((ArrayType) args[ix]).getItemType().isDecimalV2()))) {
continue;
} else if (!argTypes[i].matchesType(args[ix]) && !(
argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
&& (!fn.getReturnType().isDecimalV3()
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3() && args[ix].isDecimalV3()))) {
uncheckedCastChild(args[ix], i);
}
}
}

View File

@ -814,4 +814,14 @@ public class Function implements Writable {
throw new UserException("failed to serialize function: " + functionName(), t);
}
}
public boolean hasTemplateArg() {
for (Type t : getArgs()) {
if (t.hasTemplateType()) {
return true;
}
}
return false;
}
}

View File

@ -1223,6 +1223,34 @@ public class FunctionSet<T> {
return null;
}
List<Function> normalFunctions = Lists.newArrayList();
List<Function> templateFunctions = Lists.newArrayList();
for (Function fn : fns) {
if (fn.hasTemplateArg()) {
templateFunctions.add(fn);
} else {
normalFunctions.add(fn);
}
}
// try normal functions first
Function fn = getFunction(desc, mode, normalFunctions);
if (fn != null) {
return fn;
}
// then specialize template functions and try them
List<Function> specializedTemplateFunctions = Lists.newArrayList();
for (Function f : templateFunctions) {
f = FunctionSet.specializeTemplateFunction(f, desc);
if (f != null) {
specializedTemplateFunctions.add(f);
}
}
return getFunction(desc, mode, specializedTemplateFunctions);
}
private Function getFunction(Function desc, Function.CompareMode mode, List<Function> fns) {
// First check for identical
for (Function f : fns) {
if (f.compare(desc, Function.CompareMode.IS_IDENTICAL)) {
@ -1262,6 +1290,45 @@ public class FunctionSet<T> {
return null;
}
public static Function specializeTemplateFunction(Function templateFunction, Function requestFunction) {
try {
boolean hasTemplateType = false;
LOG.debug("templateFunction signature: " + templateFunction.signatureString()
+ " return: " + templateFunction.getReturnType());
LOG.debug("requestFunction signature: " + requestFunction.signatureString()
+ " return: " + requestFunction.getReturnType());
Function specializedFunction = templateFunction;
if (templateFunction instanceof ScalarFunction) {
ScalarFunction f = (ScalarFunction) templateFunction;
specializedFunction = new ScalarFunction(f.getFunctionName(), Lists.newArrayList(f.getArgs()),
f.getReturnType(), f.hasVarArgs(), f.getSymbolName(), f.getBinaryType(),
f.isUserVisible(), f.isVectorized(), f.getNullableMode());
} else {
// TODO(xk)
}
Type[] args = specializedFunction.getArgs();
Map<String, Type> specializedTypeMap = Maps.newHashMap();
for (int i = 0; i < args.length; i++) {
if (args[i].hasTemplateType()) {
hasTemplateType = true;
args[i] = args[i].specializeTemplateType(requestFunction.getArgs()[i], specializedTypeMap, false);
}
}
if (specializedFunction.getReturnType().hasTemplateType()) {
hasTemplateType = true;
specializedFunction.setReturnType(
specializedFunction.getReturnType().specializeTemplateType(
requestFunction.getReturnType(), specializedTypeMap, true));
}
LOG.debug("specializedFunction signature: " + specializedFunction.signatureString()
+ " return: " + specializedFunction.getReturnType());
return hasTemplateType ? specializedFunction : templateFunction;
} catch (TypeException e) {
LOG.warn("specializeTemplateFunction exception", e);
return null;
}
}
/**
* There are essential differences in the implementation of some functions for different
* types params, which should be prohibited.