[feature](function) Support for aggregate function foreach combiner (#31526)
This commit is contained in:
@ -88,6 +88,7 @@ public abstract class Expr extends TreeNode<Expr> implements ParseNode, Cloneabl
|
||||
public static final String AGG_STATE_SUFFIX = "_state";
|
||||
public static final String AGG_UNION_SUFFIX = "_union";
|
||||
public static final String AGG_MERGE_SUFFIX = "_merge";
|
||||
public static final String AGG_FOREACH_SUFFIX = "_foreach";
|
||||
public static final String DEFAULT_EXPR_NAME = "expr";
|
||||
|
||||
protected boolean disableTableName = false;
|
||||
|
||||
@ -43,6 +43,7 @@ import java.io.DataInput;
|
||||
import java.io.DataOutput;
|
||||
import java.io.DataOutputStream;
|
||||
import java.io.IOException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
@ -878,4 +879,17 @@ public class Function implements Writable {
|
||||
fnCall.setType(fnCall.getChildren().get(0).getType());
|
||||
return fnCall;
|
||||
}
|
||||
|
||||
public static FunctionCallExpr convertForEachCombinator(FunctionCallExpr fnCall) {
|
||||
Function aggFunction = fnCall.getFn();
|
||||
aggFunction.setName(new FunctionName(aggFunction.getFunctionName().getFunction() + Expr.AGG_FOREACH_SUFFIX));
|
||||
List<Type> argTypes = new ArrayList();
|
||||
for (Type type : aggFunction.argTypes) {
|
||||
argTypes.add(new ArrayType(type));
|
||||
}
|
||||
aggFunction.setArgs(argTypes);
|
||||
aggFunction.setReturnType(new ArrayType(aggFunction.getReturnType(), fnCall.isNullable()));
|
||||
aggFunction.setNullableMode(NullableMode.ALWAYS_NULLABLE);
|
||||
return fnCall;
|
||||
}
|
||||
}
|
||||
|
||||
@ -20,7 +20,7 @@ package org.apache.doris.catalog;
|
||||
import org.apache.doris.mysql.privilege.PrivPredicate;
|
||||
import org.apache.doris.nereids.annotation.Developing;
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.FunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.udf.UdfBuilder;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
@ -93,13 +93,13 @@ public class FunctionRegistry {
|
||||
if (StringUtils.isEmpty(dbName)) {
|
||||
// search internal function only if dbName is empty
|
||||
functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase());
|
||||
if (CollectionUtils.isEmpty(functionBuilders) && AggStateFunctionBuilder.isAggStateCombinator(name)) {
|
||||
String nestedName = AggStateFunctionBuilder.getNestedName(name);
|
||||
String combinatorSuffix = AggStateFunctionBuilder.getCombinatorSuffix(name);
|
||||
if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) {
|
||||
String nestedName = AggCombinerFunctionBuilder.getNestedName(name);
|
||||
String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name);
|
||||
functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase());
|
||||
if (functionBuilders != null) {
|
||||
functionBuilders = functionBuilders.stream()
|
||||
.map(builder -> new AggStateFunctionBuilder(combinatorSuffix, builder))
|
||||
.map(builder -> new AggCombinerFunctionBuilder(combinatorSuffix, builder))
|
||||
.filter(functionBuilder -> functionBuilder.canApply(arguments))
|
||||
.collect(Collectors.toList());
|
||||
}
|
||||
|
||||
@ -83,6 +83,7 @@ import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateParam;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
|
||||
@ -612,6 +613,16 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
|
||||
new FunctionParams(false, arguments)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expr visitForEachCombinator(ForEachCombinator combinator, PlanTranslatorContext context) {
|
||||
List<Expr> arguments = combinator.children().stream()
|
||||
.map(arg -> new SlotRef(arg.getDataType().toCatalogDataType(), arg.nullable()))
|
||||
.collect(ImmutableList.toImmutableList());
|
||||
return Function.convertForEachCombinator(
|
||||
new FunctionCallExpr(visitAggregateFunction(combinator.getNestedFunction(), context).getFn(),
|
||||
new FunctionParams(false, arguments)));
|
||||
}
|
||||
|
||||
@Override
|
||||
public Expr visitAggregateFunction(AggregateFunction function, PlanTranslatorContext context) {
|
||||
List<Expr> arguments = function.children().stream()
|
||||
|
||||
@ -18,11 +18,15 @@
|
||||
package org.apache.doris.nereids.trees.expressions.functions;
|
||||
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.SlotReference;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
|
||||
import org.apache.doris.nereids.types.AggStateType;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
@ -31,28 +35,30 @@ import java.util.stream.Collectors;
|
||||
/**
|
||||
* This class used to resolve AggState's combinators
|
||||
*/
|
||||
public class AggStateFunctionBuilder extends FunctionBuilder {
|
||||
public class AggCombinerFunctionBuilder extends FunctionBuilder {
|
||||
public static final String COMBINATOR_LINKER = "_";
|
||||
public static final String STATE = "state";
|
||||
public static final String MERGE = "merge";
|
||||
public static final String UNION = "union";
|
||||
public static final String FOREACH = "foreach";
|
||||
|
||||
public static final String STATE_SUFFIX = COMBINATOR_LINKER + STATE;
|
||||
public static final String MERGE_SUFFIX = COMBINATOR_LINKER + MERGE;
|
||||
public static final String UNION_SUFFIX = COMBINATOR_LINKER + UNION;
|
||||
public static final String FOREACH_SUFFIX = COMBINATOR_LINKER + FOREACH;
|
||||
|
||||
private final FunctionBuilder nestedBuilder;
|
||||
|
||||
private final String combinatorSuffix;
|
||||
|
||||
public AggStateFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBuilder) {
|
||||
public AggCombinerFunctionBuilder(String combinatorSuffix, FunctionBuilder nestedBuilder) {
|
||||
this.combinatorSuffix = Objects.requireNonNull(combinatorSuffix, "combinatorSuffix can not be null");
|
||||
this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null");
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean canApply(List<? extends Object> arguments) {
|
||||
if (combinatorSuffix.equals(STATE)) {
|
||||
if (combinatorSuffix.equals(STATE) || combinatorSuffix.equals(FOREACH)) {
|
||||
return nestedBuilder.canApply(arguments);
|
||||
} else {
|
||||
if (arguments.size() != 1) {
|
||||
@ -71,6 +77,23 @@ public class AggStateFunctionBuilder extends FunctionBuilder {
|
||||
return (AggregateFunction) nestedBuilder.build(nestedName, arguments);
|
||||
}
|
||||
|
||||
private AggregateFunction buildForEach(String nestedName, List<? extends Object> arguments) {
|
||||
List<Expression> forEachargs = arguments.stream().map(expr -> {
|
||||
if (!(expr instanceof SlotReference)) {
|
||||
throw new IllegalStateException(
|
||||
"Can not build foreach nested function: '" + nestedName);
|
||||
}
|
||||
DataType arrayType = (((Expression) expr).getDataType());
|
||||
if (!(arrayType instanceof ArrayType)) {
|
||||
throw new IllegalStateException(
|
||||
"foreach must be input array type: '" + nestedName);
|
||||
}
|
||||
DataType itemType = ((ArrayType) arrayType).getItemType();
|
||||
return new SlotReference("mocked", itemType, (((ArrayType) arrayType).containsNull()));
|
||||
}).collect(Collectors.toList());
|
||||
return (AggregateFunction) nestedBuilder.build(nestedName, forEachargs);
|
||||
}
|
||||
|
||||
private AggregateFunction buildMergeOrUnion(String nestedName, List<? extends Object> arguments) {
|
||||
if (arguments.size() != 1 || !(arguments.get(0) instanceof Expression)
|
||||
|| !((Expression) arguments.get(0)).getDataType().isAggStateType()) {
|
||||
@ -105,13 +128,16 @@ public class AggStateFunctionBuilder extends FunctionBuilder {
|
||||
} else if (combinatorSuffix.equals(UNION)) {
|
||||
AggregateFunction nestedFunction = buildMergeOrUnion(nestedName, arguments);
|
||||
return new UnionCombinator((List<Expression>) arguments, nestedFunction);
|
||||
} else if (combinatorSuffix.equals(FOREACH)) {
|
||||
AggregateFunction nestedFunction = buildForEach(nestedName, arguments);
|
||||
return new ForEachCombinator((List<Expression>) arguments, nestedFunction);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
public static boolean isAggStateCombinator(String name) {
|
||||
return name.toLowerCase().endsWith(STATE_SUFFIX) || name.toLowerCase().endsWith(MERGE_SUFFIX)
|
||||
|| name.toLowerCase().endsWith(UNION_SUFFIX);
|
||||
|| name.toLowerCase().endsWith(UNION_SUFFIX) || name.toLowerCase().endsWith(FOREACH_SUFFIX);
|
||||
}
|
||||
|
||||
public static String getNestedName(String name) {
|
||||
@ -0,0 +1,90 @@
|
||||
// Licensed to the Apache Software Foundation (ASF) under one
|
||||
// or more contributor license agreements. See the NOTICE file
|
||||
// distributed with this work for additional information
|
||||
// regarding copyright ownership. The ASF licenses this file
|
||||
// to you under the Apache License, Version 2.0 (the
|
||||
// "License"); you may not use this file except in compliance
|
||||
// with the License. You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing,
|
||||
// software distributed under the License is distributed on an
|
||||
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
|
||||
// KIND, either express or implied. See the License for the
|
||||
// specific language governing permissions and limitations
|
||||
// under the License.
|
||||
|
||||
package org.apache.doris.nereids.trees.expressions.functions.combinator;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
|
||||
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
|
||||
import org.apache.doris.nereids.types.ArrayType;
|
||||
import org.apache.doris.nereids.types.DataType;
|
||||
|
||||
import com.google.common.collect.ImmutableList;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Objects;
|
||||
|
||||
/**
|
||||
* combinator foreach
|
||||
*/
|
||||
public class ForEachCombinator extends AggregateFunction
|
||||
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable {
|
||||
|
||||
private final AggregateFunction nested;
|
||||
|
||||
/**
|
||||
* constructor of ForEachCombinator
|
||||
*/
|
||||
public ForEachCombinator(List<Expression> arguments, AggregateFunction nested) {
|
||||
super(nested.getName() + AggCombinerFunctionBuilder.FOREACH_SUFFIX, arguments);
|
||||
|
||||
this.nested = Objects.requireNonNull(nested, "nested can not be null");
|
||||
}
|
||||
|
||||
public static ForEachCombinator create(AggregateFunction nested) {
|
||||
return new ForEachCombinator(nested.getArguments(), nested);
|
||||
}
|
||||
|
||||
@Override
|
||||
public ForEachCombinator withChildren(List<Expression> children) {
|
||||
return new ForEachCombinator(children, nested);
|
||||
}
|
||||
|
||||
@Override
|
||||
public List<FunctionSignature> getSignatures() {
|
||||
return nested.getSignatures().stream().map(sig -> {
|
||||
return sig.withReturnType(ArrayType.of(sig.returnType)).withArgumentTypes(false,
|
||||
sig.argumentsTypes.stream().map(arg -> {
|
||||
return ArrayType.of(arg);
|
||||
}).collect(ImmutableList.toImmutableList()));
|
||||
}).collect(ImmutableList.toImmutableList());
|
||||
}
|
||||
|
||||
@Override
|
||||
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
|
||||
return visitor.visitForEachCombinator(this, context);
|
||||
}
|
||||
|
||||
@Override
|
||||
public DataType getDataType() {
|
||||
return ArrayType.of(nested.getDataType(), nested.nullable());
|
||||
}
|
||||
|
||||
public AggregateFunction getNestedFunction() {
|
||||
return nested;
|
||||
}
|
||||
|
||||
@Override
|
||||
public AggregateFunction withDistinctAndChildren(boolean distinct, List<Expression> children) {
|
||||
throw new UnsupportedOperationException("Unimplemented method 'withDistinctAndChildren'");
|
||||
}
|
||||
}
|
||||
@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.combinator;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ComputeNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
@ -43,7 +43,7 @@ public class MergeCombinator extends AggregateFunction
|
||||
private final AggStateType inputType;
|
||||
|
||||
public MergeCombinator(List<Expression> arguments, AggregateFunction nested) {
|
||||
super(nested.getName() + AggStateFunctionBuilder.MERGE_SUFFIX, arguments);
|
||||
super(nested.getName() + AggCombinerFunctionBuilder.MERGE_SUFFIX, arguments);
|
||||
|
||||
this.nested = Objects.requireNonNull(nested, "nested can not be null");
|
||||
inputType = (AggStateType) arguments.get(0).getDataType();
|
||||
|
||||
@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.combinator;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
@ -47,7 +47,7 @@ public class StateCombinator extends ScalarFunction
|
||||
* constructor of StateCombinator
|
||||
*/
|
||||
public StateCombinator(List<Expression> arguments, AggregateFunction nested) {
|
||||
super(nested.getName() + AggStateFunctionBuilder.STATE_SUFFIX, arguments);
|
||||
super(nested.getName() + AggCombinerFunctionBuilder.STATE_SUFFIX, arguments);
|
||||
|
||||
this.nested = Objects.requireNonNull(nested, "nested can not be null");
|
||||
this.returnType = new AggStateType(nested.getName(), arguments.stream().map(arg -> {
|
||||
|
||||
@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.combinator;
|
||||
|
||||
import org.apache.doris.catalog.FunctionSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggStateFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AggCombinerFunctionBuilder;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
|
||||
@ -43,7 +43,7 @@ public class UnionCombinator extends AggregateFunction
|
||||
private final AggStateType inputType;
|
||||
|
||||
public UnionCombinator(List<Expression> arguments, AggregateFunction nested) {
|
||||
super(nested.getName() + AggStateFunctionBuilder.UNION_SUFFIX, arguments);
|
||||
super(nested.getName() + AggCombinerFunctionBuilder.UNION_SUFFIX, arguments);
|
||||
|
||||
this.nested = Objects.requireNonNull(nested, "nested can not be null");
|
||||
inputType = (AggStateType) arguments.get(0).getDataType();
|
||||
|
||||
@ -72,6 +72,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.TopNWeighted;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.Variance;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.VarianceSamp;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.agg.WindowFunnel;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.ForEachCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
|
||||
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
|
||||
@ -305,6 +306,10 @@ public interface AggregateFunctionVisitor<R, C> {
|
||||
return visitAggregateFunction(combinator, context);
|
||||
}
|
||||
|
||||
default R visitForEachCombinator(ForEachCombinator combinator, C context) {
|
||||
return visitAggregateFunction(combinator, context);
|
||||
}
|
||||
|
||||
default R visitJavaUdaf(JavaUdaf javaUdaf, C context) {
|
||||
return visitAggregateFunction(javaUdaf, context);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user