From 325a1d4b2842bdef5a0ec7d95ffe3a6ce783d48d Mon Sep 17 00:00:00 2001 From: Ziyu Wang <46886508+wzymumon@users.noreply.github.com> Date: Tue, 16 May 2023 17:00:01 +0800 Subject: [PATCH] [vectorized](function) support array_count function (#18557) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit support array_count function. array_count:Returns the number of non-zero and non-null elements in the given array. --- be/src/vec/CMakeLists.txt | 1 + .../functions/array/function_array_count.cpp | 115 ++++++++++++++++++ .../array/function_array_register.cpp | 2 + .../array-functions/array_count.md | 112 +++++++++++++++++ docs/sidebars.json | 1 + .../array-functions/array_count.md | 106 ++++++++++++++++ .../analysis/LambdaFunctionCallExpr.java | 7 +- gensrc/script/doris_builtins_functions.py | 1 + .../test_array_count_function.out | 91 ++++++++++++++ .../test_array_count_function.groovy | 67 ++++++++++ 10 files changed, 500 insertions(+), 3 deletions(-) create mode 100644 be/src/vec/functions/array/function_array_count.cpp create mode 100644 docs/en/docs/sql-manual/sql-functions/array-functions/array_count.md create mode 100644 docs/zh-CN/docs/sql-manual/sql-functions/array-functions/array_count.md create mode 100644 regression-test/data/query_p0/sql_functions/array_functions/test_array_count_function.out create mode 100644 regression-test/suites/query_p0/sql_functions/array_functions/test_array_count_function.groovy diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index e88a4072c3..4d22b3da4d 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -210,6 +210,7 @@ set(VEC_FILES functions/array/function_array_pushfront.cpp functions/array/function_array_first_index.cpp functions/array/function_array_cum_sum.cpp + functions/array/function_array_count.cpp functions/function_map.cpp functions/function_struct.cpp exprs/table_function/vexplode_json_array.cpp diff --git a/be/src/vec/functions/array/function_array_count.cpp b/be/src/vec/functions/array/function_array_count.cpp new file mode 100644 index 0000000000..bf0f9bb890 --- /dev/null +++ b/be/src/vec/functions/array/function_array_count.cpp @@ -0,0 +1,115 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include +#include +#include +#include +#include +#include +#include + +namespace doris::vectorized { + +// array_count([0, 1, 1, 1, 0, 0]) -> [3] +class FunctionArrayCount : public IFunction { +public: + static constexpr auto name = "array_count"; + + static FunctionPtr create() { return std::make_shared(); } + + String get_name() const override { return name; } + + bool is_variadic() const override { return false; } + + size_t get_number_of_arguments() const override { return 1; } + + bool use_default_implementation_for_nulls() const override { return false; } + + bool use_default_implementation_for_constants() const override { return true; } + + ColumnNumbers get_arguments_that_are_always_constant() const override { return {1}; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return std::make_shared(); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) override { + const auto& [src_column, src_const] = + unpack_if_const(block.get_by_position(arguments[0]).column); + const ColumnArray* array_column = nullptr; + const UInt8* array_null_map = nullptr; + if (src_column->is_nullable()) { + auto nullable_array = assert_cast(src_column.get()); + array_column = assert_cast(&nullable_array->get_nested_column()); + array_null_map = nullable_array->get_null_map_column().get_data().data(); + } else { + array_column = assert_cast(src_column.get()); + } + + if (!array_column) { + return Status::RuntimeError("unsupported types for function {}({})", get_name(), + block.get_by_position(arguments[0]).type->get_name()); + } + + const auto& offsets = array_column->get_offsets(); + ColumnPtr nested_column = nullptr; + const UInt8* nested_null_map = nullptr; + if (array_column->get_data().is_nullable()) { + const auto& nested_null_column = + assert_cast(array_column->get_data()); + nested_null_map = nested_null_column.get_null_map_column().get_data().data(); + nested_column = nested_null_column.get_nested_column_ptr(); + } else { + nested_column = array_column->get_data_ptr(); + } + + const auto& nested_data = assert_cast(*nested_column).get_data(); + + auto dst_column = ColumnInt64::create(offsets.size()); + auto& dst_data = dst_column->get_data(); + + for (size_t row = 0; row < offsets.size(); ++row) { + Int64 res = 0; + if (array_null_map && array_null_map[row]) { + dst_data[row] = res; + continue; + } + size_t off = offsets[row - 1]; + size_t len = offsets[row] - off; + for (size_t pos = 0; pos < len; ++pos) { + if (nested_null_map && nested_null_map[pos + off]) { + continue; + } + if (nested_data[pos + off] != 0) { + ++res; + } + } + dst_data[row] = res; + } + + block.replace_by_position(result, std::move(dst_column)); + return Status::OK(); + } +}; + +void register_function_array_count(SimpleFunctionFactory& factory) { + factory.register_function(); +} +} // namespace doris::vectorized diff --git a/be/src/vec/functions/array/function_array_register.cpp b/be/src/vec/functions/array/function_array_register.cpp index 9ac2983ef7..1c2a4e128a 100644 --- a/be/src/vec/functions/array/function_array_register.cpp +++ b/be/src/vec/functions/array/function_array_register.cpp @@ -52,6 +52,7 @@ void register_function_array_zip(SimpleFunctionFactory&); void register_function_array_pushfront(SimpleFunctionFactory& factory); void register_function_array_first_index(SimpleFunctionFactory& factory); void register_function_array_cum_sum(SimpleFunctionFactory& factory); +void register_function_array_count(SimpleFunctionFactory&); void register_function_array(SimpleFunctionFactory& factory) { register_function_array_shuffle(factory); @@ -84,6 +85,7 @@ void register_function_array(SimpleFunctionFactory& factory) { register_function_array_pushfront(factory); register_function_array_first_index(factory); register_function_array_cum_sum(factory); + register_function_array_count(factory); } } // namespace doris::vectorized diff --git a/docs/en/docs/sql-manual/sql-functions/array-functions/array_count.md b/docs/en/docs/sql-manual/sql-functions/array-functions/array_count.md new file mode 100644 index 0000000000..0f3e8e6217 --- /dev/null +++ b/docs/en/docs/sql-manual/sql-functions/array-functions/array_count.md @@ -0,0 +1,112 @@ +--- +{ + "title": "array_count", + "language": "en" +} +--- + + + +## array_count + + + +array_count + + + +### description + +```sql +array_count(lambda, array1, ...) +``` + + +Use lambda expressions as input parameters to perform corresponding expression calculations on the internal data of other input ARRAY parameters. +Returns the number of elements such that the return value of `lambda(array1[i], ...)` is not 0. Returns 0 if no element is found that satisfies this condition. + +There are one or more parameters are input in the lambda expression, which must be consistent with the number of input array columns later.The number of elements of all input arrays must be the same. Legal scalar functions can be executed in lambda, aggregate functions, etc. are not supported. + + +``` +array_count(x->x, array1); +array_count(x->(x%2 = 0), array1); +array_count(x->(abs(x)-1), array1); +array_count((x,y)->(x = y), array1, array2); +``` + +### notice + +`Only supported in vectorized engine` + +### example + +``` +mysql> select array_count(x -> x, [0, 1, 2, 3]); ++--------------------------------------------------------+ +| array_count(array_map([x] -> x(0), ARRAY(0, 1, 2, 3))) | ++--------------------------------------------------------+ +| 3 | ++--------------------------------------------------------+ +1 row in set (0.00 sec) + +mysql> select array_count(x -> x > 2, [0, 1, 2, 3]); ++------------------------------------------------------------+ +| array_count(array_map([x] -> x(0) > 2, ARRAY(0, 1, 2, 3))) | ++------------------------------------------------------------+ +| 1 | ++------------------------------------------------------------+ +1 row in set (0.01 sec) + +mysql> select array_count(x -> x is null, [null, null, null, 1, 2]); ++----------------------------------------------------------------------------+ +| array_count(array_map([x] -> x(0) IS NULL, ARRAY(NULL, NULL, NULL, 1, 2))) | ++----------------------------------------------------------------------------+ +| 3 | ++----------------------------------------------------------------------------+ +1 row in set (0.01 sec) + +mysql> select array_count(x -> power(x,2)>10, [1, 2, 3, 4, 5]); ++------------------------------------------------------------------------------+ +| array_count(array_map([x] -> power(x(0), 2.0) > 10.0, ARRAY(1, 2, 3, 4, 5))) | ++------------------------------------------------------------------------------+ +| 2 | ++------------------------------------------------------------------------------+ +1 row in set (0.01 sec) + +mysql> select *, array_count((x, y) -> x>y, c_array1, c_array2) from array_test; ++------+-----------------+-------------------------+-----------------------------------------------------------------------+ +| id | c_array1 | c_array2 | array_count(array_map([x, y] -> x(0) > y(1), `c_array1`, `c_array2`)) | ++------+-----------------+-------------------------+-----------------------------------------------------------------------+ +| 1 | [1, 2, 3, 4, 5] | [10, 20, -40, 80, -100] | 2 | +| 2 | [6, 7, 8] | [10, 12, 13] | 0 | +| 3 | [1] | [-100] | 1 | +| 4 | [1, NULL, 2] | [NULL, 3, 1] | 1 | +| 5 | [] | [] | 0 | +| 6 | NULL | NULL | 0 | ++------+-----------------+-------------------------+-----------------------------------------------------------------------+ +6 rows in set (0.02 sec) + +``` + +### keywords + +ARRAY, COUNT, ARRAY_COUNT + diff --git a/docs/sidebars.json b/docs/sidebars.json index 7c7e554166..62b7663ad8 100644 --- a/docs/sidebars.json +++ b/docs/sidebars.json @@ -312,6 +312,7 @@ "sql-manual/sql-functions/array-functions/array_first_index", "sql-manual/sql-functions/array-functions/array_last", "sql-manual/sql-functions/array-functions/arrays_overlap", + "sql-manual/sql-functions/array-functions/array_count", "sql-manual/sql-functions/array-functions/countequal", "sql-manual/sql-functions/array-functions/element_at" ] diff --git a/docs/zh-CN/docs/sql-manual/sql-functions/array-functions/array_count.md b/docs/zh-CN/docs/sql-manual/sql-functions/array-functions/array_count.md new file mode 100644 index 0000000000..b51f393a22 --- /dev/null +++ b/docs/zh-CN/docs/sql-manual/sql-functions/array-functions/array_count.md @@ -0,0 +1,106 @@ +--- +{ + "title": "array_count", + "language": "zh-CN" +} +--- + + + +## array_count + + + +array_count + + + +### description + +```sql +array_count(lambda, array1, ...) +``` + + +使用lambda表达式作为输入参数,对其他输入ARRAY参数的内部数据进行相应的表达式计算。 返回使得 `lambda(array1[i], ...)` 返回值不为 0 的元素数量。如果找不到到满足此条件的元素,则返回 0。 + +lambda表达式中输入的参数为1个或多个,必须和后面输入的数组列数一致,且所有输入的array的元素个数必须相同。在lambda中可以执行合法的标量函数,不支持聚合函数等。 + +``` +array_count(x->x, array1); +array_count(x->(x%2 = 0), array1); +array_count(x->(abs(x)-1), array1); +array_count((x,y)->(x = y), array1, array2); +``` + +### example + +``` +mysql> select array_count(x -> x, [0, 1, 2, 3]); ++--------------------------------------------------------+ +| array_count(array_map([x] -> x(0), ARRAY(0, 1, 2, 3))) | ++--------------------------------------------------------+ +| 3 | ++--------------------------------------------------------+ +1 row in set (0.00 sec) + +mysql> select array_count(x -> x > 2, [0, 1, 2, 3]); ++------------------------------------------------------------+ +| array_count(array_map([x] -> x(0) > 2, ARRAY(0, 1, 2, 3))) | ++------------------------------------------------------------+ +| 1 | ++------------------------------------------------------------+ +1 row in set (0.01 sec) + +mysql> select array_count(x -> x is null, [null, null, null, 1, 2]); ++----------------------------------------------------------------------------+ +| array_count(array_map([x] -> x(0) IS NULL, ARRAY(NULL, NULL, NULL, 1, 2))) | ++----------------------------------------------------------------------------+ +| 3 | ++----------------------------------------------------------------------------+ +1 row in set (0.01 sec) + +mysql> select array_count(x -> power(x,2)>10, [1, 2, 3, 4, 5]); ++------------------------------------------------------------------------------+ +| array_count(array_map([x] -> power(x(0), 2.0) > 10.0, ARRAY(1, 2, 3, 4, 5))) | ++------------------------------------------------------------------------------+ +| 2 | ++------------------------------------------------------------------------------+ +1 row in set (0.01 sec) + +mysql> select *, array_count((x, y) -> x>y, c_array1, c_array2) from array_test; ++------+-----------------+-------------------------+-----------------------------------------------------------------------+ +| id | c_array1 | c_array2 | array_count(array_map([x, y] -> x(0) > y(1), `c_array1`, `c_array2`)) | ++------+-----------------+-------------------------+-----------------------------------------------------------------------+ +| 1 | [1, 2, 3, 4, 5] | [10, 20, -40, 80, -100] | 2 | +| 2 | [6, 7, 8] | [10, 12, 13] | 0 | +| 3 | [1] | [-100] | 1 | +| 4 | [1, NULL, 2] | [NULL, 3, 1] | 1 | +| 5 | [] | [] | 0 | +| 6 | NULL | NULL | 0 | ++------+-----------------+-------------------------+-----------------------------------------------------------------------+ +6 rows in set (0.02 sec) + +``` + +### keywords + +ARRAY, COUNT, ARRAY_COUNT + diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java index d31e240b3d..7614d17f1a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/LambdaFunctionCallExpr.java @@ -35,13 +35,13 @@ import java.util.List; public class LambdaFunctionCallExpr extends FunctionCallExpr { public static final ImmutableSet LAMBDA_FUNCTION_SET = new ImmutableSortedSet.Builder( String.CASE_INSENSITIVE_ORDER).add("array_map").add("array_filter").add("array_exists").add("array_sortby") - .add("array_first_index").add("array_last").build(); + .add("array_first_index").add("array_last").add("array_count").build(); // The functions in this set are all normal array functions when implemented initially. // and then wants add lambda expr as the input param, so we rewrite it to contains an array_map lambda function // rather than reimplementing a lambda function, this will be reused the implementation of normal array function public static final ImmutableSet LAMBDA_MAPPED_FUNCTION_SET = new ImmutableSortedSet.Builder( String.CASE_INSENSITIVE_ORDER).add("array_exists").add("array_sortby") - .add("array_first_index").add("array_last") + .add("array_first_index").add("array_last").add("array_count") .build(); private static final Logger LOG = LogManager.getLogger(LambdaFunctionCallExpr.class); @@ -108,7 +108,8 @@ public class LambdaFunctionCallExpr extends FunctionCallExpr { } fn.setReturnType(ArrayType.create(lambda.getChild(0).getType(), true)); } else if (fnName.getFunction().equalsIgnoreCase("array_exists") - || fnName.getFunction().equalsIgnoreCase("array_first_index")) { + || fnName.getFunction().equalsIgnoreCase("array_first_index") + || fnName.getFunction().equalsIgnoreCase("array_count")) { if (fnParams.exprs() == null || fnParams.exprs().size() < 1) { throw new AnalysisException("The " + fnName.getFunction() + " function must have at least one param"); } diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index f019b4f97b..2cbd15db15 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -689,6 +689,7 @@ visible_functions = [ [['array_exists'], 'ARRAY_BOOLEAN', ['ARRAY_STRING'], ''], [['array_first_index'], 'BIGINT', ['ARRAY_BOOLEAN'], 'ALWAYS_NOT_NULLABLE'], + [['array_count'], 'BIGINT', ['ARRAY_BOOLEAN'], 'ALWAYS_NOT_NULLABLE'], [['array_shuffle', 'shuffle'], 'ARRAY_BOOLEAN', ['ARRAY_BOOLEAN'], ''], [['array_shuffle', 'shuffle'], 'ARRAY_TINYINT', ['ARRAY_TINYINT'], ''], diff --git a/regression-test/data/query_p0/sql_functions/array_functions/test_array_count_function.out b/regression-test/data/query_p0/sql_functions/array_functions/test_array_count_function.out new file mode 100644 index 0000000000..60bd804580 --- /dev/null +++ b/regression-test/data/query_p0/sql_functions/array_functions/test_array_count_function.out @@ -0,0 +1,91 @@ +-- This file is automatically generated. You should know what you did if you want to edit this +-- !select -- +0 + +-- !select -- +0 + +-- !select -- +2 + +-- !select -- +3 + +-- !select -- +3 + +-- !select -- +2 + +-- !select -- +1 [1, 2, 3, 4, 5] [10, 20, -40, 80, -100] 1 +2 [6, 7, 8] [10, 12, 13] 1 +3 [1] [-100] 1 +4 [1, NULL, 2] [NULL, 3, 1] 1 +5 [] [] 1 +6 \N \N 1 + +-- !select -- +1 [1, 2, 3, 4, 5] [10, 20, -40, 80, -100] 3 +2 [6, 7, 8] [10, 12, 13] 3 +3 [1] [-100] 3 +4 [1, NULL, 2] [NULL, 3, 1] 3 +5 [] [] 3 +6 \N \N 3 + +-- !select -- +1 [1, 2, 3, 4, 5] [10, 20, -40, 80, -100] 1 +2 [6, 7, 8] [10, 12, 13] 1 +3 [1] [-100] 1 +4 [1, NULL, 2] [NULL, 3, 1] 1 +5 [] [] 1 +6 \N \N 1 + +-- !select -- +[1, 2, 3, 4, 5] 5 +[6, 7, 8] 3 +[1] 1 +[1, NULL, 2] 2 +[] 0 +\N 0 + +-- !select -- +[1, 2, 3, 4, 5] 2 +[6, 7, 8] 3 +[1] 0 +[1, NULL, 2] 0 +[] 0 +\N 0 + +-- !select -- +[10, 20, -40, 80, -100] 4 +[10, 12, 13] 2 +[-100] 1 +[NULL, 3, 1] 0 +[] 0 +\N 0 + +-- !select -- +[1, 2, 3, 4, 5] [10, 20, -40, 80, -100] 2 +[6, 7, 8] [10, 12, 13] 0 +[1] [-100] 1 +[1, NULL, 2] [NULL, 3, 1] 1 +[] [] 0 +\N \N 0 + +-- !select -- +[1, 2, 3, 4, 5] [10, 20, -40, 80, -100] 5 +[6, 7, 8] [10, 12, 13] 3 +[1] [-100] 1 +[1, NULL, 2] [NULL, 3, 1] 1 +[] [] 0 +\N \N 0 + +-- !select -- +[1, 2, 3, 4, 5] [10, 20, -40, 80, -100] 4 +[6, 7, 8] [10, 12, 13] 3 +[1] [-100] 1 +[1, NULL, 2] [NULL, 3, 1] 0 +[] [] 0 +\N \N 0 + diff --git a/regression-test/suites/query_p0/sql_functions/array_functions/test_array_count_function.groovy b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_count_function.groovy new file mode 100644 index 0000000000..5106f62e3a --- /dev/null +++ b/regression-test/suites/query_p0/sql_functions/array_functions/test_array_count_function.groovy @@ -0,0 +1,67 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +suite("test_array_count_function") { + + def tableName = "test_array_count_function" + + sql "DROP TABLE IF EXISTS ${tableName}" + + sql """ + CREATE TABLE IF NOT EXISTS `${tableName}` ( + `id` int(11) NULL, + `c_array1` array NULL, + `c_array2` array NULL + ) ENGINE=OLAP + DUPLICATE KEY(`id`) + DISTRIBUTED BY HASH(`id`) BUCKETS 1 + PROPERTIES ( + "replication_allocation" = "tag.location.default: 1", + "storage_format" = "V2" + ) + """ + + + sql """ + INSERT INTO ${tableName} values + (1, [1,2,3,4,5], [10,20,-40,80,-100]), + (2, [6,7,8],[10,12,13]), (3, [1],[-100]), + (4, [1, null, 2], [null, 3, 1]), (5, [], []), (6, null, null) + """ + + + qt_select "select array_count(x -> x > 1, []);" + qt_select "select array_count(x -> x > 1, [null]);" + qt_select "select array_count(x -> x > 1, [0, 1, 2, 3])" + qt_select "select array_count(x -> x + 1 > 1, [0, 1, 2, 3])" + qt_select "select array_count(x -> x is null, [null, null, null, 0, 1, 2]);" + qt_select "select array_count(x -> x > 2, array_map(x->power(x,2),[1,2,3]));" + + qt_select "select *, array_count(x -> x>2, [1,2,3]) from ${tableName} order by id;" + qt_select "select *, array_count(x -> x+1, [1,2,3]) from ${tableName} order by id;" + qt_select "select *, array_count(x -> x%2=0, [1,2,3]) from ${tableName} order by id;" + + qt_select "select c_array1, array_count(x -> x, c_array1) from ${tableName} order by id;" + qt_select "select c_array1, array_count(x -> x>3, c_array1) from ${tableName} order by id;" + qt_select "select c_array2, array_count(x -> power(x,2)>100, c_array2) from ${tableName} order by id;" + + qt_select "select c_array1, c_array2, array_count((x,y) -> x>y, c_array1, c_array2) from ${tableName} order by id;" + qt_select "select c_array1, c_array2, array_count((x,y) -> x+y, c_array1, c_array2) from ${tableName} order by id;" + qt_select "select c_array1, c_array2, array_count((x,y) -> x*abs(y)>10, c_array1, c_array2) from ${tableName} order by id;" + + sql "DROP TABLE IF EXISTS ${tableName}" +} \ No newline at end of file