From 79aa079cc6f4f0be08de6e831026913fd91446db Mon Sep 17 00:00:00 2001 From: amory Date: Sat, 17 Aug 2024 10:56:44 +0800 Subject: [PATCH] [fix](array-funcs) array min/max #39307 (#39484) --- .../array/function_array_aggregation.cpp | 8 ++++- .../functions/scalar/ArrayMax.java | 9 ++++++ .../functions/scalar/ArrayMin.java | 9 ++++++ .../functions/scalar/ArrayReverseSort.java | 10 ++++++ .../functions/scalar/ArraySort.java | 10 ++++++ .../scalar_function/Array.groovy | 32 ++++++++++++++++++- .../test_array_functions_by_literal.groovy | 31 ++++++++++++++++++ 7 files changed, 107 insertions(+), 2 deletions(-) diff --git a/be/src/vec/functions/array/function_array_aggregation.cpp b/be/src/vec/functions/array/function_array_aggregation.cpp index e8a2fd9e95..d2edfe34fb 100644 --- a/be/src/vec/functions/array/function_array_aggregation.cpp +++ b/be/src/vec/functions/array/function_array_aggregation.cpp @@ -147,7 +147,13 @@ struct ArrayAggregateImpl { const DataTypeArray* data_type_array = static_cast(remove_nullable(arguments[0]).get()); auto function = Function::create(data_type_array->get_nested_type()); - return function->get_return_type(); + if (function) { + return function->get_return_type(); + } else { + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, + "Unexpected type {} for aggregation {}", + data_type_array->get_nested_type()->get_name(), operation); + } } static Status execute(Block& block, const ColumnNumbers& arguments, size_t result, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMax.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMax.java index c1d0eff1b2..f8a2920521 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMax.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMax.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; @@ -50,6 +51,14 @@ public class ArrayMax extends ScalarFunction implements ExplicitlyCastableSignat super("array_max", arg); } + @Override + public void checkLegalityBeforeTypeCoercion() { + DataType argType = child().getDataType(); + if (((ArrayType) argType).getItemType().isComplexType()) { + throw new AnalysisException("array_max does not support complex types: " + toSql()); + } + } + @Override public DataType getDataType() { return ((ArrayType) (child().getDataType())).getItemType(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMin.java index dbfba39a2c..642b86f575 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayMin.java @@ -18,6 +18,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; @@ -50,6 +51,14 @@ public class ArrayMin extends ScalarFunction implements ExplicitlyCastableSignat super("array_min", arg); } + @Override + public void checkLegalityBeforeTypeCoercion() { + DataType argType = child().getDataType(); + if (((ArrayType) argType).getItemType().isComplexType()) { + throw new AnalysisException("array_min does not support complex types: " + toSql()); + } + } + @Override public DataType getDataType() { return ((ArrayType) (child().getDataType())).getItemType(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java index bcdc1d852f..dd62fdb7e4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayReverseSort.java @@ -18,12 +18,14 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AnyDataType; import com.google.common.base.Preconditions; @@ -48,6 +50,14 @@ public class ArrayReverseSort extends ScalarFunction super("array_reverse_sort", arg); } + @Override + public void checkLegalityBeforeTypeCoercion() { + DataType argType = child().getDataType(); + if (((ArrayType) argType).getItemType().isComplexType()) { + throw new AnalysisException("array_reverse_sort does not support complex types: " + toSql()); + } + } + /** * withChildren. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArraySort.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArraySort.java index 5953d69b66..80f359b61f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArraySort.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArraySort.java @@ -18,12 +18,14 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; +import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.coercion.AnyDataType; import com.google.common.base.Preconditions; @@ -48,6 +50,14 @@ public class ArraySort extends ScalarFunction super("array_sort", arg); } + @Override + public void checkLegalityBeforeTypeCoercion() { + DataType argType = child().getDataType(); + if (((ArrayType) argType).getItemType().isComplexType()) { + throw new AnalysisException("array_sort does not support complex types: " + toSql()); + } + } + /** * withChildren. */ diff --git a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy index 506742588b..9fa1af3d63 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy @@ -1309,7 +1309,7 @@ suite("nereids_scalar_fn_Array") { } } - // with array empty + sql """ set enable_fold_constant_by_be=true; """ qt_array_empty_fe """select array()""" // array_map with string is can be succeed @@ -1321,4 +1321,34 @@ suite("nereids_scalar_fn_Array") { exception("errCode = 2") } + // array_min/max with nested array for args + test { + sql "select array_min(array(1,2,3),array(4,5,6));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } + test { + sql "select array_max(array(1,2,3),array(4,5,6));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } + + test { + sql "select array_min(array(split_by_string('a,b,c',',')));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } + test { + sql "select array_max(array(split_by_string('a,b,c',',')));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } } diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy index 81689b6335..14c6ed5b92 100644 --- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy @@ -419,4 +419,35 @@ suite("test_array_functions_by_literal") { sql """select array_apply(split_by_string("amory,is,better,committing", ","), '!=', '');""" exception("No matching function") } + // array_min/max with nested array for args + test { + sql "select array_min(array(1,2,3),array(4,5,6));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } + test { + sql "select array_max(array(1,2,3),array(4,5,6));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } + + test { + sql "select array_min(array(split_by_string('a,b,c',',')));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } + test { + sql "select array_max(array(split_by_string('a,b,c',',')));" + check{result, exception, startTime, endTime -> + assertTrue(exception != null) + logger.info(exception.message) + } + } + }