[feature](function) Support for aggregate function foreach combiner (#31526)

This commit is contained in:
Mryange
2024-03-06 10:22:05 +08:00
committed by yiguolei
parent 7f3a666fac
commit 4f174c4fb9
54 changed files with 691 additions and 68 deletions

View File

@ -88,6 +88,7 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
public static final String AGG_STATE_SUFFIX = "_state";
public static final String AGG_UNION_SUFFIX = "_union";
public static final String AGG_MERGE_SUFFIX = "_merge";
public static final String AGG_FOREACH_SUFFIX = "_foreach";
public static final String DEFAULT_EXPR_NAME = "expr";
protected boolean disableTableName = false;

View File

@ -43,6 +43,7 @@ import java.io.DataInput;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@ -878,4 +879,17 @@ public class Function implements Writable {
fnCall.setType(fnCall.getChildren().get(0).getType());
return fnCall;
}
public static FunctionCallExpr convertForEachCombinator(FunctionCallExpr fnCall) {
Function aggFunction = fnCall.getFn();
aggFunction.setName(new FunctionName(aggFunction.getFunctionName().getFunction() + Expr.AGG_FOREACH_SUFFIX));
List<Type> argTypes = new ArrayList();
for (Type type : aggFunction.argTypes) {
argTypes.add(new ArrayType(type));
}
aggFunction.setArgs(argTypes);
aggFunction.setReturnType(new ArrayType(aggFunction.getReturnType(), fnCall.isNullable()));
aggFunction.setNullableMode(NullableMode.ALWAYS_NULLABLE);
return fnCall;
}
}

View File

