[feature](function) support type template in SQL function (#17344)
A new way just like c++ template is proposed in this PR. The previous functions can be defined much simpler using template function.
# map element extract template function
[['element_at', '%element_extract%'], 'E', ['ARRAY<E>', 'BIGINT'], 'ALWAYS_NULLABLE', ['E']],
# map element extract template function
[['element_at', '%element_extract%'], 'V', ['MAP<K, V>', 'K'], 'ALWAYS_NULLABLE', ['K', 'V']],
BTW, the plain type function is not affected and the legacy ARRAY_X MAP_K_V is still supported for compatability.
This commit is contained in:
@ -1369,37 +1369,34 @@ public class FunctionCallExpr extends Expr {
|
||||
}
|
||||
}
|
||||
|
||||
if (!fn.getFunctionName().getFunction().equals(ELEMENT_EXTRACT_FN_NAME)) {
|
||||
Type[] args = fn.getArgs();
|
||||
if (args.length > 0) {
|
||||
// Implicitly cast all the children to match the function if necessary
|
||||
for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) {
|
||||
// For varargs, we must compare with the last type in callArgs.argTypes.
|
||||
int ix = Math.min(args.length - 1, i);
|
||||
if (fnName.getFunction().equalsIgnoreCase("money_format")
|
||||
&& children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()) {
|
||||
continue;
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("array")
|
||||
&& (children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()
|
||||
|| children.get(0).getType().isDatetimeV2() && args[ix].isDatetimeV2())) {
|
||||
continue;
|
||||
} else if ((fnName.getFunction().equalsIgnoreCase("array_min") || fnName.getFunction()
|
||||
.equalsIgnoreCase("array_max") || fnName.getFunction().equalsIgnoreCase("element_at"))
|
||||
&& ((
|
||||
children.get(0).getType().isDecimalV3() && ((ArrayType) args[ix]).getItemType()
|
||||
.isDecimalV3())
|
||||
|| (children.get(0).getType().isDatetimeV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDatetimeV2())
|
||||
|| (children.get(0).getType().isDecimalV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDecimalV2()))) {
|
||||
continue;
|
||||
} else if (!argTypes[i].matchesType(args[ix])
|
||||
&& !(argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
|
||||
&& (!fn.getReturnType().isDecimalV3()
|
||||
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3()
|
||||
&& args[ix].isDecimalV3()))) {
|
||||
uncheckedCastChild(args[ix], i);
|
||||
}
|
||||
Type[] args = fn.getArgs();
|
||||
if (args.length > 0) {
|
||||
// Implicitly cast all the children to match the function if necessary
|
||||
for (int i = 0; i < argTypes.length - orderByElements.size(); ++i) {
|
||||
// For varargs, we must compare with the last type in callArgs.argTypes.
|
||||
int ix = Math.min(args.length - 1, i);
|
||||
if (fnName.getFunction().equalsIgnoreCase("money_format")
|
||||
&& children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()) {
|
||||
continue;
|
||||
} else if (fnName.getFunction().equalsIgnoreCase("array")
|
||||
&& (children.get(0).getType().isDecimalV3() && args[ix].isDecimalV3()
|
||||
|| children.get(0).getType().isDatetimeV2() && args[ix].isDatetimeV2())) {
|
||||
continue;
|
||||
} else if ((fnName.getFunction().equalsIgnoreCase("array_min") || fnName.getFunction()
|
||||
.equalsIgnoreCase("array_max") || fnName.getFunction().equalsIgnoreCase("element_at"))
|
||||
&& ((
|
||||
children.get(0).getType().isDecimalV3() && ((ArrayType) args[ix]).getItemType()
|
||||
.isDecimalV3())
|
||||
|| (children.get(0).getType().isDatetimeV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDatetimeV2())
|
||||
|| (children.get(0).getType().isDecimalV2()
|
||||
&& ((ArrayType) args[ix]).getItemType().isDecimalV2()))) {
|
||||
continue;
|
||||
} else if (!argTypes[i].matchesType(args[ix]) && !(
|
||||
argTypes[i].isDateOrDateTime() && args[ix].isDateOrDateTime())
|
||||
&& (!fn.getReturnType().isDecimalV3()
|
||||
|| (argTypes[i].isValid() && !argTypes[i].isDecimalV3() && args[ix].isDecimalV3()))) {
|
||||
uncheckedCastChild(args[ix], i);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -814,4 +814,14 @@ public class Function implements Writable {
|
||||
throw new UserException("failed to serialize function: " + functionName(), t);
|
||||
}
|
||||
}
|
||||
|
||||
public boolean hasTemplateArg() {
|
||||
for (Type t : getArgs()) {
|
||||
if (t.hasTemplateType()) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
@ -1223,6 +1223,34 @@ public class FunctionSet<T> {
|
||||
return null;
|
||||
}
|
||||
|
||||
List<Function> normalFunctions = Lists.newArrayList();
|
||||
List<Function> templateFunctions = Lists.newArrayList();
|
||||
for (Function fn : fns) {
|
||||
if (fn.hasTemplateArg()) {
|
||||
templateFunctions.add(fn);
|
||||
} else {
|
||||
normalFunctions.add(fn);
|
||||
}
|
||||
}
|
||||
|
||||
// try normal functions first
|
||||
Function fn = getFunction(desc, mode, normalFunctions);
|
||||
if (fn != null) {
|
||||
return fn;
|
||||
}
|
||||
|
||||
// then specialize template functions and try them
|
||||
List<Function> specializedTemplateFunctions = Lists.newArrayList();
|
||||
for (Function f : templateFunctions) {
|
||||
f = FunctionSet.specializeTemplateFunction(f, desc);
|
||||
if (f != null) {
|
||||
specializedTemplateFunctions.add(f);
|
||||
}
|
||||
}
|
||||
return getFunction(desc, mode, specializedTemplateFunctions);
|
||||
}
|
||||
|
||||
private Function getFunction(Function desc, Function.CompareMode mode, List<Function> fns) {
|
||||
// First check for identical
|
||||
for (Function f : fns) {
|
||||
if (f.compare(desc, Function.CompareMode.IS_IDENTICAL)) {
|
||||
@ -1262,6 +1290,45 @@ public class FunctionSet<T> {
|
||||
return null;
|
||||
}
|
||||
|
||||
public static Function specializeTemplateFunction(Function templateFunction, Function requestFunction) {
|
||||
try {
|
||||
boolean hasTemplateType = false;
|
||||
LOG.debug("templateFunction signature: " + templateFunction.signatureString()
|
||||
+ " return: " + templateFunction.getReturnType());
|
||||
LOG.debug("requestFunction signature: " + requestFunction.signatureString()
|
||||
+ " return: " + requestFunction.getReturnType());
|
||||
Function specializedFunction = templateFunction;
|
||||
if (templateFunction instanceof ScalarFunction) {
|
||||
ScalarFunction f = (ScalarFunction) templateFunction;
|
||||
specializedFunction = new ScalarFunction(f.getFunctionName(), Lists.newArrayList(f.getArgs()),
|
||||
f.getReturnType(), f.hasVarArgs(), f.getSymbolName(), f.getBinaryType(),
|
||||
f.isUserVisible(), f.isVectorized(), f.getNullableMode());
|
||||
} else {
|
||||
// TODO(xk)
|
||||
}
|
||||
Type[] args = specializedFunction.getArgs();
|
||||
Map<String, Type> specializedTypeMap = Maps.newHashMap();
|
||||
for (int i = 0; i < args.length; i++) {
|
||||
if (args[i].hasTemplateType()) {
|
||||
hasTemplateType = true;
|
||||
args[i] = args[i].specializeTemplateType(requestFunction.getArgs()[i], specializedTypeMap, false);
|
||||
}
|
||||
}
|
||||
if (specializedFunction.getReturnType().hasTemplateType()) {
|
||||
hasTemplateType = true;
|
||||
specializedFunction.setReturnType(
|
||||
specializedFunction.getReturnType().specializeTemplateType(
|
||||
requestFunction.getReturnType(), specializedTypeMap, true));
|
||||
}
|
||||
LOG.debug("specializedFunction signature: " + specializedFunction.signatureString()
|
||||
+ " return: " + specializedFunction.getReturnType());
|
||||
return hasTemplateType ? specializedFunction : templateFunction;
|
||||
} catch (TypeException e) {
|
||||
LOG.warn("specializeTemplateFunction exception", e);
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* There are essential differences in the implementation of some functions for different
|
||||
* types params, which should be prohibited.
|
||||
|
||||
Reference in New Issue
Block a user