[feature](agg) support aggregate function group_array_intersect (#33265)

This commit is contained in:
Chester
2024-04-16 16:25:48 +08:00
committed by yiguolei
parent 07a8f44443
commit 3096150d1b
12 changed files with 1115 additions and 2 deletions

View File

@ -55,7 +55,8 @@ public class AggregateFunction extends Function {
FunctionSet.COUNT, "approx_count_distinct", "ndv", FunctionSet.BITMAP_UNION_INT,
FunctionSet.BITMAP_UNION_COUNT, "ndv_no_finalize", FunctionSet.WINDOW_FUNNEL, FunctionSet.RETENTION,
FunctionSet.SEQUENCE_MATCH, FunctionSet.SEQUENCE_COUNT, FunctionSet.MAP_AGG, FunctionSet.BITMAP_AGG,
FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET);
FunctionSet.ARRAY_AGG, FunctionSet.COLLECT_LIST, FunctionSet.COLLECT_SET,
FunctionSet.GROUP_ARRAY_INTERSECT);
public static ImmutableSet<String> ALWAYS_NULLABLE_AGGREGATE_FUNCTION_NAME_SET =
ImmutableSet.of("stddev_samp", "variance_samp", "var_samp", "percentile_approx", "first_value",

View File

@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitOr;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitXor;
@ -103,6 +104,7 @@ public class BuiltinAggregateFunctions implements FunctionHelper {
agg(CountByEnum.class, "count_by_enum"),
agg(Covar.class, "covar", "covar_pop"),
agg(CovarSamp.class, "covar_samp"),
agg(GroupArrayIntersect.class, "group_array_intersect"),
agg(GroupBitAnd.class, "group_bit_and"),
agg(GroupBitOr.class, "group_bit_or"),
agg(GroupBitXor.class, "group_bit_xor"),

View File

@ -612,6 +612,8 @@ public class FunctionSet<T> {
public static final String GROUP_ARRAY = "group_array";
public static final String GROUP_ARRAY_INTERSECT = "group_array_intersect";
public static final String ARRAY_AGG = "array_agg";
// Populate all the aggregate builtins in the catalog.
@ -1503,7 +1505,9 @@ public class FunctionSet<T> {
addBuiltin(
AggregateFunction.createBuiltin(GROUP_ARRAY, Lists.newArrayList(t, Type.INT), new ArrayType(t),
t, "", "", "", "", "", true, false, true, true));
addBuiltin(
AggregateFunction.createBuiltin(GROUP_ARRAY_INTERSECT, Lists.newArrayList(new ArrayType(t)),
new ArrayType(t), t, "", "", "", "", "", true, false, true, true));
addBuiltin(AggregateFunction.createBuiltin(ARRAY_AGG, Lists.newArrayList(t), new ArrayType(t), t, "", "", "", "", "",
true, false, true, true));

View File

@ -0,0 +1,76 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.agg;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.AlwaysNotNullable;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* AggregateFunction 'group_array_intersect'.
*/
public class GroupArrayIntersect extends AggregateFunction
implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNotNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0))));
/**
* constructor with 1 argument.
*/
public GroupArrayIntersect(Expression arg) {
super("group_array_intersect", arg);
}
/**
* constructor with 1 argument.
*/
public GroupArrayIntersect(boolean distinct, Expression arg) {
super("group_array_intersect", false, arg);
}
/**
* withChildren.
*/
@Override
public AggregateFunction withDistinctAndChildren(boolean distinct, List<Expression> children) {
Preconditions.checkArgument(children.size() == 1);
return new GroupArrayIntersect(distinct, children.get(0));
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitGroupArrayIntersect(this, context);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -34,6 +34,7 @@ import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitOr;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitXor;
@ -168,6 +169,10 @@ public interface AggregateFunctionVisitor<R, C> {
return visitAggregateFunction(multiDistinctSum0, context);
}
default R visitGroupArrayIntersect(GroupArrayIntersect groupArrayIntersect, C context) {
return visitAggregateFunction(groupArrayIntersect, context);
}
default R visitGroupBitAnd(GroupBitAnd groupBitAnd, C context) {
return visitNullableAggregateFunction(groupBitAnd, context);
}