@ -20,7 +20,7 @@ package org.apache.doris.catalog;
import org.apache.doris.mysql.privilege.PrivPredicate;
import org.apache.doris.nereids.annotation.Developing;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
import org.apache.doris.nereids.types.DataType;
@ -93,13 +93,13 @@ public class FunctionRegistry {
if (StringUtils.isEmpty(dbName)) {
// search internal function only if dbName is empty
functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase());
if (CollectionUtils.isEmpty(functionBuilders) && AggStateFunctionBuilder.isAggStateCombinator(name)) {
String nestedName = AggStateFunctionBuilder.getNestedName(name);
String combinatorSuffix = AggStateFunctionBuilder.getCombinatorSuffix(name);
if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) {
String nestedName = AggCombinerFunctionBuilder.getNestedName(name);
String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name);
functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
if (functionBuilders != null) {
functionBuilders = functionBuilders.stream()
.map(builder -> new AggStateFunctionBuilder(combinatorSuffix, builder))
.map(builder -> new AggCombinerFunctionBuilder(combinatorSuffix, builder))
.filter(functionBuilder -> functionBuilder.canApply(arguments))
.collect(Collectors.toList());
}

View File

@ -83,6 +83,7 @@ import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
@ -612,6 +613,16 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
new FunctionParams(false, arguments)));
}
@Override
public Expr visitForEachCombinator(ForEachCombinator combinator, PlanTranslatorContext context) {
List<Expr> arguments = combinator.children().stream()
.map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable()))
.collect(ImmutableList.toImmutableList());
return Function.convertForEachCombinator(
new FunctionCallExpr(visitAggregateFunction(combinator.getNestedFunction(), context).getFn(),
new FunctionParams(false, arguments)));
}
@Override
public Expr visitAggregateFunction(AggregateFunction function, PlanTranslatorContext context) {
List<Expr> arguments = function.children().stream()

View File

@ -18,11 +18,15 @@
package org.apache.doris.nereids.trees.expressions.functions;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.types.AggStateType;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import java.util.List;
import java.util.Objects;
@ -31,28 +35,30 @@ import java.util.stream.Collectors;
/**
* This class used to resolve AggState's combinators
*/
public class AggStateFunctionBuilder extends FunctionBuilder {
public class AggCombinerFunctionBuilder extends FunctionBuilder {
public static final String COMBINATOR_LINKER = "_";
public static final String STATE = "state";
public static final String MERGE = "merge";
public static final String UNION = "union";
public static final String FOREACH = "foreach";
public static final String STATE_SUFFIX = COMBINATOR_LINKER + STATE;
public static final String MERGE_SUFFIX = COMBINATOR_LINKER + MERGE;
public static final String UNION_SUFFIX = COMBINATOR_LINKER + UNION;
public static final String FOREACH_SUFFIX = COMBINATOR_LINKER + FOREACH;
private final FunctionBuilder nestedBuilder;
private final String combinatorSuffix;
public AggStateFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBuilder) {
public AggCombinerFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBuilder) {
this.combinatorSuffix = Objects.requireNonNull(combinatorSuffix, "combinatorSuffix can not be null");
this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null");
}
@Override
public boolean canApply(List<? extends Object> arguments) {
if (combinatorSuffix.equals(STATE)) {
if (combinatorSuffix.equals(STATE) || combinatorSuffix.equals(FOREACH)) {
return nestedBuilder.canApply(arguments);
} else {
if (arguments.size() != 1) {
@ -71,6 +77,23 @@ public class AggStateFunctionBuilder extends FunctionBuilder {
return (AggregateFunction) nestedBuilder.build(nestedName, arguments);
}
private AggregateFunction buildForEach(String nestedName, List<? extends Object> arguments) {
List<Expression> forEachargs = arguments.stream().map(expr -> {
if (!(expr instanceof SlotReference)) {
throw new IllegalStateException(
"Can not build foreach nested function: '" + nestedName);
}
DataType arrayType = (((Expression) expr).getDataType());
if (!(arrayType instanceof ArrayType)) {
throw new IllegalStateException(
"foreach must be input array type: '" + nestedName);
}
DataType itemType = ((ArrayType) arrayType).getItemType();
return new SlotReference("mocked", itemType, (((ArrayType) arrayType).containsNull()));
}).collect(Collectors.toList());
return (AggregateFunction) nestedBuilder.build(nestedName, forEachargs);
}
private AggregateFunction buildMergeOrUnion(String nestedName, List<? extends Object> arguments) {
if (arguments.size() != 1 || !(arguments.get(0) instanceof Expression)
|| !((Expression) arguments.get(0)).getDataType().isAggStateType()) {
@ -105,13 +128,16 @@ public class AggStateFunctionBuilder extends FunctionBuilder {
} else if (combinatorSuffix.equals(UNION)) {
AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments);
return new UnionCombinator((List<Expression>) arguments, nestedFunction);
} else if (combinatorSuffix.equals(FOREACH)) {
AggregateFunction nestedFunction = buildForEach(nestedName, arguments);
return new ForEachCombinator((List<Expression>) arguments, nestedFunction);
}
return null;
}
public static boolean isAggStateCombinator(String name) {
return name.toLowerCase().endsWith(STATE_SUFFIX) || name.toLowerCase().endsWith(MERGE_SUFFIX)
|| name.toLowerCase().endsWith(UNION_SUFFIX);
|| name.toLowerCase().endsWith(UNION_SUFFIX) || name.toLowerCase().endsWith(FOREACH_SUFFIX);
}
public static String getNestedName(String name) {

View File

@ -0,0 +1,90 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.Objects;
/**
* combinator foreach
*/
public class ForEachCombinator extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
private final AggregateFunction nested;
/**
* constructor of ForEachCombinator
*/
public ForEachCombinator(List<Expression> arguments, AggregateFunction nested) {
super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
}
public static ForEachCombinator create(AggregateFunction nested) {
return new ForEachCombinator(nested.getArguments(), nested);
}
@Override
public ForEachCombinator withChildren(List<Expression> children) {
return new ForEachCombinator(children, nested);
}
@Override
public List<FunctionSignature> getSignatures() {
return nested.getSignatures().stream().map(sig -> {
return sig.withReturnType(ArrayType.of(sig.returnType)).withArgumentTypes(false,
sig.argumentsTypes.stream().map(arg -> {
return ArrayType.of(arg);
}).collect(ImmutableList.toImmutableList()));
}).collect(ImmutableList.toImmutableList());
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitForEachCombinator(this, context);
}
@Override
public DataType getDataType() {
return ArrayType.of(nested.getDataType(), nested.nullable());
}
public AggregateFunction getNestedFunction() {
return nested;
}
@Override
public AggregateFunction withDistinctAndChildren(boolean distinct, List<Expression> children) {
throw new UnsupportedOperationException("Unimplemented method 'withDistinctAndChildren'");
}
}

View File

@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@ -43,7 +43,7 @@ public class MergeCombinator extends AggregateFunction
private final AggStateType inputType;
public MergeCombinator(List<Expression> arguments, AggregateFunction nested) {
super(nested.getName() + AggStateFunctionBuilder.MERGE_SUFFIX, arguments);
super(nested.getName() + AggCombinerFunctionBuilder.MERGE_SUFFIX, arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
inputType = (AggStateType) arguments.get(0).getDataType();

View File

@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@ -47,7 +47,7 @@ public class StateCombinator extends ScalarFunction
* constructor of StateCombinator
*/
public StateCombinator(List<Expression> arguments, AggregateFunction nested) {
super(nested.getName() + AggStateFunctionBuilder.STATE_SUFFIX, arguments);
super(nested.getName() + AggCombinerFunctionBuilder.STATE_SUFFIX, arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
this.returnType = new AggStateType(nested.getName(), arguments.stream().map(arg -> {

View File

@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.combinator;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@ -43,7 +43,7 @@ public class UnionCombinator extends AggregateFunction
private final AggStateType inputType;
public UnionCombinator(List<Expression> arguments, AggregateFunction nested) {
super(nested.getName() + AggStateFunctionBuilder.UNION_SUFFIX, arguments);
super(nested.getName() + AggCombinerFunctionBuilder.UNION_SUFFIX, arguments);
this.nested = Objects.requireNonNull(nested, "nested can not be null");
inputType = (AggStateType) arguments.get(0).getDataType();

View File

@ -72,6 +72,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.TopNWeighted;
import org.apache.doris.nereids.trees.expressions.functions.agg.Variance;
import org.apache.doris.nereids.trees.expressions.functions.agg.VarianceSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.WindowFunnel;
import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
@ -305,6 +306,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitAggregateFunction(combinator, context);
}
default R visitForEachCombinator(ForEachCombinator combinator, C context) {
return visitAggregateFunction(combinator, context);
}
default R visitJavaUdaf(JavaUdaf javaUdaf, C context) {
return visitAggregateFunction(javaUdaf, context);
}