From 226e01889c4e36a9ed0c7cd5273f4fa02a3bfe2f Mon Sep 17 00:00:00 2001 From: amory Date: Wed, 14 Aug 2024 18:52:29 +0800 Subject: [PATCH] [fix](array_apply) pick array apply fix (#39328) ## Proposed changes backport: https://github.com/apache/doris/pull/39105 Issue Number: close #xxx --- .../functions/array/function_array_apply.cpp | 87 ++++++++++--------- .../functions/scalar/ArrayApply.java | 10 +++ .../scalar_function/Array.out | 5 +- .../test_array_functions_by_literal.out | 3 + .../scalar_function/Array.groovy | 13 +++ .../test_array_functions_by_literal.groovy | 9 ++ 6 files changed, 84 insertions(+), 43 deletions(-) diff --git a/be/src/vec/functions/array/function_array_apply.cpp b/be/src/vec/functions/array/function_array_apply.cpp index 2ad680635a..426347c449 100644 --- a/be/src/vec/functions/array/function_array_apply.cpp +++ b/be/src/vec/functions/array/function_array_apply.cpp @@ -173,48 +173,51 @@ private: } // need exception safety -#define APPLY_ALL_TYPES(src_column, src_offsets, OP, cmp, dst) \ - do { \ - WhichDataType which(remove_nullable(nested_type)); \ - if (which.is_uint8()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_int8()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_int16()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_int32()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_int64()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_int128()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_float32()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_float64()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_date()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_date_time()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_date_v2()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_date_time_v2()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_date_time_v2()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_decimal32()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_decimal64()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_decimal128v2()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_decimal128v3()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else if (which.is_decimal256()) { \ - *dst = _apply_internal(src_column, src_offsets, cmp); \ - } else { \ - LOG(FATAL) << "unsupported type " << nested_type->get_name(); \ - } \ +#define APPLY_ALL_TYPES(src_column, src_offsets, OP, cmp, dst) \ + do { \ + WhichDataType which(remove_nullable(nested_type)); \ + if (which.is_uint8()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_int8()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_int16()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_int32()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_int64()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_int128()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_float32()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_float64()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_date()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_date_time()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_date_v2()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_date_time_v2()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_date_time_v2()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_decimal32()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_decimal64()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_decimal128v2()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_decimal128v3()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else if (which.is_decimal256()) { \ + *dst = _apply_internal(src_column, src_offsets, cmp); \ + } else { \ + throw doris::Exception(ErrorCode::INVALID_ARGUMENT, \ + "array_apply only accept array with nested type which is " \ + "uint/int/decimal/float/date but got : " + \ + nested_type->get_name()); \ + } \ } while (0) // need exception safety diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java index 82bad4e486..07e4c16d77 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ArrayApply.java @@ -26,6 +26,7 @@ import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral; import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.ArrayType; +import org.apache.doris.nereids.types.DataType; import org.apache.doris.nereids.types.VarcharType; import org.apache.doris.nereids.types.coercion.AnyDataType; import org.apache.doris.nereids.types.coercion.FollowToAnyDataType; @@ -67,6 +68,15 @@ public class ArrayApply extends ScalarFunction } } + @Override + public void checkLegalityBeforeTypeCoercion() { + DataType argType = ((ArrayType) child(0).getDataType()).getItemType(); + if (!(argType.isIntegralType() || argType.isFloatLikeType() || argType.isDecimalLikeType() + || argType.isDateLikeType() || argType.isBooleanType())) { + throw new AnalysisException("array_apply does not support type: " + toSql()); + } + } + @Override public ArrayApply withChildren(List children) { Preconditions.checkArgument(children.size() == 3, diff --git a/regression-test/data/nereids_function_p0/scalar_function/Array.out b/regression-test/data/nereids_function_p0/scalar_function/Array.out index cfad441a49..220924df04 100644 --- a/regression-test/data/nereids_function_p0/scalar_function/Array.out +++ b/regression-test/data/nereids_function_p0/scalar_function/Array.out @@ -14448,6 +14448,9 @@ true -- !array_empty_fe -- [] --- !array_empty_be -- +-- !array_empty_fe -- [] +-- !sql_array_map -- +[1, 1, 1, 1] + diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out index 6f3b1756d4..bddcebea70 100644 --- a/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_functions_by_literal.out @@ -977,3 +977,6 @@ _ -- !sql -- [11.9999, 34.0000] +-- !sql_array_map -- +[1, 1, 1, 1] + diff --git a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy index e1eb2bab51..506742588b 100644 --- a/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy +++ b/regression-test/suites/nereids_function_p0/scalar_function/Array.groovy @@ -1308,4 +1308,17 @@ suite("nereids_scalar_fn_Array") { logger.info(exception.message) } } + + // with array empty + qt_array_empty_fe """select array()""" + + // array_map with string is can be succeed + qt_sql_array_map """select array_map(x->x!='', split_by_string('amory,is,better,committing', ','))""" + + // array_apply with string should be failed + test { + sql """select array_apply(split_by_string("amory,is,better,committing", ","), '!=', '');""" + exception("errCode = 2") + } + } diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy index f335cd7211..81689b6335 100644 --- a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_functions_by_literal.groovy @@ -410,4 +410,13 @@ suite("test_array_functions_by_literal") { } catch (Exception ex) { assert("${ex}".contains("errCode = 2, detailMessage = No matching function with signature: array_intersect")) } + + // array_map with string is can be succeed + qt_sql_array_map """ select array_map(x->x!='', split_by_string('amory,is,better,committing', ',')) """ + + // array_apply with string should be failed + test { + sql """select array_apply(split_by_string("amory,is,better,committing", ","), '!=', '');""" + exception("No matching function") + } }