[feature](Nereids) Support function registry (#12481)
Support function registry.
The classes:
- BuiltinFunctions: contains the built-in functions list
- FunctionRegistry: used to register scalar functions and aggregate functions, it can find the function by name
- FunctionBuilder: used to resolve a BoundFunction class, extract the constructor, and build to a BoundFunction by arguments(`List<Expression>`)
Register example: you can add built-in functions in the list for simplicity
```java
public class BuiltinFunctions implements FunctionHelper {
public final List<ScalarFunc> scalarFunctions = ImmutableList.of(
scalar(Substring.class, "substr", "substring"),
scalar(WeekOfYear.class),
scalar(Year.class)
);
public final ImmutableList<AggregateFunc> aggregateFunctions = ImmutableList.of(
agg(Avg.class),
agg(Count.class),
agg(Max.class),
agg(Min.class),
agg(Sum.class)
);
}
```
Note:
- Currently, we only support register scalar functions add aggregate functions, we will support register table functions.
- Currently, we only support resolve function by function name and difference arity, but can not resolve the same arity override function, e.g. `some_function(Expression)` and `some_function(Literal)`
This commit is contained in:
@ -0,0 +1,58 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.catalog;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Avg;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Max;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Substring;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.WeekOfYear;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Year;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Built-in functions.
|
||||
*
|
||||
* Note: Please ensure that this class only has some lists and no procedural code.
|
||||
* It helps to be clear and concise.
|
||||
*/
|
||||
public class BuiltinFunctions implements FunctionHelper {
|
||||
public final List<ScalarFunc> scalarFunctions = ImmutableList.of(
|
||||
scalar(Substring.class, "substr", "substring"),
|
||||
scalar(WeekOfYear.class),
|
||||
scalar(Year.class)
|
||||
);
|
||||
|
||||
public final ImmutableList<AggregateFunc> aggregateFunctions = ImmutableList.of(
|
||||
agg(Avg.class),
|
||||
agg(Count.class),
|
||||
agg(Max.class),
|
||||
agg(Min.class),
|
||||
agg(Sum.class)
|
||||
);
|
||||
|
||||
public static final BuiltinFunctions INSTANCE = new BuiltinFunctions();
|
||||
|
||||
// Note: Do not add any code here!
|
||||
private BuiltinFunctions() {}
|
||||
}
|
||||
@ -375,6 +375,9 @@ public class Env {
|
||||
private CatalogRecycleBin recycleBin;
|
||||
private FunctionSet functionSet;
|
||||
|
||||
// for nereids
|
||||
private FunctionRegistry functionRegistry;
|
||||
|
||||
private MetaReplayState metaReplayState;
|
||||
|
||||
private BrokerMgr brokerMgr;
|
||||
@ -555,6 +558,8 @@ public class Env {
|
||||
this.functionSet = new FunctionSet();
|
||||
this.functionSet.init();
|
||||
|
||||
this.functionRegistry = new FunctionRegistry();
|
||||
|
||||
this.metaReplayState = new MetaReplayState();
|
||||
|
||||
this.isDefaultClusterCreated = false;
|
||||
@ -4306,6 +4311,10 @@ public class Env {
|
||||
LOG.info("successfully create view[" + tableName + "-" + newView.getId() + "]");
|
||||
}
|
||||
|
||||
public FunctionRegistry getFunctionRegistry() {
|
||||
return functionRegistry;
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the function that best matches 'desc' that is registered with the
|
||||
* catalog using 'mode' to check for matching. If desc matches multiple
|
||||
|
||||
@ -0,0 +1,109 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.catalog;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ScalarFunction;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
public interface FunctionHelper {
|
||||
/**
|
||||
* put functions into the target map, which the key is the function name,
|
||||
* and value is FunctionBuilder is converted from NamedFunc.
|
||||
* @param name2FuncBuilders target Map
|
||||
* @param functions the NamedFunc list to be put into the target map
|
||||
*/
|
||||
static void addFunctions(Map<String, List<FunctionBuilder>> name2FuncBuilders,
|
||||
List<? extends NamedFunc<? extends BoundFunction>> functions) {
|
||||
for (NamedFunc<? extends BoundFunction> func : functions) {
|
||||
for (String name : func.names) {
|
||||
if (name2FuncBuilders.containsKey(name)) {
|
||||
throw new IllegalStateException("Function '" + name + "' already exists in function registry");
|
||||
}
|
||||
|
||||
name2FuncBuilders.put(name, func.functionBuilders);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
default ScalarFunc scalar(Class<? extends ScalarFunction> functionClass) {
|
||||
String functionName = functionClass.getSimpleName();
|
||||
return scalar(functionClass, functionName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve ScalaFunction class, convert to FunctionBuilder and wrap to ScalarFunc
|
||||
* @param functionClass the ScalaFunction class
|
||||
* @return ScalaFunc which contains the functionName and the FunctionBuilder
|
||||
*/
|
||||
default ScalarFunc scalar(Class<? extends ScalarFunction> functionClass, String... functionNames) {
|
||||
return new ScalarFunc(functionClass, functionNames);
|
||||
}
|
||||
|
||||
default AggregateFunc agg(Class<? extends AggregateFunction> functionClass) {
|
||||
String functionName = functionClass.getSimpleName();
|
||||
return new AggregateFunc(functionClass, functionName);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve AggregateFunction class, convert to FunctionBuilder and wrap to AggregateFunc
|
||||
* @param functionClass the AggregateFunction class
|
||||
* @return AggregateFunc which contains the functionName and the AggregateFunc
|
||||
*/
|
||||
default AggregateFunc agg(Class<? extends AggregateFunction> functionClass, String... functionNames) {
|
||||
return new AggregateFunc(functionClass, functionNames);
|
||||
}
|
||||
|
||||
/**
|
||||
* use this class to prevent the wrong type from being registered, and support multi function names
|
||||
* like substring and substr.
|
||||
*/
|
||||
class NamedFunc<T extends BoundFunction> {
|
||||
public final List<String> names;
|
||||
public final Class<? extends T> functionClass;
|
||||
|
||||
public final List<FunctionBuilder> functionBuilders;
|
||||
|
||||
public NamedFunc(Class<? extends T> functionClass, String... names) {
|
||||
this.functionClass = functionClass;
|
||||
this.names = Arrays.stream(names)
|
||||
.map(String::toLowerCase)
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
this.functionBuilders = FunctionBuilder.resolve(functionClass);
|
||||
}
|
||||
}
|
||||
|
||||
class ScalarFunc extends NamedFunc<ScalarFunction> {
|
||||
public ScalarFunc(Class<? extends ScalarFunction> functionClass, String... names) {
|
||||
super(functionClass, names);
|
||||
}
|
||||
}
|
||||
|
||||
class AggregateFunc extends NamedFunc<AggregateFunction> {
|
||||
public AggregateFunc(Class<? extends AggregateFunction> functionClass, String... names) {
|
||||
super(functionClass, names);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,91 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.catalog;
|
||||
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
|
||||
import com.google.common.annotations.VisibleForTesting;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ConcurrentHashMap;
|
||||
import java.util.stream.Collectors;
|
||||
import javax.annotation.concurrent.ThreadSafe;
|
||||
|
||||
/**
|
||||
* New function registry for nereids.
|
||||
*
|
||||
* this class is developing for more functions.
|
||||
*/
|
||||
@Developing
|
||||
@ThreadSafe
|
||||
public class FunctionRegistry {
|
||||
private final Map<String, List<FunctionBuilder>> name2Builders;
|
||||
|
||||
public FunctionRegistry() {
|
||||
name2Builders = new ConcurrentHashMap<>();
|
||||
registerBuiltinFunctions(name2Builders);
|
||||
afterRegisterBuiltinFunctions(name2Builders);
|
||||
}
|
||||
|
||||
// this function is used to test.
|
||||
// for example, you can create child class of FunctionRegistry and clear builtin functions or add more functions
|
||||
// in this method
|
||||
@VisibleForTesting
|
||||
protected void afterRegisterBuiltinFunctions(Map<String, List<FunctionBuilder>> name2Builders) {}
|
||||
|
||||
// currently we only find function by name and arity
|
||||
public FunctionBuilder findFunctionBuilder(String name, List<Expression> arguments) {
|
||||
int arity = arguments.size();
|
||||
List<FunctionBuilder> functionBuilders = name2Builders.get(name.toLowerCase());
|
||||
if (functionBuilders == null || functionBuilders.isEmpty()) {
|
||||
throw new AnalysisException("Can not found function '" + name + "'");
|
||||
}
|
||||
|
||||
List<FunctionBuilder> candidateBuilders = functionBuilders.stream()
|
||||
.filter(functionBuilder -> functionBuilder.arity == arity)
|
||||
.collect(Collectors.toList());
|
||||
if (candidateBuilders.isEmpty()) {
|
||||
String candidateHints = getCandidateHint(name, candidateBuilders);
|
||||
throw new AnalysisException("Can not found function '" + name
|
||||
+ "' which has " + arity + " arity. Candidate functions are: " + candidateHints);
|
||||
}
|
||||
|
||||
if (candidateBuilders.size() > 1) {
|
||||
String candidateHints = getCandidateHint(name, candidateBuilders);
|
||||
// NereidsPlanner not supported override function by the same arity, should we support it?
|
||||
|
||||
throw new AnalysisException("Function '" + name + "' is ambiguous: " + candidateHints);
|
||||
}
|
||||
return candidateBuilders.get(0);
|
||||
}
|
||||
|
||||
private void registerBuiltinFunctions(Map<String, List<FunctionBuilder>> name2Builders) {
|
||||
FunctionHelper.addFunctions(name2Builders, BuiltinFunctions.INSTANCE.scalarFunctions);
|
||||
FunctionHelper.addFunctions(name2Builders, BuiltinFunctions.INSTANCE.aggregateFunctions);
|
||||
}
|
||||
|
||||
public String getCandidateHint(String name, List<FunctionBuilder> candidateBuilders) {
|
||||
return candidateBuilders.stream()
|
||||
.map(builder -> name + builder.toString())
|
||||
.collect(Collectors.joining(", "));
|
||||
}
|
||||
}
|
||||
@ -18,21 +18,19 @@
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.analysis.ArithmeticExpr.Operator;
|
||||
import org.apache.doris.catalog.Env;
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.nereids.analyzer.UnboundFunction;
|
||||
import org.apache.doris.nereids.rules.Rule;
|
||||
import org.apache.doris.nereids.rules.RuleType;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.NamedExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Avg;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Max;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Min;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Substring;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Sum;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.WeekOfYear;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Year;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
|
||||
import org.apache.doris.nereids.trees.plans.GroupPlan;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
|
||||
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
|
||||
@ -54,9 +52,10 @@ public class BindFunction implements AnalysisRuleFactory {
|
||||
public List<Rule> buildRules() {
|
||||
return ImmutableList.of(
|
||||
RuleType.BINDING_ONE_ROW_RELATION_FUNCTION.build(
|
||||
logicalOneRowRelation().then(oneRowRelation -> {
|
||||
logicalOneRowRelation().thenApply(ctx -> {
|
||||
LogicalOneRowRelation oneRowRelation = ctx.root;
|
||||
List<NamedExpression> projects = oneRowRelation.getProjects();
|
||||
List<NamedExpression> boundProjects = bind(projects);
|
||||
List<NamedExpression> boundProjects = bind(projects, ctx.connectContext.getEnv());
|
||||
// TODO:
|
||||
// trick logic: currently XxxRelation in GroupExpression always difference to each other,
|
||||
// so this rule must check the expression whether is changed to prevent dead loop because
|
||||
@ -70,113 +69,64 @@ public class BindFunction implements AnalysisRuleFactory {
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_PROJECT_FUNCTION.build(
|
||||
logicalProject().then(project -> {
|
||||
List<NamedExpression> boundExpr = bind(project.getProjects());
|
||||
logicalProject().thenApply(ctx -> {
|
||||
LogicalProject<GroupPlan> project = ctx.root;
|
||||
List<NamedExpression> boundExpr = bind(project.getProjects(), ctx.connectContext.getEnv());
|
||||
return new LogicalProject<>(boundExpr, project.child());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_AGGREGATE_FUNCTION.build(
|
||||
logicalAggregate().then(agg -> {
|
||||
List<Expression> groupBy = bind(agg.getGroupByExpressions());
|
||||
List<NamedExpression> output = bind(agg.getOutputExpressions());
|
||||
logicalAggregate().thenApply(ctx -> {
|
||||
LogicalAggregate<GroupPlan> agg = ctx.root;
|
||||
List<Expression> groupBy = bind(agg.getGroupByExpressions(), ctx.connectContext.getEnv());
|
||||
List<NamedExpression> output = bind(agg.getOutputExpressions(), ctx.connectContext.getEnv());
|
||||
return agg.withGroupByAndOutput(groupBy, output);
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_FILTER_FUNCTION.build(
|
||||
logicalFilter().then(filter -> {
|
||||
List<Expression> predicates = bind(filter.getExpressions());
|
||||
logicalFilter().thenApply(ctx -> {
|
||||
LogicalFilter<GroupPlan> filter = ctx.root;
|
||||
List<Expression> predicates = bind(filter.getExpressions(), ctx.connectContext.getEnv());
|
||||
return new LogicalFilter<>(predicates.get(0), filter.child());
|
||||
})
|
||||
),
|
||||
RuleType.BINDING_HAVING_FUNCTION.build(
|
||||
logicalHaving(logicalAggregate()).then(filter -> {
|
||||
List<Expression> predicates = bind(filter.getExpressions());
|
||||
return new LogicalHaving<>(predicates.get(0), filter.child());
|
||||
logicalHaving(logicalAggregate()).thenApply(ctx -> {
|
||||
LogicalHaving<LogicalAggregate<GroupPlan>> having = ctx.root;
|
||||
List<Expression> predicates = bind(having.getExpressions(), ctx.connectContext.getEnv());
|
||||
return new LogicalHaving<>(predicates.get(0), having.child());
|
||||
})
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
private <E extends Expression> List<E> bind(List<E> exprList) {
|
||||
private <E extends Expression> List<E> bind(List<E> exprList, Env env) {
|
||||
return exprList.stream()
|
||||
.map(FunctionBinder.INSTANCE::bind)
|
||||
.map(expr -> FunctionBinder.INSTANCE.bind(expr, env))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
private static class FunctionBinder extends DefaultExpressionRewriter<Void> {
|
||||
private static class FunctionBinder extends DefaultExpressionRewriter<Env> {
|
||||
public static final FunctionBinder INSTANCE = new FunctionBinder();
|
||||
|
||||
public <E extends Expression> E bind(E expression) {
|
||||
return (E) expression.accept(this, null);
|
||||
public <E extends Expression> E bind(E expression, Env env) {
|
||||
return (E) expression.accept(this, env);
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitUnboundFunction(UnboundFunction unboundFunction, Void context) {
|
||||
String name = unboundFunction.getName();
|
||||
// TODO: lookup function in the function registry
|
||||
if (name.equalsIgnoreCase("sum")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() != 1) {
|
||||
return unboundFunction;
|
||||
}
|
||||
return new Sum(unboundFunction.getArguments().get(0));
|
||||
} else if (name.equalsIgnoreCase("count")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() > 1 || (arguments.size() == 0 && !unboundFunction.isStar())) {
|
||||
return unboundFunction;
|
||||
}
|
||||
if (unboundFunction.isStar() || arguments.stream().allMatch(Expression::isConstant)) {
|
||||
return new Count();
|
||||
}
|
||||
return new Count(unboundFunction.getArguments().get(0));
|
||||
} else if (name.equalsIgnoreCase("max")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() != 1) {
|
||||
return unboundFunction;
|
||||
}
|
||||
return new Max(unboundFunction.getArguments().get(0));
|
||||
} else if (name.equalsIgnoreCase("min")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() != 1) {
|
||||
return unboundFunction;
|
||||
}
|
||||
return new Min(unboundFunction.getArguments().get(0));
|
||||
} else if (name.equalsIgnoreCase("avg")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() != 1) {
|
||||
return unboundFunction;
|
||||
}
|
||||
return new Avg(unboundFunction.getArguments().get(0));
|
||||
} else if (name.equalsIgnoreCase("substr") || name.equalsIgnoreCase("substring")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() == 2) {
|
||||
return new Substring(unboundFunction.getArguments().get(0),
|
||||
unboundFunction.getArguments().get(1));
|
||||
} else if (arguments.size() == 3) {
|
||||
return new Substring(unboundFunction.getArguments().get(0), unboundFunction.getArguments().get(1),
|
||||
unboundFunction.getArguments().get(2));
|
||||
}
|
||||
return unboundFunction;
|
||||
} else if (name.equalsIgnoreCase("year")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() != 1) {
|
||||
return unboundFunction;
|
||||
}
|
||||
return new Year(unboundFunction.getArguments().get(0));
|
||||
} else if (name.equalsIgnoreCase("WeekOfYear")) {
|
||||
List<Expression> arguments = unboundFunction.getArguments();
|
||||
if (arguments.size() != 1) {
|
||||
return unboundFunction;
|
||||
}
|
||||
return new WeekOfYear(unboundFunction.getArguments().get(0));
|
||||
}
|
||||
return unboundFunction;
|
||||
public BoundFunction visitUnboundFunction(UnboundFunction unboundFunction, Env env) {
|
||||
FunctionRegistry functionRegistry = env.getFunctionRegistry();
|
||||
String functionName = unboundFunction.getName();
|
||||
FunctionBuilder builder = functionRegistry.findFunctionBuilder(
|
||||
functionName, unboundFunction.getArguments());
|
||||
return builder.build(functionName, unboundFunction.getArguments());
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Void context) {
|
||||
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env env) {
|
||||
String funcOpName;
|
||||
if (arithmetic.getFuncName() == null) {
|
||||
// e.g. YEARS_ADD, MONTHS_SUB
|
||||
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
|
||||
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
|
||||
} else {
|
||||
|
||||
@ -21,7 +21,9 @@ import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
/** AggregateFunction. */
|
||||
/**
|
||||
* The function which consume arguments in lots of rows and product one value.
|
||||
*/
|
||||
public abstract class AggregateFunction extends BoundFunction {
|
||||
|
||||
private DataType intermediate;
|
||||
|
||||
@ -0,0 +1,90 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.lang.reflect.Constructor;
|
||||
import java.lang.reflect.Modifier;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
import java.util.stream.Collectors;
|
||||
|
||||
/**
|
||||
* This class used to resolve a concrete BoundFunction's class and build BoundFunction by a list of Expressions.
|
||||
*/
|
||||
public class FunctionBuilder {
|
||||
public final int arity;
|
||||
|
||||
// Concrete BoundFunction's constructor
|
||||
private final Constructor<BoundFunction> builderMethod;
|
||||
|
||||
public FunctionBuilder(Constructor<BoundFunction> builderMethod) {
|
||||
this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null");
|
||||
this.arity = builderMethod.getParameterCount();
|
||||
}
|
||||
|
||||
/**
|
||||
* build a BoundFunction by function name and arguments.
|
||||
* @param name function name which in the sql expression
|
||||
* @param arguments the function's argument expressions
|
||||
* @return the concrete bound function instance
|
||||
*/
|
||||
public BoundFunction build(String name, List<Expression> arguments) {
|
||||
try {
|
||||
return builderMethod.newInstance(arguments.toArray(new Expression[0]));
|
||||
} catch (Throwable t) {
|
||||
String argString = arguments.stream()
|
||||
.map(arg -> arg == null ? "null" : arg.toSql())
|
||||
.collect(Collectors.joining(", ", "(", ")"));
|
||||
throw new IllegalStateException("Can not build function: '" + name
|
||||
+ "', expression: " + name + argString, t);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return Arrays.stream(builderMethod.getParameterTypes())
|
||||
.map(type -> type.getSimpleName())
|
||||
.collect(Collectors.joining(", ", "(", ")"));
|
||||
}
|
||||
|
||||
/**
|
||||
* resolve a Concrete boundFunction's class and convert the constructors to FunctionBuilder
|
||||
* @param functionClass a class which is the child class of BoundFunction and can not be abstract class
|
||||
* @return list of FunctionBuilder which contains the constructor
|
||||
*/
|
||||
public static List<FunctionBuilder> resolve(Class<? extends BoundFunction> functionClass) {
|
||||
Preconditions.checkArgument(!Modifier.isAbstract(functionClass.getModifiers()),
|
||||
"Can not resolve bind function which is abstract class: "
|
||||
+ functionClass.getSimpleName());
|
||||
return Arrays.stream(functionClass.getConstructors())
|
||||
.filter(constructor -> Modifier.isPublic(constructor.getModifiers()))
|
||||
.filter(constructor ->
|
||||
// all arguments must be Expression
|
||||
Arrays.stream(functionClass.getTypeParameters())
|
||||
.allMatch(Expression.class::isInstance)
|
||||
)
|
||||
.map(constructor -> new FunctionBuilder((Constructor<BoundFunction>) constructor))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,29 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
|
||||
/**
|
||||
* The function which consume zero or more arguments in a row and product one value.
|
||||
*/
|
||||
public abstract class ScalarFunction extends BoundFunction {
|
||||
public ScalarFunction(String name, Expression... arguments) {
|
||||
super(name, arguments);
|
||||
}
|
||||
}
|
||||
@ -19,7 +19,6 @@ package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
import org.apache.doris.nereids.types.IntegerType;
|
||||
@ -31,11 +30,12 @@ import com.google.common.base.Preconditions;
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Optional;
|
||||
|
||||
/**
|
||||
* substring function.
|
||||
*/
|
||||
public class Substring extends BoundFunction implements TernaryExpression, ImplicitCastInputTypes {
|
||||
public class Substring extends ScalarFunction implements ImplicitCastInputTypes {
|
||||
|
||||
// used in interface expectedInputTypes to avoid new list in each time it be called
|
||||
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
|
||||
@ -49,10 +49,10 @@ public class Substring extends BoundFunction implements TernaryExpression, Impli
|
||||
}
|
||||
|
||||
public Substring(Expression str, Expression pos) {
|
||||
super("substring", str, pos, new IntegerLiteral(Integer.MAX_VALUE));
|
||||
super("substring", str, pos);
|
||||
}
|
||||
|
||||
public Expression getTarget() {
|
||||
public Expression getSource() {
|
||||
return child(0);
|
||||
}
|
||||
|
||||
@ -60,21 +60,22 @@ public class Substring extends BoundFunction implements TernaryExpression, Impli
|
||||
return child(1);
|
||||
}
|
||||
|
||||
public Expression getLength() {
|
||||
return child(2);
|
||||
public Optional<Expression> getLength() {
|
||||
return arity() == 3 ? Optional.of(child(2)) : Optional.empty();
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
if (getLength() instanceof IntegerLiteral) {
|
||||
return VarcharType.createVarcharType(((IntegerLiteral) getLength()).getValue());
|
||||
Optional<Expression> length = getLength();
|
||||
if (length.isPresent() && length.get() instanceof IntegerLiteral) {
|
||||
return VarcharType.createVarcharType(((IntegerLiteral) length.get()).getValue());
|
||||
}
|
||||
return VarcharType.SYSTEM_DEFAULT;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean nullable() {
|
||||
return first().nullable();
|
||||
return children().stream().anyMatch(Expression::nullable);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
||||
@ -34,7 +34,7 @@ import java.util.List;
|
||||
/**
|
||||
* weekOfYear function
|
||||
*/
|
||||
public class WeekOfYear extends BoundFunction implements UnaryExpression, ImplicitCastInputTypes {
|
||||
public class WeekOfYear extends ScalarFunction implements UnaryExpression, ImplicitCastInputTypes {
|
||||
|
||||
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
|
||||
new TypeCollection(DateTimeType.INSTANCE)
|
||||
|
||||
@ -35,7 +35,7 @@ import java.util.List;
|
||||
/**
|
||||
* year function.
|
||||
*/
|
||||
public class Year extends BoundFunction implements UnaryExpression, ImplicitCastInputTypes {
|
||||
public class Year extends ScalarFunction implements UnaryExpression, ImplicitCastInputTypes {
|
||||
|
||||
// used in interface expectedInputTypes to avoid new list in each time it be called
|
||||
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
|
||||
|
||||
@ -0,0 +1,153 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.rules.analysis;
|
||||
|
||||
import org.apache.doris.catalog.FunctionRegistry;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Substring;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.Year;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.Literal;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.util.MemoTestUtils;
|
||||
import org.apache.doris.nereids.util.PatternMatchSupported;
|
||||
import org.apache.doris.nereids.util.PlanChecker;
|
||||
import org.apache.doris.qe.ConnectContext;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
import org.junit.jupiter.api.Assertions;
|
||||
import org.junit.jupiter.api.Test;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
|
||||
// this ut will add more test case later
|
||||
public class FunctionRegistryTest implements PatternMatchSupported {
|
||||
private ConnectContext connectContext = MemoTestUtils.createConnectContext();
|
||||
|
||||
@Test
|
||||
public void testDefaultFunctionNameIsClassName() {
|
||||
// we register Year by the code in FunctionRegistry: scalar(Year.class).
|
||||
// and default class name should be year.
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze("select year('2021-01-01')")
|
||||
.matchesFromRoot(
|
||||
logicalOneRowRelation().when(r -> {
|
||||
Year year = (Year) r.getProjects().get(0).child(0);
|
||||
Assertions.assertEquals("2021-01-01",
|
||||
((Literal) year.getArguments().get(0)).getValue());
|
||||
return true;
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testMultiName() {
|
||||
// the substring function has 2 names:
|
||||
// 1. substring
|
||||
// 2. substr
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze("select substring('abc', 1, 2), substr(substring('abcdefg', 4, 3), 1, 2)")
|
||||
.matchesFromRoot(
|
||||
logicalOneRowRelation().when(r -> {
|
||||
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
|
||||
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
|
||||
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get()).getValue());
|
||||
|
||||
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
|
||||
Assertions.assertTrue(secondSubstring.getSource() instanceof Substring);
|
||||
Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition()).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get()).getValue());
|
||||
return true;
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOverrideArity() {
|
||||
// the substring function has 2 override functions:
|
||||
// 1. substring(string, position)
|
||||
// 2. substring(string, position, length)
|
||||
PlanChecker.from(connectContext)
|
||||
.analyze("select substr('abc', 1), substring('def', 2, 3)")
|
||||
.matchesFromRoot(
|
||||
logicalOneRowRelation().when(r -> {
|
||||
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
|
||||
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
|
||||
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
|
||||
Assertions.assertFalse(firstSubstring.getLength().isPresent());
|
||||
|
||||
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
|
||||
Assertions.assertEquals("def", ((Literal) secondSubstring.getSource()).getValue());
|
||||
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition()).getValue());
|
||||
Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get()).getValue());
|
||||
return true;
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testAddFunction() {
|
||||
FunctionRegistry functionRegistry = new FunctionRegistry() {
|
||||
@Override
|
||||
protected void afterRegisterBuiltinFunctions(Map<String, List<FunctionBuilder>> name2builders) {
|
||||
name2builders.put("foo", FunctionBuilder.resolve(ExtendFunction.class));
|
||||
}
|
||||
};
|
||||
|
||||
ImmutableList<Expression> arguments = ImmutableList.of(Literal.of(1));
|
||||
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder("foo", arguments);
|
||||
BoundFunction function = functionBuilder.build("foo", arguments);
|
||||
Assertions.assertTrue(function.getClass().equals(ExtendFunction.class));
|
||||
Assertions.assertEquals(arguments, function.getArguments());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void testOverrideDifferenceTypes() {
|
||||
FunctionRegistry functionRegistry = new FunctionRegistry() {
|
||||
@Override
|
||||
protected void afterRegisterBuiltinFunctions(Map<String, List<FunctionBuilder>> name2builders) {
|
||||
name2builders.put("abc", FunctionBuilder.resolve(AmbiguousFunction.class));
|
||||
}
|
||||
};
|
||||
|
||||
// currently we can not support the override same arity function with difference types
|
||||
Assertions.assertThrowsExactly(AnalysisException.class, () -> {
|
||||
functionRegistry.findFunctionBuilder("abc", ImmutableList.of(Literal.of(1)));
|
||||
});
|
||||
}
|
||||
|
||||
public static class ExtendFunction extends BoundFunction implements UnaryExpression {
|
||||
public ExtendFunction(Expression a1) {
|
||||
super("foo", a1);
|
||||
}
|
||||
}
|
||||
|
||||
public static class AmbiguousFunction extends BoundFunction implements UnaryExpression {
|
||||
public AmbiguousFunction(Expression a1) {
|
||||
super("abc", a1);
|
||||
}
|
||||
|
||||
public AmbiguousFunction(Literal a1) {
|
||||
super("abc", a1);
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user