[Improve](array) support array_enumerate_uniq and array_suffle for nereids (#29936)

This commit is contained in:
amory
2024-01-15 16:54:47 +08:00
committed by yiguolei
parent c9cf9ab841
commit d5dcdf3e07
6 changed files with 204 additions and 13 deletions

View File

@ -37,6 +37,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCumSum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDifference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerateUniq;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
@ -59,6 +60,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRemove;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRepeat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayShuffle;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySlice;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy;
@ -445,6 +447,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArrayDifference.class, "array_difference"),
scalar(ArrayDistinct.class, "array_distinct"),
scalar(ArrayEnumerate.class, "array_enumerate"),
scalar(ArrayEnumerateUniq.class, "array_enumerate_uniq"),
scalar(ArrayExcept.class, "array_except"),
scalar(ArrayExists.class, "array_exists"),
scalar(ArrayFilter.class, "array_filter"),
@ -470,6 +473,7 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArraySlice.class, "array_slice"),
scalar(ArraySort.class, "array_sort"),
scalar(ArraySortBy.class, "array_sortby"),
scalar(ArrayShuffle.class, "array_shuffle", "shuffle"),
scalar(ArraySum.class, "array_sum"),
scalar(ArrayUnion.class, "array_union"),
scalar(ArrayWithConstant.class, "array_with_constant"),

View File

@ -0,0 +1,73 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* ScalarFunction 'array_enumerate_uniq'.
* more than 0 array as args
*/
public class ArrayEnumerateUniq extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(BigIntType.INSTANCE)).varArgs(ArrayType.of(new AnyDataType(0)))
);
/**
* constructor with more than 0 arguments.
*/
public ArrayEnumerateUniq(Expression arg, Expression ...varArgs) {
super("array_enumerate_uniq", ExpressionUtils.mergeArguments(arg, varArgs));
}
/**
* withChildren.
*/
@Override
public ArrayEnumerateUniq withChildren(List<Expression> children) {
Preconditions.checkArgument(!children.isEmpty());
return new ArrayEnumerateUniq(children.get(0), children.subList(1, children.size()).toArray(new Expression[0]));
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayEnumerateUniq(this, context);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -0,0 +1,86 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import java.util.List;
/**
* ScalarFunction 'array_shuffle'
* with 1 or 2 arguments : array_shuffle(arr) or array_shuffle(arr, seed)
*/
public class ArrayShuffle extends ScalarFunction
implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable {
public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.retArgType(0).args(ArrayType.of(new AnyDataType(0))),
FunctionSignature.retArgType(0)
.args(ArrayType.of(new AnyDataType(0)), BigIntType.INSTANCE)
);
/**
* constructor with 1 arguments.
*/
public ArrayShuffle(Expression arg) {
super("array_shuffle", arg);
}
/**
* constructor with 2 arguments.
*/
public ArrayShuffle(Expression arg, Expression arg1) {
super("array_shuffle", arg, arg1);
}
/**
* withChildren.
*/
@Override
public ArrayShuffle withChildren(List<Expression> children) {
Preconditions.checkArgument(children.size() == 1
|| children.size() == 2);
if (children.size() == 1) {
return new ArrayShuffle(children.get(0));
} else {
return new ArrayShuffle(children.get(0), children.get(1));
}
}
@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayShuffle(this, context);
}
@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}

View File

@ -39,6 +39,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayCumSum;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDifference;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayDistinct;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerate;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerateUniq;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
@ -59,6 +60,7 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRange;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRemove;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayRepeat;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayReverseSort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayShuffle;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySlice;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySort;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArraySortBy;
@ -484,6 +486,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(arrayEnumerate, context);
}
default R visitArrayEnumerateUniq(ArrayEnumerateUniq arrayEnumerateUniq, C context) {
return visitScalarFunction(arrayEnumerateUniq, context);
}
default R visitArrayExcept(ArrayExcept arrayExcept, C context) {
return visitScalarFunction(arrayExcept, context);
}
@ -564,6 +570,10 @@ public interface ScalarFunctionVisitor<R, C> {
return visitScalarFunction(arraySortBy, context);
}
default R visitArrayShuffle(ArrayShuffle arrayShuffle, C context) {
return visitScalarFunction(arrayShuffle, context);
}
default R visitArrayMap(ArrayMap arraySort, C context) {
return visitScalarFunction(arraySort, context);
}