[opt](Nereids) support search from override udfs with same arity (#40432) (#40751)

pick from master #40432

create alias function f1(int) with parameter(id) as abs(id); create
alias function f1(string) with parameter(id) as substr(id, 2); select
f1('1'); -- bind on f1(string)
select f1(1);   -- bind on f1(int)

test case already existed in P0
This commit is contained in:
morrySnow
2024-09-12 19:58:49 +08:00
committed by GitHub
parent 4b7b43b5ca
commit fedadbba6e
7 changed files with 124 additions and 41 deletions

View File

@ -21,7 +21,9 @@ import org.apache.doris.datasource.InternalCatalog;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.types.DataType;
@ -156,18 +158,33 @@ public class FunctionRegistry {
+ "' which has " + arity + " arity. Candidate functions are: " + candidateHints);
}
if (candidateBuilders.size() > 1) {
String candidateHints = getCandidateHint(name, candidateBuilders);
// TODO: NereidsPlanner not supported override function by the same arity, we will support it later
if (ConnectContext.get() != null) {
try {
ConnectContext.get().getSessionVariable().enableFallbackToOriginalPlannerOnce();
} catch (Throwable t) {
// ignore error
boolean needChooseOne = true;
List<FunctionSignature> signatures = Lists.newArrayListWithCapacity(candidateBuilders.size());
for (FunctionBuilder functionBuilder : candidateBuilders) {
if (functionBuilder instanceof UdfBuilder) {
signatures.addAll(((UdfBuilder) functionBuilder).getSignatures());
} else {
needChooseOne = false;
break;
}
}
for (Object argument : arguments) {
if (!(argument instanceof Expression)) {
needChooseOne = false;
break;
}
}
if (needChooseOne) {
FunctionSignature signature = new UdfSignatureSearcher(signatures, (List) arguments).getSignature();
for (int i = 0; i < signatures.size(); i++) {
if (signatures.get(i).equals(signature)) {
return candidateBuilders.get(i);
}
}
}
String candidateHints = getCandidateHint(name, candidateBuilders);
throw new AnalysisException("Function '" + qualifiedName + "' is ambiguous: " + candidateHints);
}
return candidateBuilders.get(0);
}
@ -235,4 +252,63 @@ public class FunctionRegistry {
.removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes));
}
}
/**
* use for search appropriate signature for UDFs if candidate more than one.
*/
static class UdfSignatureSearcher implements ExplicitlyCastableSignature {
private final List<FunctionSignature> signatures;
private final List<Expression> arguments;
public UdfSignatureSearcher(List<FunctionSignature> signatures, List<Expression> arguments) {
this.signatures = signatures;
this.arguments = arguments;
}
@Override
public List<FunctionSignature> getSignatures() {
return signatures;
}
@Override
public FunctionSignature getSignature() {
return searchSignature(signatures);
}
@Override
public boolean nullable() {
throw new AnalysisException("could not call nullable on UdfSignatureSearcher");
}
@Override
public List<Expression> children() {
return arguments;
}
@Override
public Expression child(int index) {
return arguments.get(index);
}
@Override
public int arity() {
return arguments.size();
}
@Override
public <T> Optional<T> getMutableState(String key) {
return Optional.empty();
}
@Override
public void setMutableState(String key, Object value) {
}
@Override
public Expression withChildren(List<Expression> children) {
throw new AnalysisException("could not call withChildren on UdfSignatureSearcher");
}
}
}

View File

@ -30,7 +30,6 @@ import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.BiFunction;
public class FunctionSignature {
public final DataType returnType;
@ -78,21 +77,6 @@ public class FunctionSignature {
return new FunctionSignature(returnType, hasVarArgs, argumentsTypes);
}
/**
* change argument type by the signature's type and the corresponding argument's type
* @param arguments arguments
* @param transform param1: signature's type, param2: argument's type, return new type you want to change
* @return
*/
public FunctionSignature withArgumentTypes(List<Expression> arguments,
BiFunction<DataType, Expression, DataType> transform) {
List<DataType> newTypes = Lists.newArrayList();
for (int i = 0; i < arguments.size(); i++) {
newTypes.add(transform.apply(getArgType(i), arguments.get(i)));
}
return withArgumentTypes(hasVarArgs, newTypes);
}
/**
* change argument type by the signature's type and the corresponding argument's type
* @param arguments arguments
@ -145,6 +129,24 @@ public class FunctionSignature {
.toString();
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
FunctionSignature signature = (FunctionSignature) o;
return hasVarArgs == signature.hasVarArgs && arity == signature.arity && Objects.equals(returnType,
signature.returnType) && Objects.equals(argumentsTypes, signature.argumentsTypes);
}
@Override
public int hashCode() {
return Objects.hash(returnType, hasVarArgs, argumentsTypes, arity);
}
public static class FuncSigBuilder {
public final DataType returnType;

View File

@ -32,7 +32,6 @@ import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.NullType;
import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Maps;
@ -62,8 +61,7 @@ public class AliasUdf extends ScalarFunction implements ExplicitlyCastableSignat
@Override
public List<FunctionSignature> getSignatures() {
return ImmutableList.of(Suppliers.memoize(() -> FunctionSignature
.of(NullType.INSTANCE, argTypes)).get());
return ImmutableList.of(FunctionSignature.of(NullType.INSTANCE, argTypes));
}
public List<String> getParameters() {

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions.functions.udf;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.ReflectionUtils;
import org.apache.doris.nereids.analyzer.Scope;
@ -25,7 +26,6 @@ import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.util.TypeCoercionUtils;
@ -53,6 +53,11 @@ public class AliasUdfBuilder extends UdfBuilder {
return aliasUdf.getArgTypes();
}
@Override
public List<FunctionSignature> getSignatures() {
return aliasUdf.getSignatures();
}
@Override
public Class<? extends BoundFunction> functionClass() {
return AliasUdf.class;
@ -109,17 +114,4 @@ public class AliasUdfBuilder extends UdfBuilder {
return Pair.of(udfAnalyzer.analyze(aliasUdf.getUnboundFunction()), boundAliasFunction);
}
private static class SlotReplacer extends DefaultExpressionRewriter<Map<SlotReference, Expression>> {
public static final SlotReplacer INSTANCE = new SlotReplacer();
public Expression replace(Expression expression, Map<SlotReference, Expression> context) {
return expression.accept(this, context);
}
@Override
public Expression visitSlotReference(SlotReference slot, Map<SlotReference, Expression> context) {
return context.get(slot);
}
}
}

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions.functions.udf;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.ReflectionUtils;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -50,6 +51,11 @@ public class JavaUdafBuilder extends UdfBuilder {
.collect(Collectors.toList())).get();
}
@Override
public List<FunctionSignature> getSignatures() {
return udaf.getSignatures();
}
@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdaf.class;

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions.functions.udf;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.common.Pair;
import org.apache.doris.common.util.ReflectionUtils;
import org.apache.doris.nereids.trees.expressions.Expression;
@ -52,6 +53,11 @@ public class JavaUdfBuilder extends UdfBuilder {
.collect(Collectors.toList())).get();
}
@Override
public List<FunctionSignature> getSignatures() {
return udf.getSignatures();
}
@Override
public Class<? extends BoundFunction> functionClass() {
return JavaUdf.class;

View File

@ -17,6 +17,7 @@
package org.apache.doris.nereids.trees.expressions.functions.udf;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.types.DataType;
@ -27,4 +28,6 @@ import java.util.List;
*/
public abstract class UdfBuilder extends FunctionBuilder {
public abstract List<DataType> getArgTypes();
public abstract List<FunctionSignature> getSignatures();
}