[feature](Nereids) Support function registry (#12481)

Support function registry.

The classes:
- BuiltinFunctions: contains the built-in functions list
- FunctionRegistry: used to register scalar functions and aggregate functions, it can find the function by name
- FunctionBuilder: used to resolve a BoundFunction class, extract the constructor, and build to a BoundFunction by arguments(`List<Expression>`)

Register example: you can add built-in functions in the list for simplicity

```java
public class BuiltinFunctions implements FunctionHelper {
    public final List<ScalarFunc> scalarFunctions = ImmutableList.of(
            scalar(Substring.class, "substr", "substring"),
            scalar(WeekOfYear.class),
            scalar(Year.class)
    );

    public final ImmutableList<AggregateFunc> aggregateFunctions = ImmutableList.of(
            agg(Avg.class),
            agg(Count.class),
            agg(Max.class),
            agg(Min.class),
            agg(Sum.class)
    );
}
```

Note:
- Currently, we only support register scalar functions add aggregate functions, we will support register table functions.
- Currently, we only support resolve function by function name and difference arity, but can not resolve the same arity override function, e.g. `some_function(Expression)` and `some_function(Literal)`
This commit is contained in:
924060929
2022-09-09 15:19:45 +08:00
committed by GitHub
parent c9a6486f8c
commit 6b8a139f2d
12 changed files with 590 additions and 98 deletions

View File

@ -0,0 +1,58 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.catalog;
import org.apache.doris.nereids.trees.expressions.functions.Avg;
import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.Max;
import org.apache.doris.nereids.trees.expressions.functions.Min;
import org.apache.doris.nereids.trees.expressions.functions.Substring;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.expressions.functions.WeekOfYear;
import org.apache.doris.nereids.trees.expressions.functions.Year;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* Built-in functions.
*
* Note: Please ensure that this class only has some lists and no procedural code.
* It helps to be clear and concise.
*/
public class BuiltinFunctions implements FunctionHelper {
public final List<ScalarFunc> scalarFunctions = ImmutableList.of(
scalar(Substring.class, "substr", "substring"),
scalar(WeekOfYear.class),
scalar(Year.class)
);
public final ImmutableList<AggregateFunc> aggregateFunctions = ImmutableList.of(
agg(Avg.class),
agg(Count.class),
agg(Max.class),
agg(Min.class),
agg(Sum.class)
);
public static final BuiltinFunctions INSTANCE = new BuiltinFunctions();
// Note: Do not add any code here!
private BuiltinFunctions() {}
}

View File

@ -375,6 +375,9 @@ public class Env {
private CatalogRecycleBin recycleBin;
private FunctionSet functionSet;
// for nereids
private FunctionRegistry functionRegistry;
private MetaReplayState metaReplayState;
private BrokerMgr brokerMgr;
@ -555,6 +558,8 @@ public class Env {
this.functionSet = new FunctionSet();
this.functionSet.init();
this.functionRegistry = new FunctionRegistry();
this.metaReplayState = new MetaReplayState();
this.isDefaultClusterCreated = false;
@ -4306,6 +4311,10 @@ public class Env {
LOG.info("successfully create view[" + tableName + "-" + newView.getId() + "]");
}
public FunctionRegistry getFunctionRegistry() {
return functionRegistry;
}
/**
* Returns the function that best matches 'desc' that is registered with the
* catalog using 'mode' to check for matching. If desc matches multiple

View File

@ -0,0 +1,109 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.catalog;
import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.ScalarFunction;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
public interface FunctionHelper {
/**
* put functions into the target map, which the key is the function name,
* and value is FunctionBuilder is converted from NamedFunc.
* @param name2FuncBuilders target Map
* @param functions the NamedFunc list to be put into the target map
*/
static void addFunctions(Map<String, List<FunctionBuilder>> name2FuncBuilders,
List<? extends NamedFunc<? extends BoundFunction>> functions) {
for (NamedFunc<? extends BoundFunction> func : functions) {
for (String name : func.names) {
if (name2FuncBuilders.containsKey(name)) {
throw new IllegalStateException("Function '" + name + "' already exists in function registry");
}
name2FuncBuilders.put(name, func.functionBuilders);
}
}
}
default ScalarFunc scalar(Class<? extends ScalarFunction> functionClass) {
String functionName = functionClass.getSimpleName();
return scalar(functionClass, functionName);
}
/**
* Resolve ScalaFunction class, convert to FunctionBuilder and wrap to ScalarFunc
* @param functionClass the ScalaFunction class
* @return ScalaFunc which contains the functionName and the FunctionBuilder
*/
default ScalarFunc scalar(Class<? extends ScalarFunction> functionClass, String... functionNames) {
return new ScalarFunc(functionClass, functionNames);
}
default AggregateFunc agg(Class<? extends AggregateFunction> functionClass) {
String functionName = functionClass.getSimpleName();
return new AggregateFunc(functionClass, functionName);
}
/**
* Resolve AggregateFunction class, convert to FunctionBuilder and wrap to AggregateFunc
* @param functionClass the AggregateFunction class
* @return AggregateFunc which contains the functionName and the AggregateFunc
*/
default AggregateFunc agg(Class<? extends AggregateFunction> functionClass, String... functionNames) {
return new AggregateFunc(functionClass, functionNames);
}
/**
* use this class to prevent the wrong type from being registered, and support multi function names
* like substring and substr.
*/
class NamedFunc<T extends BoundFunction> {
public final List<String> names;
public final Class<? extends T> functionClass;
public final List<FunctionBuilder> functionBuilders;
public NamedFunc(Class<? extends T> functionClass, String... names) {
this.functionClass = functionClass;
this.names = Arrays.stream(names)
.map(String::toLowerCase)
.collect(ImmutableList.toImmutableList());
this.functionBuilders = FunctionBuilder.resolve(functionClass);
}
}
class ScalarFunc extends NamedFunc<ScalarFunction> {
public ScalarFunc(Class<? extends ScalarFunction> functionClass, String... names) {
super(functionClass, names);
}
}
class AggregateFunc extends NamedFunc<AggregateFunction> {
public AggregateFunc(Class<? extends AggregateFunction> functionClass, String... names) {
super(functionClass, names);
}
}
}

View File

@ -0,0 +1,91 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.catalog;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import com.google.common.annotations.VisibleForTesting;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.annotation.concurrent.ThreadSafe;
/**
* New function registry for nereids.
*
* this class is developing for more functions.
*/
@Developing
@ThreadSafe
public class FunctionRegistry {
private final Map<String, List<FunctionBuilder>> name2Builders;
public FunctionRegistry() {
name2Builders = new ConcurrentHashMap<>();
registerBuiltinFunctions(name2Builders);
afterRegisterBuiltinFunctions(name2Builders);
}
// this function is used to test.
// for example, you can create child class of FunctionRegistry and clear builtin functions or add more functions
// in this method
@VisibleForTesting
protected void afterRegisterBuiltinFunctions(Map<String, List<FunctionBuilder>> name2Builders) {}
// currently we only find function by name and arity
public FunctionBuilder findFunctionBuilder(String name, List<Expression> arguments) {
int arity = arguments.size();
List<FunctionBuilder> functionBuilders = name2Builders.get(name.toLowerCase());
if (functionBuilders == null || functionBuilders.isEmpty()) {
throw new AnalysisException("Can not found function '" + name + "'");
}
List<FunctionBuilder> candidateBuilders = functionBuilders.stream()
.filter(functionBuilder -> functionBuilder.arity == arity)
.collect(Collectors.toList());
if (candidateBuilders.isEmpty()) {
String candidateHints = getCandidateHint(name, candidateBuilders);
throw new AnalysisException("Can not found function '" + name
+ "' which has " + arity + " arity. Candidate functions are: " + candidateHints);
}
if (candidateBuilders.size() > 1) {
String candidateHints = getCandidateHint(name, candidateBuilders);
// NereidsPlanner not supported override function by the same arity, should we support it?
throw new AnalysisException("Function '" + name + "' is ambiguous: " + candidateHints);
}
return candidateBuilders.get(0);
}
private void registerBuiltinFunctions(Map<String, List<FunctionBuilder>> name2Builders) {
FunctionHelper.addFunctions(name2Builders, BuiltinFunctions.INSTANCE.scalarFunctions);
FunctionHelper.addFunctions(name2Builders, BuiltinFunctions.INSTANCE.aggregateFunctions);
}
public String getCandidateHint(String name, List<FunctionBuilder> candidateBuilders) {
return candidateBuilders.stream()
.map(builder -> name + builder.toString())
.collect(Collectors.joining(", "));
}
}

View File

@ -18,21 +18,19 @@
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.analysis.ArithmeticExpr.Operator;
import org.apache.doris.catalog.Env;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.nereids.analyzer.UnboundFunction;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.TimestampArithmetic;
import org.apache.doris.nereids.trees.expressions.functions.Avg;
import org.apache.doris.nereids.trees.expressions.functions.Count;
import org.apache.doris.nereids.trees.expressions.functions.Max;
import org.apache.doris.nereids.trees.expressions.functions.Min;
import org.apache.doris.nereids.trees.expressions.functions.Substring;
import org.apache.doris.nereids.trees.expressions.functions.Sum;
import org.apache.doris.nereids.trees.expressions.functions.WeekOfYear;
import org.apache.doris.nereids.trees.expressions.functions.Year;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalHaving;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
@ -54,9 +52,10 @@ public class BindFunction implements AnalysisRuleFactory {
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.BINDING_ONE_ROW_RELATION_FUNCTION.build(
logicalOneRowRelation().then(oneRowRelation -> {
logicalOneRowRelation().thenApply(ctx -> {
LogicalOneRowRelation oneRowRelation = ctx.root;
List<NamedExpression> projects = oneRowRelation.getProjects();
List<NamedExpression> boundProjects = bind(projects);
List<NamedExpression> boundProjects = bind(projects, ctx.connectContext.getEnv());
// TODO:
// trick logic: currently XxxRelation in GroupExpression always difference to each other,
// so this rule must check the expression whether is changed to prevent dead loop because
@ -70,113 +69,64 @@ public class BindFunction implements AnalysisRuleFactory {
})
),
RuleType.BINDING_PROJECT_FUNCTION.build(
logicalProject().then(project -> {
List<NamedExpression> boundExpr = bind(project.getProjects());
logicalProject().thenApply(ctx -> {
LogicalProject<GroupPlan> project = ctx.root;
List<NamedExpression> boundExpr = bind(project.getProjects(), ctx.connectContext.getEnv());
return new LogicalProject<>(boundExpr, project.child());
})
),
RuleType.BINDING_AGGREGATE_FUNCTION.build(
logicalAggregate().then(agg -> {
List<Expression> groupBy = bind(agg.getGroupByExpressions());
List<NamedExpression> output = bind(agg.getOutputExpressions());
logicalAggregate().thenApply(ctx -> {
LogicalAggregate<GroupPlan> agg = ctx.root;
List<Expression> groupBy = bind(agg.getGroupByExpressions(), ctx.connectContext.getEnv());
List<NamedExpression> output = bind(agg.getOutputExpressions(), ctx.connectContext.getEnv());
return agg.withGroupByAndOutput(groupBy, output);
})
),
RuleType.BINDING_FILTER_FUNCTION.build(
logicalFilter().then(filter -> {
List<Expression> predicates = bind(filter.getExpressions());
logicalFilter().thenApply(ctx -> {
LogicalFilter<GroupPlan> filter = ctx.root;
List<Expression> predicates = bind(filter.getExpressions(), ctx.connectContext.getEnv());
return new LogicalFilter<>(predicates.get(0), filter.child());
})
),
RuleType.BINDING_HAVING_FUNCTION.build(
logicalHaving(logicalAggregate()).then(filter -> {
List<Expression> predicates = bind(filter.getExpressions());
return new LogicalHaving<>(predicates.get(0), filter.child());
logicalHaving(logicalAggregate()).thenApply(ctx -> {
LogicalHaving<LogicalAggregate<GroupPlan>> having = ctx.root;
List<Expression> predicates = bind(having.getExpressions(), ctx.connectContext.getEnv());
return new LogicalHaving<>(predicates.get(0), having.child());
})
)
);
}
private <E extends Expression> List<E> bind(List<E> exprList) {
private <E extends Expression> List<E> bind(List<E> exprList, Env env) {
return exprList.stream()
.map(FunctionBinder.INSTANCE::bind)
.map(expr -> FunctionBinder.INSTANCE.bind(expr, env))
.collect(Collectors.toList());
}
private static class FunctionBinder extends DefaultExpressionRewriter<Void> {
private static class FunctionBinder extends DefaultExpressionRewriter<Env> {
public static final FunctionBinder INSTANCE = new FunctionBinder();
public <E extends Expression> E bind(E expression) {
return (E) expression.accept(this, null);
public <E extends Expression> E bind(E expression, Env env) {
return (E) expression.accept(this, env);
}
@Override
public Expression visitUnboundFunction(UnboundFunction unboundFunction, Void context) {
String name = unboundFunction.getName();
// TODO: lookup function in the function registry
if (name.equalsIgnoreCase("sum")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new Sum(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("count")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() > 1 || (arguments.size() == 0 && !unboundFunction.isStar())) {
return unboundFunction;
}
if (unboundFunction.isStar() || arguments.stream().allMatch(Expression::isConstant)) {
return new Count();
}
return new Count(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("max")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new Max(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("min")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new Min(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("avg")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new Avg(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("substr") || name.equalsIgnoreCase("substring")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() == 2) {
return new Substring(unboundFunction.getArguments().get(0),
unboundFunction.getArguments().get(1));
} else if (arguments.size() == 3) {
return new Substring(unboundFunction.getArguments().get(0), unboundFunction.getArguments().get(1),
unboundFunction.getArguments().get(2));
}
return unboundFunction;
} else if (name.equalsIgnoreCase("year")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new Year(unboundFunction.getArguments().get(0));
} else if (name.equalsIgnoreCase("WeekOfYear")) {
List<Expression> arguments = unboundFunction.getArguments();
if (arguments.size() != 1) {
return unboundFunction;
}
return new WeekOfYear(unboundFunction.getArguments().get(0));
}
return unboundFunction;
public BoundFunction visitUnboundFunction(UnboundFunction unboundFunction, Env env) {
FunctionRegistry functionRegistry = env.getFunctionRegistry();
String functionName = unboundFunction.getName();
FunctionBuilder builder = functionRegistry.findFunctionBuilder(
functionName, unboundFunction.getArguments());
return builder.build(functionName, unboundFunction.getArguments());
}
@Override
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Void context) {
public Expression visitTimestampArithmetic(TimestampArithmetic arithmetic, Env env) {
String funcOpName;
if (arithmetic.getFuncName() == null) {
// e.g. YEARS_ADD, MONTHS_SUB
funcOpName = String.format("%sS_%s", arithmetic.getTimeUnit(),
(arithmetic.getOp() == Operator.ADD) ? "ADD" : "SUB");
} else {

View File

@ -21,7 +21,9 @@ import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DataType;
/** AggregateFunction. */
/**
* The function which consume arguments in lots of rows and product one value.
*/
public abstract class AggregateFunction extends BoundFunction {
private DataType intermediate;

View File

@ -0,0 +1,90 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.lang.reflect.Constructor;
import java.lang.reflect.Modifier;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
/**
* This class used to resolve a concrete BoundFunction's class and build BoundFunction by a list of Expressions.
*/
public class FunctionBuilder {
public final int arity;
// Concrete BoundFunction's constructor
private final Constructor<BoundFunction> builderMethod;
public FunctionBuilder(Constructor<BoundFunction> builderMethod) {
this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null");
this.arity = builderMethod.getParameterCount();
}
/**
* build a BoundFunction by function name and arguments.
* @param name function name which in the sql expression
* @param arguments the function's argument expressions
* @return the concrete bound function instance
*/
public BoundFunction build(String name, List<Expression> arguments) {
try {
return builderMethod.newInstance(arguments.toArray(new Expression[0]));
} catch (Throwable t) {
String argString = arguments.stream()
.map(arg -> arg == null ? "null" : arg.toSql())
.collect(Collectors.joining(", ", "(", ")"));
throw new IllegalStateException("Can not build function: '" + name
+ "', expression: " + name + argString, t);
}
}
@Override
public String toString() {
return Arrays.stream(builderMethod.getParameterTypes())
.map(type -> type.getSimpleName())
.collect(Collectors.joining(", ", "(", ")"));
}
/**
* resolve a Concrete boundFunction's class and convert the constructors to FunctionBuilder
* @param functionClass a class which is the child class of BoundFunction and can not be abstract class
* @return list of FunctionBuilder which contains the constructor
*/
public static List<FunctionBuilder> resolve(Class<? extends BoundFunction> functionClass) {
Preconditions.checkArgument(!Modifier.isAbstract(functionClass.getModifiers()),
"Can not resolve bind function which is abstract class: "
+ functionClass.getSimpleName());
return Arrays.stream(functionClass.getConstructors())
.filter(constructor -> Modifier.isPublic(constructor.getModifiers()))
.filter(constructor ->
// all arguments must be Expression
Arrays.stream(functionClass.getTypeParameters())
.allMatch(Expression.class::isInstance)
)
.map(constructor -> new FunctionBuilder((Constructor<BoundFunction>) constructor))
.collect(ImmutableList.toImmutableList());
}
}

View File

@ -0,0 +1,29 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
/**
* The function which consume zero or more arguments in a row and product one value.
*/
public abstract class ScalarFunction extends BoundFunction {
public ScalarFunction(String name, Expression... arguments) {
super(name, arguments);
}
}

View File

@ -19,7 +19,6 @@ package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.shape.TernaryExpression;
import org.apache.doris.nereids.trees.expressions.typecoercion.ImplicitCastInputTypes;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.IntegerType;
@ -31,11 +30,12 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Optional;
/**
* substring function.
*/
public class Substring extends BoundFunction implements TernaryExpression, ImplicitCastInputTypes {
public class Substring extends ScalarFunction implements ImplicitCastInputTypes {
// used in interface expectedInputTypes to avoid new list in each time it be called
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
@ -49,10 +49,10 @@ public class Substring extends BoundFunction implements TernaryExpression, Impli
}
public Substring(Expression str, Expression pos) {
super("substring", str, pos, new IntegerLiteral(Integer.MAX_VALUE));
super("substring", str, pos);
}
public Expression getTarget() {
public Expression getSource() {
return child(0);
}
@ -60,21 +60,22 @@ public class Substring extends BoundFunction implements TernaryExpression, Impli
return child(1);
}
public Expression getLength() {
return child(2);
public Optional<Expression> getLength() {
return arity() == 3 ? Optional.of(child(2)) : Optional.empty();
}
@Override
public DataType getDataType() {
if (getLength() instanceof IntegerLiteral) {
return VarcharType.createVarcharType(((IntegerLiteral) getLength()).getValue());
Optional<Expression> length = getLength();
if (length.isPresent() && length.get() instanceof IntegerLiteral) {
return VarcharType.createVarcharType(((IntegerLiteral) length.get()).getValue());
}
return VarcharType.SYSTEM_DEFAULT;
}
@Override
public boolean nullable() {
return first().nullable();
return children().stream().anyMatch(Expression::nullable);
}
@Override

View File

@ -34,7 +34,7 @@ import java.util.List;
/**
* weekOfYear function
*/
public class WeekOfYear extends BoundFunction implements UnaryExpression, ImplicitCastInputTypes {
public class WeekOfYear extends ScalarFunction implements UnaryExpression, ImplicitCastInputTypes {
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(
new TypeCollection(DateTimeType.INSTANCE)

View File

@ -35,7 +35,7 @@ import java.util.List;
/**
* year function.
*/
public class Year extends BoundFunction implements UnaryExpression, ImplicitCastInputTypes {
public class Year extends ScalarFunction implements UnaryExpression, ImplicitCastInputTypes {
// used in interface expectedInputTypes to avoid new list in each time it be called
private static final List<AbstractDataType> EXPECTED_INPUT_TYPES = ImmutableList.of(

View File

@ -0,0 +1,153 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.catalog.FunctionRegistry;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.BoundFunction;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.Substring;
import org.apache.doris.nereids.trees.expressions.functions.Year;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.util.MemoTestUtils;
import org.apache.doris.nereids.util.PatternMatchSupported;
import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.Test;
import java.util.List;
import java.util.Map;
// this ut will add more test case later
public class FunctionRegistryTest implements PatternMatchSupported {
private ConnectContext connectContext = MemoTestUtils.createConnectContext();
@Test
public void testDefaultFunctionNameIsClassName() {
// we register Year by the code in FunctionRegistry: scalar(Year.class).
// and default class name should be year.
PlanChecker.from(connectContext)
.analyze("select year('2021-01-01')")
.matchesFromRoot(
logicalOneRowRelation().when(r -> {
Year year = (Year) r.getProjects().get(0).child(0);
Assertions.assertEquals("2021-01-01",
((Literal) year.getArguments().get(0)).getValue());
return true;
})
);
}
@Test
public void testMultiName() {
// the substring function has 2 names:
// 1. substring
// 2. substr
PlanChecker.from(connectContext)
.analyze("select substring('abc', 1, 2), substr(substring('abcdefg', 4, 3), 1, 2)")
.matchesFromRoot(
logicalOneRowRelation().when(r -> {
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
Assertions.assertEquals((byte) 2, ((Literal) firstSubstring.getLength().get()).getValue());
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
Assertions.assertTrue(secondSubstring.getSource() instanceof Substring);
Assertions.assertEquals((byte) 1, ((Literal) secondSubstring.getPosition()).getValue());
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getLength().get()).getValue());
return true;
})
);
}
@Test
public void testOverrideArity() {
// the substring function has 2 override functions:
// 1. substring(string, position)
// 2. substring(string, position, length)
PlanChecker.from(connectContext)
.analyze("select substr('abc', 1), substring('def', 2, 3)")
.matchesFromRoot(
logicalOneRowRelation().when(r -> {
Substring firstSubstring = (Substring) r.getProjects().get(0).child(0);
Assertions.assertEquals("abc", ((Literal) firstSubstring.getSource()).getValue());
Assertions.assertEquals((byte) 1, ((Literal) firstSubstring.getPosition()).getValue());
Assertions.assertFalse(firstSubstring.getLength().isPresent());
Substring secondSubstring = (Substring) r.getProjects().get(1).child(0);
Assertions.assertEquals("def", ((Literal) secondSubstring.getSource()).getValue());
Assertions.assertEquals((byte) 2, ((Literal) secondSubstring.getPosition()).getValue());
Assertions.assertEquals((byte) 3, ((Literal) secondSubstring.getLength().get()).getValue());
return true;
})
);
}
@Test
public void testAddFunction() {
FunctionRegistry functionRegistry = new FunctionRegistry() {
@Override
protected void afterRegisterBuiltinFunctions(Map<String, List<FunctionBuilder>> name2builders) {
name2builders.put("foo", FunctionBuilder.resolve(ExtendFunction.class));
}
};
ImmutableList<Expression> arguments = ImmutableList.of(Literal.of(1));
FunctionBuilder functionBuilder = functionRegistry.findFunctionBuilder("foo", arguments);
BoundFunction function = functionBuilder.build("foo", arguments);
Assertions.assertTrue(function.getClass().equals(ExtendFunction.class));
Assertions.assertEquals(arguments, function.getArguments());
}
@Test
public void testOverrideDifferenceTypes() {
FunctionRegistry functionRegistry = new FunctionRegistry() {
@Override
protected void afterRegisterBuiltinFunctions(Map<String, List<FunctionBuilder>> name2builders) {
name2builders.put("abc", FunctionBuilder.resolve(AmbiguousFunction.class));
}
};
// currently we can not support the override same arity function with difference types
Assertions.assertThrowsExactly(AnalysisException.class, () -> {
functionRegistry.findFunctionBuilder("abc", ImmutableList.of(Literal.of(1)));
});
}
public static class ExtendFunction extends BoundFunction implements UnaryExpression {
public ExtendFunction(Expression a1) {
super("foo", a1);
}
}
public static class AmbiguousFunction extends BoundFunction implements UnaryExpression {
public AmbiguousFunction(Expression a1) {
super("abc", a1);
}
public AmbiguousFunction(Literal a1) {
super("abc", a1);
}
}
}