From 6b8a139f2dfc241ac3948a2ba28d025b811c88be Mon Sep 17 00:00:00 2001 From: 924060929 <924060929@qq.com> Date: Fri, 9 Sep 2022 15:19:45 +0800 Subject: [PATCH] [feature](Nereids) Support function registry (#12481) Support function registry. The classes: - BuiltinFunctions: contains the built-in functions list - FunctionRegistry: used to register scalar functions and aggregate functions, it can find the function by name - FunctionBuilder: used to resolve a BoundFunction class, extract the constructor, and build to a BoundFunction by arguments(`List`) Register example: you can add built-in functions in the list for simplicity ```java public class BuiltinFunctions implements FunctionHelper { public final List scalarFunctions = ImmutableList.of( scalar(Substring.class, "substr", "substring"), scalar(WeekOfYear.class), scalar(Year.class) ); public final ImmutableList aggregateFunctions = ImmutableList.of( agg(Avg.class), agg(Count.class), agg(Max.class), agg(Min.class), agg(Sum.class) ); } ``` Note: - Currently, we only support register scalar functions add aggregate functions, we will support register table functions. - Currently, we only support resolve function by function name and difference arity, but can not resolve the same arity override function, e.g. `some_function(Expression)` and `some_function(Literal)` --- .../doris/catalog/BuiltinFunctions.java | 58 +++++++ .../java/org/apache/doris/catalog/Env.java | 9 ++ .../apache/doris/catalog/FunctionHelper.java | 109 +++++++++++++ .../doris/catalog/FunctionRegistry.java | 91 +++++++++++ .../nereids/rules/analysis/BindFunction.java | 122 +++++--------- .../functions/AggregateFunction.java | 4 +- .../functions/FunctionBuilder.java | 90 +++++++++++ .../expressions/functions/ScalarFunction.java | 29 ++++ .../expressions/functions/Substring.java | 19 +-- .../expressions/functions/WeekOfYear.java | 2 +- .../trees/expressions/functions/Year.java | 2 +- .../rules/analysis/FunctionRegistryTest.java | 153 ++++++++++++++++++ 12 files changed, 590 insertions(+), 98 deletions(-) create mode 100644 fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinFunctions.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java create mode 100644 fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ScalarFunction.java create mode 100644 fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinFunctions.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinFunctions.java new file mode 100644 index 0000000000..365e4db2a6 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/BuiltinFunctions.java @@ -0,0 +1,58 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.catalog; + +import org.apache.doris.nereids.trees.expressions.functions.Avg; +import org.apache.doris.nereids.trees.expressions.functions.Count; +import org.apache.doris.nereids.trees.expressions.functions.Max; +import org.apache.doris.nereids.trees.expressions.functions.Min; +import org.apache.doris.nereids.trees.expressions.functions.Substring; +import org.apache.doris.nereids.trees.expressions.functions.Sum; +import org.apache.doris.nereids.trees.expressions.functions.WeekOfYear; +import org.apache.doris.nereids.trees.expressions.functions.Year; + +import com.google.common.collect.ImmutableList; + +import java.util.List; + +/** + * Built-in functions. + * + * Note: Please ensure that this class only has some lists and no procedural code. + * It helps to be clear and concise. + */ +public class BuiltinFunctions implements FunctionHelper { + public final List scalarFunctions = ImmutableList.of( + scalar(Substring.class, "substr", "substring"), + scalar(WeekOfYear.class), + scalar(Year.class) + ); + + public final ImmutableList aggregateFunctions = ImmutableList.of( + agg(Avg.class), + agg(Count.class), + agg(Max.class), + agg(Min.class), + agg(Sum.class) + ); + + public static final BuiltinFunctions INSTANCE = new BuiltinFunctions(); + + // Note: Do not add any code here! + private BuiltinFunctions() {} +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java index 805f46b35a..c07da5df24 100755 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Env.java @@ -375,6 +375,9 @@ public class Env { private CatalogRecycleBin recycleBin; private FunctionSet functionSet; + // for nereids + private FunctionRegistry functionRegistry; + private MetaReplayState metaReplayState; private BrokerMgr brokerMgr; @@ -555,6 +558,8 @@ public class Env { this.functionSet = new FunctionSet(); this.functionSet.init(); + this.functionRegistry = new FunctionRegistry(); + this.metaReplayState = new MetaReplayState(); this.isDefaultClusterCreated = false; @@ -4306,6 +4311,10 @@ public class Env { LOG.info("successfully create view[" + tableName + "-" + newView.getId() + "]"); } + public FunctionRegistry getFunctionRegistry() { + return functionRegistry; + } + /** * Returns the function that best matches 'desc' that is registered with the * catalog using 'mode' to check for matching. If desc matches multiple diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java new file mode 100644 index 0000000000..3c0c312c44 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionHelper.java @@ -0,0 +1,109 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.catalog; + +import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.ScalarFunction; + +import com.google.common.collect.ImmutableList; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +public interface FunctionHelper { + /** + * put functions into the target map, which the key is the function name, + * and value is FunctionBuilder is converted from NamedFunc. + * @param name2FuncBuilders target Map + * @param functions the NamedFunc list to be put into the target map + */ + static void addFunctions(Map> name2FuncBuilders, + List> functions) { + for (NamedFunc func : functions) { + for (String name : func.names) { + if (name2FuncBuilders.containsKey(name)) { + throw new IllegalStateException("Function '" + name + "' already exists in function registry"); + } + + name2FuncBuilders.put(name, func.functionBuilders); + } + } + } + + default ScalarFunc scalar(Class functionClass) { + String functionName = functionClass.getSimpleName(); + return scalar(functionClass, functionName); + } + + /** + * Resolve ScalaFunction class, convert to FunctionBuilder and wrap to ScalarFunc + * @param functionClass the ScalaFunction class + * @return ScalaFunc which contains the functionName and the FunctionBuilder + */ + default ScalarFunc scalar(Class functionClass, String... functionNames) { + return new ScalarFunc(functionClass, functionNames); + } + + default AggregateFunc agg(Class functionClass) { + String functionName = functionClass.getSimpleName(); + return new AggregateFunc(functionClass, functionName); + } + + /** + * Resolve AggregateFunction class, convert to FunctionBuilder and wrap to AggregateFunc + * @param functionClass the AggregateFunction class + * @return AggregateFunc which contains the functionName and the AggregateFunc + */ + default AggregateFunc agg(Class functionClass, String... functionNames) { + return new AggregateFunc(functionClass, functionNames); + } + + /** + * use this class to prevent the wrong type from being registered, and support multi function names + * like substring and substr. + */ + class NamedFunc { + public final List names; + public final Class functionClass; + + public final List functionBuilders; + + public NamedFunc(Class functionClass, String... names) { + this.functionClass = functionClass; + this.names = Arrays.stream(names) + .map(String::toLowerCase) + .collect(ImmutableList.toImmutableList()); + this.functionBuilders = FunctionBuilder.resolve(functionClass); + } + } + + class ScalarFunc extends NamedFunc { + public ScalarFunc(Class functionClass, String... names) { + super(functionClass, names); + } + } + + class AggregateFunc extends NamedFunc { + public AggregateFunc(Class functionClass, String... names) { + super(functionClass, names); + } + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java new file mode 100644 index 0000000000..fa91d2ff9e --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java @@ -0,0 +1,91 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.catalog; + +import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; + +import com.google.common.annotations.VisibleForTesting; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.stream.Collectors; +import javax.annotation.concurrent.ThreadSafe; + +/** + * New function registry for nereids. + * + * this class is developing for more functions. + */ +@Developing +@ThreadSafe +public class FunctionRegistry { + private final Map> name2Builders; + + public FunctionRegistry() { + name2Builders = new ConcurrentHashMap<>(); + registerBuiltinFunctions(name2Builders); + afterRegisterBuiltinFunctions(name2Builders); + } + + // this function is used to test. + // for example, you can create child class of FunctionRegistry and clear builtin functions or add more functions + // in this method + @VisibleForTesting + protected void afterRegisterBuiltinFunctions(Map> name2Builders) {} + + // currently we only find function by name and arity + public FunctionBuilder findFunctionBuilder(String name, List arguments) { + int arity = arguments.size(); + List functionBuilders = name2Builders.get(name.toLowerCase()); + if (functionBuilders == null || functionBuilders.isEmpty()) { + throw new AnalysisException("Can not found function '" + name + "'"); + } + + List candidateBuilders = functionBuilders.stream() + .filter(functionBuilder -> functionBuilder.arity == arity) + .collect(Collectors.toList()); + if (candidateBuilders.isEmpty()) { + String candidateHints = getCandidateHint(name, candidateBuilders); + throw new AnalysisException("Can not found function '" + name + + "' which has " + arity + " arity. Candidate functions are: " + candidateHints); + } + + if (candidateBuilders.size() > 1) { + String candidateHints = getCandidateHint(name, candidateBuilders); + // NereidsPlanner not supported override function by the same arity, should we support it? + + throw new AnalysisException("Function '" + name + "' is ambiguous: " + candidateHints); + } + return candidateBuilders.get(0); + } + + private void registerBuiltinFunctions(Map> name2Builders) { + FunctionHelper.addFunctions(name2Builders, BuiltinFunctions.INSTANCE.scalarFunctions); + FunctionHelper.addFunctions(name2Builders, BuiltinFunctions.INSTANCE.aggregateFunctions); + } + + public String getCandidateHint(String name, List candidateBuilders) { + return candidateBuilders.stream() + .map(builder -> name + builder.toString()) + .collect(Collectors.joining(", ")); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java index 7f309d768a..db2135fcc3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindFunction.java @@ -18,21 +18,19 @@ package org.apache.doris.nereids.rules.analysis; import org.apache.doris.analysis.ArithmeticExpr.Operator; +import org.apache.doris.catalog.Env; +import org.apache.doris.catalog.FunctionRegistry; import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.TimestampArithmetic; -import org.apache.doris.nereids.trees.expressions.functions.Avg; -import org.apache.doris.nereids.trees.expressions.functions.Count; -import org.apache.doris.nereids.trees.expressions.functions.Max; -import org.apache.doris.nereids.trees.expressions.functions.Min; -import org.apache.doris.nereids.trees.expressions.functions.Substring; -import org.apache.doris.nereids.trees.expressions.functions.Sum; -import org.apache.doris.nereids.trees.expressions.functions.WeekOfYear; -import org.apache.doris.nereids.trees.expressions.functions.Year; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalHaving; import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation; @@ -54,9 +52,10 @@ public class BindFunction implements AnalysisRuleFactory { public List buildRules() { return ImmutableList.of( RuleType.BINDING_ONE_ROW_RELATION_FUNCTION.build( - logicalOneRowRelation().then(oneRowRelation -> { + logicalOneRowRelation().thenApply(ctx -> { + LogicalOneRowRelation oneRowRelation = ctx.root; List projects = oneRowRelation.getProjects(); - List boundProjects = bind(projects); + List boundProjects = bind(projects, ctx.connectContext.getEnv()); // TODO: // trick logic: currently XxxRelation in GroupExpression always difference to each other, // so this rule must check the expression whether is changed to prevent dead loop because @@ -70,113 +69,64 @@ public class BindFunction implements AnalysisRuleFactory { }) ), RuleType.BINDING_PROJECT_FUNCTION.build( - logicalProject().then(project -> { - List boundExpr = bind(project.getProjects()); + logicalProject().thenApply(ctx -> { + LogicalProject project = ctx.root; + List boundExpr = bind(project.getProjects(), ctx.connectContext.getEnv()); return new LogicalProject<>(boundExpr, project.child()); }) ), RuleType.BINDING_AGGREGATE_FUNCTION.build( - logicalAggregate().then(agg -> { - List groupBy = bind(agg.getGroupByExpressions()); - List output = bind(agg.getOutputExpressions()); + logicalAggregate().thenApply(ctx -> { + LogicalAggregate agg = ctx.root; + List groupBy = bind(agg.getGroupByExpressions(), ctx.connectContext.getEnv()); + List output = bind(agg.getOutputExpressions(), ctx.connectContext.getEnv()); return agg.withGroupByAndOutput(groupBy, output); }) ), RuleType.BINDING_FILTER_FUNCTION.build( - logicalFilter().then(filter -> { - List predicates = bind(filter.getExpressions()); + logicalFilter().thenApply(ctx -> { + LogicalFilter filter = ctx.root; + List predicates = bind(filter.getExpressions(), ctx.connectContext.getEnv()); return new LogicalFilter<>(predicates.get(0), filter.child()); }) ), RuleType.BINDING_HAVING_FUNCTION.build( - logicalHaving(logicalAggregate()).then(filter -> { - List predicates = bind(filter.getExpressions()); - return new LogicalHaving<>(predicates.get(0), filter.child()); + logicalHaving(logicalAggregate()).thenApply(ctx -> { + LogicalHaving> having = ctx.root; + List predicates = bind(having.getExpressions(), ctx.connectContext.getEnv()); + return new LogicalHaving<>(predicates.get(0), having.child()); }) ) ); } - private List bind(List exprList) { + private List bind(List exprList, Env env) { return exprList.stream() - .map(FunctionBinder.INSTANCE::bind) + .map(expr -> FunctionBinder.INSTANCE.bind(expr, env)) .collect(Collectors.toList()); } - private static class FunctionBinder extends DefaultExpressionRewriter { + private static class FunctionBinder extends DefaultExpressionRewriter { public static final FunctionBinder INSTANCE = new FunctionBinder(); - public E bind(E expression) { - return (E) expression.accept(this, null); + public E bind(E expression, Env env) { + return (E) expression.accept(this, env); } @Override - public Expression visitUnboundFunction(UnboundFunction unboundFunction, Void context) { - String name = unboundFunction.getName(); - // TODO: lookup function in the function registry - if (name.equalsIgnoreCase("sum")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() != 1) { - return unboundFunction; - } - return new Sum(unboundFunction.getArguments().get(0)); - } else if (name.equalsIgnoreCase("count")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() > 1 || (arguments.size() == 0 && !unboundFunction.isStar())) { - return unboundFunction; - } - if (unboundFunction.isStar() || arguments.stream().allMatch(Expression::isConstant)) { - return new Count(); - } - return new Count(unboundFunction.getArguments().get(0)); - } else if (name.equalsIgnoreCase("max")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() != 1) { - return unboundFunction; - } - return new Max(unboundFunction.getArguments().get(0)); - } else if (name.equalsIgnoreCase("min")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() != 1) { - return unboundFunction; - } - return new Min(unboundFunction.getArguments().get(0)); - } else if (name.equalsIgnoreCase("avg")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() != 1) { - return unboundFunction; - } - return new Avg(unboundFunction.getArguments().get(0)); - } else if (name.equalsIgnoreCase("substr") || name.equalsIgnoreCase("substring")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() == 2) { - return new Substring(unboundFunction.getArguments().get(0), - unboundFunction.getArguments().get(1)); - } else if (arguments.size() == 3) { - return new Substring(unboundFunction.getArguments().get(0), unboundFunction.getArguments().get(1), - unboundFunction.getArguments().get(2)); - } - return unboundFunction; - } else if (name.equalsIgnoreCase("year")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() != 1) { - return unboundFunction; - } - return new Year(unboundFunction.getArguments().get(0)); - } else if (name.equalsIgnoreCase("WeekOfYear")) { - List arguments = unboundFunction.getArguments(); - if (arguments.size() != 1) { - return unboundFunction; - } - return new WeekOfYear(unboundFunction.getArguments().get(0)); - } - return unboundFunction; + public BoundFunction visitUnboundFunction(UnboundFunction unboundFunction, Env env) { + FunctionRegistry functionRegistry = env.getFunctionRegistry(); + String functionName = unboundFunction.getName(); + FunctionBuilder builder = functionRegistry.findFunctionBuilder( + functionName, unboundFunction.getArguments()); + return builder.build(functionName, unboundFunction.getArguments()); } @Override - public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Void context) { + public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env env) { String funcOpName; if (arithmetic.getFuncName() == null) { + // e.g. YEARS_ADD, MONTHS_SUB funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(), (arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB"); } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java index 371aff83b6..73de61a058 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggregateFunction.java @@ -21,7 +21,9 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.DataType; -/** AggregateFunction. */ +/** + * The function which consume arguments in lots of rows and product one value. + */ public abstract class AggregateFunction extends BoundFunction { private DataType intermediate; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java new file mode 100644 index 0000000000..9d5b620775 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java @@ -0,0 +1,90 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions; + +import org.apache.doris.nereids.trees.expressions.Expression; + +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Modifier; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * This class used to resolve a concrete BoundFunction's class and build BoundFunction by a list of Expressions. + */ +public class FunctionBuilder { + public final int arity; + + // Concrete BoundFunction's constructor + private final Constructor builderMethod; + + public FunctionBuilder(Constructor builderMethod) { + this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null"); + this.arity = builderMethod.getParameterCount(); + } + + /** + * build a BoundFunction by function name and arguments. + * @param name function name which in the sql expression + * @param arguments the function's argument expressions + * @return the concrete bound function instance + */ + public BoundFunction build(String name, List arguments) { + try { + return builderMethod.newInstance(arguments.toArray(new Expression[0])); + } catch (Throwable t) { + String argString = arguments.stream() + .map(arg -> arg == null ? "null" : arg.toSql()) + .collect(Collectors.joining(", ", "(", ")")); + throw new IllegalStateException("Can not build function: '" + name + + "', expression: " + name + argString, t); + } + } + + @Override + public String toString() { + return Arrays.stream(builderMethod.getParameterTypes()) + .map(type -> type.getSimpleName()) + .collect(Collectors.joining(", ", "(", ")")); + } + + /** + * resolve a Concrete boundFunction's class and convert the constructors to FunctionBuilder + * @param functionClass a class which is the child class of BoundFunction and can not be abstract class + * @return list of FunctionBuilder which contains the constructor + */ + public static List resolve(Class functionClass) { + Preconditions.checkArgument(!Modifier.isAbstract(functionClass.getModifiers()), + "Can not resolve bind function which is abstract class: " + + functionClass.getSimpleName()); + return Arrays.stream(functionClass.getConstructors()) + .filter(constructor -> Modifier.isPublic(constructor.getModifiers())) + .filter(constructor -> + // all arguments must be Expression + Arrays.stream(functionClass.getTypeParameters()) + .allMatch(Expression.class::isInstance) + ) + .map(constructor -> new FunctionBuilder((Constructor) constructor)) + .collect(ImmutableList.toImmutableList()); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ScalarFunction.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ScalarFunction.java new file mode 100644 index 0000000000..4dd1b0689a --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/ScalarFunction.java @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions.functions; + +import org.apache.doris.nereids.trees.expressions.Expression; + +/** + * The function which consume zero or more arguments in a row and product one value. + */ +public abstract class ScalarFunction extends BoundFunction { + public ScalarFunction(String name, Expression... arguments) { + super(name, arguments); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Substring.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Substring.java index ac90843349..ead4df8d9b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Substring.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Substring.java @@ -19,7 +19,6 @@ package org.apache.doris.nereids.trees.expressions.functions; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression; import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes; import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.IntegerType; @@ -31,11 +30,12 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; import java.util.List; +import java.util.Optional; /** * substring function. */ -public class Substring extends BoundFunction implements TernaryExpression, ImplicitCastInputTypes { +public class Substring extends ScalarFunction implements ImplicitCastInputTypes { // used in interface expectedInputTypes to avoid new list in each time it be called private static final List EXPECTED_INPUT_TYPES = ImmutableList.of( @@ -49,10 +49,10 @@ public class Substring extends BoundFunction implements TernaryExpression, Impli } public Substring(Expression str, Expression pos) { - super("substring", str, pos, new IntegerLiteral(Integer.MAX_VALUE)); + super("substring", str, pos); } - public Expression getTarget() { + public Expression getSource() { return child(0); } @@ -60,21 +60,22 @@ public class Substring extends BoundFunction implements TernaryExpression, Impli return child(1); } - public Expression getLength() { - return child(2); + public Optional getLength() { + return arity() == 3 ? Optional.of(child(2)) : Optional.empty(); } @Override public DataType getDataType() { - if (getLength() instanceof IntegerLiteral) { - return VarcharType.createVarcharType(((IntegerLiteral) getLength()).getValue()); + Optional length = getLength(); + if (length.isPresent() && length.get() instanceof IntegerLiteral) { + return VarcharType.createVarcharType(((IntegerLiteral) length.get()).getValue()); } return VarcharType.SYSTEM_DEFAULT; } @Override public boolean nullable() { - return first().nullable(); + return children().stream().anyMatch(Expression::nullable); } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/WeekOfYear.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/WeekOfYear.java index 8660b72492..baba3673ca 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/WeekOfYear.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/WeekOfYear.java @@ -34,7 +34,7 @@ import java.util.List; /** * weekOfYear function */ -public class WeekOfYear extends BoundFunction implements UnaryExpression, ImplicitCastInputTypes { +public class WeekOfYear extends ScalarFunction implements UnaryExpression, ImplicitCastInputTypes { private static final List EXPECTED_INPUT_TYPES = ImmutableList.of( new TypeCollection(DateTimeType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Year.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Year.java index 454ea937db..20594bd576 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Year.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/Year.java @@ -35,7 +35,7 @@ import java.util.List; /** * year function. */ -public class Year extends BoundFunction implements UnaryExpression, ImplicitCastInputTypes { +public class Year extends ScalarFunction implements UnaryExpression, ImplicitCastInputTypes { // used in interface expectedInputTypes to avoid new list in each time it be called private static final List EXPECTED_INPUT_TYPES = ImmutableList.of( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java new file mode 100644 index 0000000000..1d8f9c5bd1 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/analysis/FunctionRegistryTest.java @@ -0,0 +1,153 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.analysis; + +import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; +import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder; +import org.apache.doris.nereids.trees.expressions.functions.Substring; +import org.apache.doris.nereids.trees.expressions.functions.Year; +import org.apache.doris.nereids.trees.expressions.literal.Literal; +import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.qe.ConnectContext; + +import com.google.common.collect.ImmutableList; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +// this ut will add more test case later +public class FunctionRegistryTest implements PatternMatchSupported { + private ConnectContext connectContext = MemoTestUtils.createConnectContext(); + + @Test + public void testDefaultFunctionNameIsClassName() { + // we register Year by the code in FunctionRegistry: scalar(Year.class). + // and default class name should be year. + PlanChecker.from(connectContext) + .analyze("select year('2021-01-01')") + .matchesFromRoot( + logicalOneRowRelation().when(r -> { + Year year = (Year) r.getProjects().get(0).child(0); + Assertions.assertEquals("2021-01-01", + ((Literal) year.getArguments().get(0)).getValue()); + return true; + }) + ); + } + + @Test + public void testMultiName() { + // the substring function has 2 names: + // 1. substring + // 2. substr + PlanChecker.from(connectContext) + .analyze("select substring('abc', 1, 2), substr(substring('abcdefg', 4, 3), 1, 2)") + .matchesFromRoot( + logicalOneRowRelation().when(r -> { + Substring firstSubstring = (Substring) r.getProjects().get(0).child(0); + Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue()); + Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue()); + Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get()).getValue()); + + Substring secondSubstring = (Substring) r.getProjects().get(1).child(0); + Assertions.assertTrue(secondSubstring.getSource() instanceof Substring); + Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition()).getValue()); + Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get()).getValue()); + return true; + }) + ); + } + + @Test + public void testOverrideArity() { + // the substring function has 2 override functions: + // 1. substring(string, position) + // 2. substring(string, position, length) + PlanChecker.from(connectContext) + .analyze("select substr('abc', 1), substring('def', 2, 3)") + .matchesFromRoot( + logicalOneRowRelation().when(r -> { + Substring firstSubstring = (Substring) r.getProjects().get(0).child(0); + Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue()); + Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue()); + Assertions.assertFalse(firstSubstring.getLength().isPresent()); + + Substring secondSubstring = (Substring) r.getProjects().get(1).child(0); + Assertions.assertEquals("def", ((Literal) secondSubstring.getSource()).getValue()); + Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition()).getValue()); + Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get()).getValue()); + return true; + }) + ); + } + + @Test + public void testAddFunction() { + FunctionRegistry functionRegistry = new FunctionRegistry() { + @Override + protected void afterRegisterBuiltinFunctions(Map> name2builders) { + name2builders.put("foo", FunctionBuilder.resolve(ExtendFunction.class)); + } + }; + + ImmutableList arguments = ImmutableList.of(Literal.of(1)); + FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder("foo", arguments); + BoundFunction function = functionBuilder.build("foo", arguments); + Assertions.assertTrue(function.getClass().equals(ExtendFunction.class)); + Assertions.assertEquals(arguments, function.getArguments()); + } + + @Test + public void testOverrideDifferenceTypes() { + FunctionRegistry functionRegistry = new FunctionRegistry() { + @Override + protected void afterRegisterBuiltinFunctions(Map> name2builders) { + name2builders.put("abc", FunctionBuilder.resolve(AmbiguousFunction.class)); + } + }; + + // currently we can not support the override same arity function with difference types + Assertions.assertThrowsExactly(AnalysisException.class, () -> { + functionRegistry.findFunctionBuilder("abc", ImmutableList.of(Literal.of(1))); + }); + } + + public static class ExtendFunction extends BoundFunction implements UnaryExpression { + public ExtendFunction(Expression a1) { + super("foo", a1); + } + } + + public static class AmbiguousFunction extends BoundFunction implements UnaryExpression { + public AmbiguousFunction(Expression a1) { + super("abc", a1); + } + + public AmbiguousFunction(Literal a1) { + super("abc", a1); + } + } +}