[feat](Nereids) add lambda func array_last and array_first (#24682)

This commit is contained in:
谢健
2023-09-21 12:23:34 +08:00
committed by GitHub
parent a65dbb097c
commit 5b590bbfcf
7 changed files with 611 additions and 12 deletions

View File

@ -45,6 +45,7 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr {
public static final ImmutableSet<String> LAMBDA_MAPPED_FUNCTION_SET = new ImmutableSortedSet.Builder(
String.CASE_INSENSITIVE_ORDER).add("array_exists").add("array_sortby")
.add("array_first_index").add("array_last_index").add("array_first").add("array_last").add("array_count")
.add("element_at")
.build();
private static final Logger LOG = LogManager.getLogger(LambdaFunctionCallExpr.class);

View File

@ -36,9 +36,11 @@ import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayEnumerat
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExcept;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayExists;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFilter;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirst;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayFirstIndex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayJoin;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLast;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayLastIndex;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMax;
@ -413,9 +415,11 @@ public class BuiltinScalarFunctions implements FunctionHelper {
scalar(ArrayExcept.class, "array_except"),
scalar(ArrayExists.class, "array_exists"),
scalar(ArrayFilter.class, "array_filter"),
scalar(ArrayFirst.class, "array_first"),
scalar(ArrayFirstIndex.class, "array_first_index"),
scalar(ArrayIntersect.class, "array_intersect"),
scalar(ArrayJoin.class, "array_join"),
scalar(ArrayLast.class, "array_last"),
scalar(ArrayLastIndex.class, "array_last_index"),
scalar(ArrayMap.class, "array_map"),
scalar(ArrayMax.class, "array_max"),

View File

@ -87,6 +87,8 @@ import org.apache.doris.nereids.trees.expressions.functions.combinator.MergeComb
import org.apache.doris.nereids.trees.expressions.functions.combinator.StateCombinator;
import org.apache.doris.nereids.trees.expressions.functions.combinator.UnionCombinator;
import org.apache.doris.nereids.trees.expressions.functions.generator.TableGeneratingFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ArrayMap;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HighOrderFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Lambda;
import org.apache.doris.nereids.trees.expressions.functions.scalar.ScalarFunction;
import org.apache.doris.nereids.trees.expressions.functions.udf.JavaUdaf;
@ -405,9 +407,10 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
return new LambdaFunctionExpr(func, lambda.getLambdaArgumentNames(), arguments);
}
private Expr visitHighOrderFunction(ScalarFunction function, PlanTranslatorContext context) {
Lambda lambda = (Lambda) function.child(0);
List<Expr> arguments = new ArrayList<>(function.children().size());
@Override
public Expr visitArrayMap(ArrayMap arrayMap, PlanTranslatorContext context) {
Lambda lambda = (Lambda) arrayMap.child(0);
List<Expr> arguments = new ArrayList<>(arrayMap.children().size());
arguments.add(null);
int columnId = 0;
for (ArrayItemReference arrayItemReference : lambda.getLambdaArguments()) {
@ -424,7 +427,7 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
columnId += 1;
}
List<Type> argTypes = function.getArguments().stream()
List<Type> argTypes = arrayMap.getArguments().stream()
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.collect(Collectors.toList());
@ -433,11 +436,11 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
.map(Expression::getDataType)
.map(DataType::toCatalogDataType)
.forEach(argTypes::add);
NullableMode nullableMode = function.nullable()
NullableMode nullableMode = arrayMap.nullable()
? NullableMode.ALWAYS_NULLABLE
: NullableMode.ALWAYS_NOT_NULLABLE;
org.apache.doris.catalog.Function catalogFunction = new Function(
new FunctionName(function.getName()), argTypes,
new FunctionName(arrayMap.getName()), argTypes,
ArrayType.create(lambda.getRetType().toCatalogDataType(), true),
true, true, nullableMode);
@ -449,10 +452,6 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
@Override
public Expr visitScalarFunction(ScalarFunction function, PlanTranslatorContext context) {
if (function.isHighOrder()) {
return visitHighOrderFunction(function, context);
}
List<Expr> arguments = function.getArguments().stream()
.map(arg -> arg.accept(this, context))
.collect(Collectors.toList());
@ -472,8 +471,7 @@ public class ExpressionTranslator extends DefaultExpressionVisitor<Expr, PlanTra
"", TFunctionBinaryType.BUILTIN, true, true, nullableMode);
// create catalog FunctionCallExpr without analyze again
if (LambdaFunctionCallExpr.LAMBDA_FUNCTION_SET.contains(function.getName())
|| LambdaFunctionCallExpr.LAMBDA_MAPPED_FUNCTION_SET.contains(function.getName())) {
if (function instanceof HighOrderFunction) {
return new LambdaFunctionCallExpr(catalogFunction, new FunctionParams(false, arguments));
}
return new FunctionCallExpr(catalogFunction, new FunctionParams(false, arguments));

View File

@ -0,0 +1,49 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import java.util.List;
/**
* ScalarFunction 'array_first'.
*/
public class ArrayFirst extends ElementAt
implements HighOrderFunction {
/**
* constructor with arguments.
* array_first(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), 1)
*/
public ArrayFirst(Expression arg) {
super(new ArrayFilter(arg), new BigIntLiteral(1));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
}
@Override
public List<FunctionSignature> getImplSignature() {
return SIGNATURES;
}
}

View File

@ -0,0 +1,49 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.trees.expressions.functions.scalar;
import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
import java.util.List;
/**
* ScalarFunction 'array_last'.
*/
public class ArrayLast extends ElementAt
implements HighOrderFunction {
/**
* constructor with arguments.
* array_last(lambda, a1, ...) = element_at(array_filter(lambda, a1, ...), -1)
*/
public ArrayLast(Expression arg) {
super(new ArrayFilter(arg), new BigIntLiteral(-1));
if (!(arg instanceof Lambda)) {
throw new AnalysisException(
String.format("The 1st arg of %s must be lambda but is %s", getName(), arg));
}
}
@Override
public List<FunctionSignature> getImplSignature() {
return SIGNATURES;
}
}

View File

@ -11570,3 +11570,467 @@ true
2
2
-- !sql_array_first_Double --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.1
1.2
-- !sql_array_first_Double_notnull --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.1
1.2
-- !sql_array_first_Float --
\N
\N
10.0
11.0
12.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
-- !sql_array_first_Float_notnull --
\N
10.0
11.0
12.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
-- !sql_array_first_LargeInt --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_LargeInt_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_BigInt --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_BigInt_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_SmallInt --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_SmallInt_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_Integer --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_Integer_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_first_TinyInt --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
-- !sql_array_first_TinyInt_notnull --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
-- !sql_array_first_DecimalV3 --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.100000000
1.200000000
-- !sql_array_first_DecimalV3_notnull --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.100000000
1.200000000
-- !sql_array_last_Double --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.1
1.2
-- !sql_array_last_Double_notnull --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.1
1.2
-- !sql_array_last_Float --
\N
\N
10.0
11.0
12.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
-- !sql_array_last_Float_notnull --
\N
10.0
11.0
12.0
2.0
3.0
4.0
5.0
6.0
7.0
8.0
9.0
-- !sql_array_last_LargeInt --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_LargeInt_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_BigInt --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_BigInt_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_SmallInt --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_SmallInt_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_Integer --
\N
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_Integer_notnull --
\N
10
11
12
2
3
4
5
6
7
8
9
-- !sql_array_last_TinyInt --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
-- !sql_array_last_TinyInt_notnull --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
-- !sql_array_last_DecimalV3 --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.100000000
1.200000000
-- !sql_array_last_DecimalV3_notnull --
\N
\N
\N
\N
\N
\N
\N
\N
\N
\N
1.100000000
1.200000000

View File

@ -961,6 +961,40 @@ suite("nereids_scalar_fn_Array") {
order_qt_sql_array_last_index_TinyInt_notnull "select array_last_index(x -> x > 1, katint) from fn_test_not_nullable"
order_qt_sql_array_last_index_DecimalV3 "select array_last_index(x -> x > 1, kadcml) from fn_test"
order_qt_sql_array_last_index_DecimalV3_notnull "select array_last_index(x -> x > 1, kadcml) from fn_test_not_nullable"
order_qt_sql_array_first_Double "select array_first(x -> x > 1, kadbl) from fn_test"
// test array_first
order_qt_sql_array_first_Double_notnull "select array_first(x -> x > 1, kadbl) from fn_test_not_nullable"
order_qt_sql_array_first_Float "select array_first(x -> x > 1, kafloat) from fn_test"
order_qt_sql_array_first_Float_notnull "select array_first(x -> x > 1, kafloat) from fn_test_not_nullable"
order_qt_sql_array_first_LargeInt "select array_first(x -> x > 1, kalint) from fn_test"
order_qt_sql_array_first_LargeInt_notnull "select array_first(x -> x > 1, kalint) from fn_test_not_nullable"
order_qt_sql_array_first_BigInt "select array_first(x -> x > 1, kabint) from fn_test"
order_qt_sql_array_first_BigInt_notnull "select array_first(x -> x > 1, kabint) from fn_test_not_nullable"
order_qt_sql_array_first_SmallInt "select array_first(x -> x > 1, kasint) from fn_test"
order_qt_sql_array_first_SmallInt_notnull "select array_first(x -> x > 1, kasint) from fn_test_not_nullable"
order_qt_sql_array_first_Integer "select array_first(x -> x > 1, kaint) from fn_test"
order_qt_sql_array_first_Integer_notnull "select array_first(x -> x > 1, kaint) from fn_test_not_nullable"
order_qt_sql_array_first_TinyInt "select array_first(x -> x > 1, katint) from fn_test"
order_qt_sql_array_first_TinyInt_notnull "select array_first(x -> x > 1, katint) from fn_test_not_nullable"
order_qt_sql_array_first_DecimalV3 "select array_first(x -> x > 1, kadcml) from fn_test"
order_qt_sql_array_first_DecimalV3_notnull "select array_first(x -> x > 1, kadcml) from fn_test_not_nullable"
// test array_last
order_qt_sql_array_last_Double "select array_last(x -> x > 1, kadbl) from fn_test"
order_qt_sql_array_last_Double_notnull "select array_last(x -> x > 1, kadbl) from fn_test_not_nullable"
order_qt_sql_array_last_Float "select array_last(x -> x > 1, kafloat) from fn_test"
order_qt_sql_array_last_Float_notnull "select array_last(x -> x > 1, kafloat) from fn_test_not_nullable"
order_qt_sql_array_last_LargeInt "select array_last(x -> x > 1, kalint) from fn_test"
order_qt_sql_array_last_LargeInt_notnull "select array_last(x -> x > 1, kalint) from fn_test_not_nullable"
order_qt_sql_array_last_BigInt "select array_last(x -> x > 1, kabint) from fn_test"
order_qt_sql_array_last_BigInt_notnull "select array_last(x -> x > 1, kabint) from fn_test_not_nullable"
order_qt_sql_array_last_SmallInt "select array_last(x -> x > 1, kasint) from fn_test"
order_qt_sql_array_last_SmallInt_notnull "select array_last(x -> x > 1, kasint) from fn_test_not_nullable"
order_qt_sql_array_last_Integer "select array_last(x -> x > 1, kaint) from fn_test"
order_qt_sql_array_last_Integer_notnull "select array_last(x -> x > 1, kaint) from fn_test_not_nullable"
order_qt_sql_array_last_TinyInt "select array_last(x -> x > 1, katint) from fn_test"
order_qt_sql_array_last_TinyInt_notnull "select array_last(x -> x > 1, katint) from fn_test_not_nullable"
order_qt_sql_array_last_DecimalV3 "select array_last(x -> x > 1, kadcml) from fn_test"
order_qt_sql_array_last_DecimalV3_notnull "select array_last(x -> x > 1, kadcml) from fn_test_not_nullable"
test {
sql "select tokenize('arg1','xxx = yyy,zzz');"