From 4d516bece8adba01c2e0b3a71e331cd9356d023a Mon Sep 17 00:00:00 2001 From: camby <104178625@qq.com> Date: Sat, 2 Apr 2022 12:03:56 +0800 Subject: [PATCH] [feature-wip](array-type)Add element_at and subscript functions (#8597) Describe the overview of changes. 1. add function element_at; 2. support element_subscript([]) to get element of array, col_array[N] <==> element_at(col_array, N); 3. return error message instead of BE crash while array function execute failed; element_at(array, index) desc: > Returns element of array at given **(1-based)** index. If **index < 0**, accesses elements from the last to the first. Returns NULL if the index exceeds the length of the array or the array is NULL. Usage example: 1. create table with ARRAY type column and insert some data: ``` +------+------+--------+ | k1 | k2 | k3 | +------+------+--------+ | 1 | 2 | [1, 2] | | 2 | 3 | NULL | | 4 | NULL | [] | | 3 | NULL | NULL | +------+------+--------+ ``` 2. enable vectorized: ``` set enable_vectorized_engine=true; ``` 3. element_subscript([]) usage example: ``` > select k1,k3,k3[1] from array_test; +------+--------+----------------------------+ | k1 | k3 | %element_extract%(`k3`, 1) | +------+--------+----------------------------+ | 3 | NULL | NULL | | 1 | [1, 2] | 1 | | 2 | NULL | NULL | | 4 | [] | NULL | +------+--------+----------------------------+ ``` 4. element_at function usage example: ``` > select k1,k3 from array_test where element_at(k3, -1) = 2; +------+--------+ | k1 | k3 | +------+--------+ | 1 | [1, 2] | +------+--------+ ``` --- be/src/vec/CMakeLists.txt | 1 + .../array/function_array_element.cpp | 29 +++ .../functions/array/function_array_element.h | 229 ++++++++++++++++++ .../functions/array/function_array_index.cpp | 9 +- .../functions/array/function_array_index.h | 104 +++++--- .../array/function_array_register.cpp | 2 + be/src/vec/functions/function.cpp | 4 +- be/test/vec/function/CMakeLists.txt | 1 + .../function/function_array_element_test.cpp | 84 +++++++ .../function/function_array_index_test.cpp | 48 +++- .../org/apache/doris/catalog/FunctionSet.java | 24 +- .../java/org/apache/doris/catalog/Type.java | 7 +- gensrc/script/doris_builtins_functions.py | 22 +- gensrc/script/gen_builtins_functions.py | 21 +- 14 files changed, 519 insertions(+), 66 deletions(-) create mode 100644 be/src/vec/functions/array/function_array_element.cpp create mode 100644 be/src/vec/functions/array/function_array_element.h create mode 100644 be/test/vec/function/function_array_element_test.cpp diff --git a/be/src/vec/CMakeLists.txt b/be/src/vec/CMakeLists.txt index be9e788280..e952d98bf4 100644 --- a/be/src/vec/CMakeLists.txt +++ b/be/src/vec/CMakeLists.txt @@ -112,6 +112,7 @@ set(VEC_FILES exprs/table_function/vexplode_split.cpp exprs/table_function/vexplode_numbers.cpp functions/array/function_array_index.cpp + functions/array/function_array_element.cpp functions/array/function_array_register.cpp functions/math.cpp functions/function_bitmap.cpp diff --git a/be/src/vec/functions/array/function_array_element.cpp b/be/src/vec/functions/array/function_array_element.cpp new file mode 100644 index 0000000000..f1868e0395 --- /dev/null +++ b/be/src/vec/functions/array/function_array_element.cpp @@ -0,0 +1,29 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "vec/functions/array/function_array_element.h" + +#include "vec/functions/simple_function_factory.h" + +namespace doris::vectorized { + +void register_function_array_element(SimpleFunctionFactory& factory) { + factory.register_function(); + factory.register_alias(FunctionArrayElement::name, "%element_extract%"); +} + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/array/function_array_element.h b/be/src/vec/functions/array/function_array_element.h new file mode 100644 index 0000000000..eae3375bca --- /dev/null +++ b/be/src/vec/functions/array/function_array_element.h @@ -0,0 +1,229 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. +// This file is copied from +// https://github.com/ClickHouse/ClickHouse/blob/master/src/Functions/array/arrayElement.cpp +// and modified by Doris +#pragma once + +#include + +#include "vec/columns/column_array.h" +#include "vec/columns/column_const.h" +#include "vec/columns/column_string.h" +#include "vec/data_types/data_type_array.h" +#include "vec/data_types/data_type_number.h" +#include "vec/functions/function.h" +#include "vec/functions/function_helpers.h" + +namespace doris::vectorized { + +class FunctionArrayElement : public IFunction { +public: + static constexpr auto name = "element_at"; + static FunctionPtr create() { return std::make_shared(); } + + /// Get function name. + String get_name() const override { return name; } + + bool is_variadic() const override { return false; } + + size_t get_number_of_arguments() const override { return 2; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + DCHECK(is_array(arguments[0])) + << "first argument for function: " << name << " should be DataTypeArray"; + DCHECK(is_integer(arguments[1])) + << "second argument for function: " << name << " should be Integer"; + return make_nullable( + check_and_get_data_type(arguments[0].get())->get_nested_type()); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) override { + auto dst_null_column = ColumnUInt8::create(input_rows_count); + UInt8* dst_null_map = dst_null_column->get_data().data(); + const UInt8* src_null_map = nullptr; + ColumnsWithTypeAndName args; + auto col_left = block.get_by_position(arguments[0]); + if (col_left.column->is_nullable()) { + auto null_col = check_and_get_column(*col_left.column); + src_null_map = null_col->get_null_map_column().get_data().data(); + args = {{null_col->get_nested_column_ptr(), remove_nullable(col_left.type), + col_left.name}, + block.get_by_position(arguments[1])}; + } else { + args = {col_left, block.get_by_position(arguments[1])}; + } + + auto result_type = remove_nullable( + check_and_get_data_type(args[0].type.get())->get_nested_type()); + + auto res_column = _execute_non_nullable(args, result_type, input_rows_count, src_null_map, + dst_null_map); + if (!res_column) { + return Status::RuntimeError( + fmt::format("unsupported types for function {}({}, {})", get_name(), + block.get_by_position(arguments[0]).type->get_name(), + block.get_by_position(arguments[1]).type->get_name())); + } + block.replace_by_position( + result, ColumnNullable::create(std::move(res_column), std::move(dst_null_column))); + return Status::OK(); + } + +private: + template + ColumnPtr _execute_number(const ColumnArray::Offsets& offsets, const IColumn& nested_column, + const UInt8* arr_null_map, const IColumn& indices, + const UInt8* nested_null_map, UInt8* dst_null_map) { + const auto& nested_data = check_and_get_column(nested_column)->get_data(); + auto dst_column = ColumnType::create(offsets.size()); + auto& dst_data = dst_column->get_data(); + + // process + for (size_t row = 0; row < offsets.size(); ++row) { + size_t off = offsets[row - 1]; + size_t len = offsets[row] - off; + auto index = indices.get_int(row); + // array is nullable + bool null_flag = bool(arr_null_map && arr_null_map[row]); + // calc index in nested column + if (!null_flag && index > 0 && index <= len) { + index += off - 1; + } else if (!null_flag && index < 0 && -index <= len) { + index += off + len; + } else { + null_flag = true; + } + // nested column nullable check + if (!null_flag && nested_null_map && nested_null_map[index]) { + null_flag = true; + } + // actual data copy + if (null_flag) { + dst_null_map[row] = true; + dst_data[row] = typename ColumnType::value_type(); + } else { + DCHECK(index >= 0 && index < nested_data.size()); + dst_null_map[row] = false; + dst_data[row] = nested_data[index]; + } + } + return dst_column; + } + + ColumnPtr _execute_string(const ColumnArray::Offsets& offsets, const IColumn& nested_column, + const UInt8* arr_null_map, const IColumn& indices, + const UInt8* nested_null_map, UInt8* dst_null_map) { + const auto& src_str_offs = check_and_get_column(nested_column)->get_offsets(); + const auto& src_str_chars = check_and_get_column(nested_column)->get_chars(); + + // prepare return data + auto dst_column = ColumnString::create(); + auto& dst_str_offs = dst_column->get_offsets(); + dst_str_offs.resize(offsets.size()); + auto& dst_str_chars = dst_column->get_chars(); + dst_str_chars.reserve(src_str_chars.size()); + + // process + for (size_t row = 0; row < offsets.size(); ++row) { + size_t off = offsets[row - 1]; + size_t len = offsets[row] - off; + auto index = indices.get_int(row); + // array is nullable + bool null_flag = bool(arr_null_map && arr_null_map[row]); + // calc index in nested column + if (!null_flag && index > 0 && index <= len) { + index += off - 1; + } else if (!null_flag && index < 0 && -index <= len) { + index += off + len; + } else { + null_flag = true; + } + // nested column nullable check + if (!null_flag && nested_null_map && nested_null_map[index]) { + null_flag = true; + } + // actual string copy + if (!null_flag) { + DCHECK(index >= 0 && index < src_str_offs.size()); + dst_null_map[row] = false; + auto element_size = src_str_offs[index] - src_str_offs[index - 1]; + dst_str_offs[row] = dst_str_offs[row - 1] + element_size; + auto src_string_pos = src_str_offs[index - 1]; + auto dst_string_pos = dst_str_offs[row - 1]; + dst_str_chars.resize(dst_string_pos + element_size); + memcpy(&dst_str_chars[dst_string_pos], &src_str_chars[src_string_pos], + element_size); + } else { + dst_null_map[row] = true; + dst_str_offs[row] = dst_str_offs[row - 1]; + } + } + return dst_column; + } + + ColumnPtr _execute_non_nullable(const ColumnsWithTypeAndName& arguments, + const DataTypePtr& result_type, size_t input_rows_count, + const UInt8* src_null_map, UInt8* dst_null_map) { + // check array nested column type and get data + auto array_column = check_and_get_column(*arguments[0].column); + DCHECK(array_column != nullptr); + const auto& offsets = array_column->get_offsets(); + DCHECK(offsets.size() == input_rows_count); + const UInt8* nested_null_map = nullptr; + ColumnPtr nested_column = nullptr; + if (is_column_nullable(array_column->get_data())) { + const auto& nested_null_column = + check_and_get_column(array_column->get_data()); + nested_null_map = nested_null_column->get_null_map_column().get_data().data(); + nested_column = nested_null_column->get_nested_column_ptr(); + } else { + nested_column = array_column->get_data_ptr(); + } + + ColumnPtr res = nullptr; + if (check_column(*nested_column)) { + res = _execute_number(offsets, *nested_column, src_null_map, + *arguments[1].column, nested_null_map, dst_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number(offsets, *nested_column, src_null_map, + *arguments[1].column, nested_null_map, dst_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number(offsets, *nested_column, src_null_map, + *arguments[1].column, nested_null_map, dst_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number(offsets, *nested_column, src_null_map, + *arguments[1].column, nested_null_map, dst_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number(offsets, *nested_column, src_null_map, + *arguments[1].column, nested_null_map, + dst_null_map); + } else if (check_column(*nested_column)) { + res = _execute_number(offsets, *nested_column, src_null_map, + *arguments[1].column, nested_null_map, + dst_null_map); + } else if (check_column(*nested_column)) { + res = _execute_string(offsets, *nested_column, src_null_map, *arguments[1].column, + nested_null_map, dst_null_map); + } + + return res; + } +}; + +} // namespace doris::vectorized diff --git a/be/src/vec/functions/array/function_array_index.cpp b/be/src/vec/functions/array/function_array_index.cpp index 474500ed89..0c59b05b77 100644 --- a/be/src/vec/functions/array/function_array_index.cpp +++ b/be/src/vec/functions/array/function_array_index.cpp @@ -16,12 +16,17 @@ // under the License. #include "vec/functions/array/function_array_index.h" + #include "vec/functions/simple_function_factory.h" namespace doris::vectorized { -struct NameArrayContains { static constexpr auto name = "array_contains"; }; -struct NameArrayPosition { static constexpr auto name = "array_position"; }; +struct NameArrayContains { + static constexpr auto name = "array_contains"; +}; +struct NameArrayPosition { + static constexpr auto name = "array_position"; +}; void register_function_array_index(SimpleFunctionFactory& factory) { factory.register_function>(); diff --git a/be/src/vec/functions/array/function_array_index.h b/be/src/vec/functions/array/function_array_index.h index f0948112cd..435fa7a1d5 100644 --- a/be/src/vec/functions/array/function_array_index.h +++ b/be/src/vec/functions/array/function_array_index.h @@ -29,23 +29,20 @@ namespace doris::vectorized { -struct ArrayContainsAction -{ +struct ArrayContainsAction { using ResultType = UInt8; static constexpr const bool resume_execution = false; static constexpr void apply(ResultType& current, size_t) noexcept { current = 1; } }; -struct ArrayPositionAction -{ +struct ArrayPositionAction { using ResultType = Int64; static constexpr const bool resume_execution = false; static constexpr void apply(ResultType& current, size_t j) noexcept { current = j + 1; } }; template -class FunctionArrayIndex : public IFunction -{ +class FunctionArrayIndex : public IFunction { public: using ResultType = typename ConcreteAction::ResultType; @@ -60,21 +57,32 @@ public: size_t get_number_of_arguments() const override { return 2; } DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { - DCHECK(WhichDataType(arguments[0]).is_array()); + DCHECK(is_array(arguments[0])); return std::make_shared>(); } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) override { - return execute_non_nullable(block, arguments, result, input_rows_count); + return _execute_non_nullable(block, arguments, result, input_rows_count); } private: - static bool execute_string(Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) { + static bool _execute_string(Block& block, const ColumnNumbers& arguments, size_t result, + size_t input_rows_count) { // check array nested column type and get data - auto array_column = check_and_get_column(*block.get_by_position(arguments[0]).column); + auto array_column = + check_and_get_column(*block.get_by_position(arguments[0]).column); DCHECK(array_column != nullptr); - auto nested_column = check_and_get_column(array_column->get_data()); + const ColumnString* nested_column = nullptr; + const UInt8* nested_null_map = nullptr; + auto nested_null_column = check_and_get_column(array_column->get_data()); + if (nested_null_column) { + nested_null_map = nested_null_column->get_null_map_column().get_data().data(); + nested_column = + check_and_get_column(nested_null_column->get_nested_column()); + } else { + nested_column = check_and_get_column(array_column->get_data()); + } if (!nested_column) { return false; } @@ -92,7 +100,8 @@ private: } // expand const column and get data - auto right_column = check_and_get_column(*block.get_by_position(arguments[1]).column->convert_to_full_column_if_const()); + auto right_column = check_and_get_column( + *block.get_by_position(arguments[1]).column->convert_to_full_column_if_const()); const auto& right_offs = right_column->get_offsets(); const auto& right_chars = right_column->get_chars(); @@ -110,12 +119,16 @@ private: size_t right_off = right_offs[row - 1]; size_t right_len = right_offs[row] - right_off; for (size_t pos = 0; pos < len; ++pos) { + if (nested_null_map && nested_null_map[pos + off]) { + continue; + } + size_t str_pos = str_offs[pos + off - 1]; size_t str_len = str_offs[pos + off] - str_pos; - const char* left_raw_v = reinterpret_cast(&str_chars[str_pos]); const char* right_raw_v = reinterpret_cast(&right_chars[right_off]); - if (std::string_view(left_raw_v, str_len) == std::string_view(right_raw_v, right_len)) { + if (std::string_view(left_raw_v, str_len) == + std::string_view(right_raw_v, right_len)) { ConcreteAction::apply(res, pos); break; } @@ -126,21 +139,37 @@ private: return true; } -#define INTEGRAL_TPL_PACK UInt8, UInt16, UInt32, UInt64, Int8, Int16, Int32, Int64, Float32, Float64 - template - static bool execute_integral(Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) { - return (execute_integral_expanded(block, arguments, result, input_rows_count) || ...); +#define NUMBER_TPL_PACK Int8, Int16, Int32, Int64, Float32, Float64 + template + static bool _execute_number(Block& block, const ColumnNumbers& arguments, size_t result, + size_t input_rows_count) { + return (_execute_number_expanded(block, arguments, result, + input_rows_count) || + ...); } template - static bool execute_integral_expanded(Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) { - return (execute_integral_impl(block, arguments, result, input_rows_count) || ...); + static bool _execute_number_expanded(Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) { + return (_execute_number_impl(block, arguments, result, input_rows_count) || ...); } - template - static bool execute_integral_impl(Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) { + template + static bool _execute_number_impl(Block& block, const ColumnNumbers& arguments, size_t result, + size_t input_rows_count) { // check array nested column type and get data - auto array_column = check_and_get_column(*block.get_by_position(arguments[0]).column); + auto array_column = + check_and_get_column(*block.get_by_position(arguments[0]).column); DCHECK(array_column != nullptr); - auto nested_column = check_and_get_column>(array_column->get_data()); + const ColumnVector* nested_column = nullptr; + const UInt8* nested_null_map = nullptr; + auto nested_null_column = check_and_get_column(array_column->get_data()); + if (nested_null_column) { + nested_null_map = nested_null_column->get_null_map_column().get_data().data(); + nested_column = check_and_get_column>( + nested_null_column->get_nested_column()); + } else { + nested_column = + check_and_get_column>(array_column->get_data()); + } if (!nested_column) { return false; } @@ -152,13 +181,15 @@ private: if (is_column_const(*ptr)) { ptr = check_and_get_column(ptr)->get_data_column_ptr(); } - if (!check_and_get_column>(*ptr)) { + if (!check_and_get_column>(*ptr)) { return false; } // expand const column and get data - auto right_column = block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); - const auto& right_data = check_and_get_column>(*right_column)->get_data(); + auto right_column = + block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); + const auto& right_data = + check_and_get_column>(*right_column)->get_data(); // prepare return data auto dst = ColumnVector::create(); @@ -171,6 +202,10 @@ private: size_t off = offsets[row - 1]; size_t len = offsets[row] - off; for (size_t pos = 0; pos < len; ++pos) { + if (nested_null_map && nested_null_map[pos + off]) { + continue; + } + if (nested_data[pos + off] == right_data[row]) { ConcreteAction::apply(res, pos); break; @@ -182,15 +217,20 @@ private: return true; } - Status execute_non_nullable(Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) { + Status _execute_non_nullable(Block& block, const ColumnNumbers& arguments, size_t result, + size_t input_rows_count) { WhichDataType right_type(block.get_by_position(arguments[1]).type); - if ((right_type.is_string() && execute_string(block, arguments, result, input_rows_count)) || - execute_integral(block, arguments, result, input_rows_count)) { + if ((right_type.is_string() && + _execute_string(block, arguments, result, input_rows_count)) || + _execute_number(block, arguments, result, input_rows_count)) { return Status::OK(); } - return Status::OK(); + return Status::RuntimeError( + fmt::format("unsupported types for function {}({}, {})", get_name(), + block.get_by_position(arguments[0]).type->get_name(), + block.get_by_position(arguments[1]).type->get_name())); } -#undef INTEGRAL_TPL_PACK +#undef NUMBER_TPL_PACK }; } // namespace doris::vectorized diff --git a/be/src/vec/functions/array/function_array_register.cpp b/be/src/vec/functions/array/function_array_register.cpp index e9ab7630fe..90bd5dcb5f 100644 --- a/be/src/vec/functions/array/function_array_register.cpp +++ b/be/src/vec/functions/array/function_array_register.cpp @@ -22,9 +22,11 @@ namespace doris::vectorized { +void register_function_array_element(SimpleFunctionFactory&); void register_function_array_index(SimpleFunctionFactory&); void register_function_array(SimpleFunctionFactory& factory) { + register_function_array_element(factory); register_function_array_index(factory); } diff --git a/be/src/vec/functions/function.cpp b/be/src/vec/functions/function.cpp index 5c671284a4..5c019d49bf 100644 --- a/be/src/vec/functions/function.cpp +++ b/be/src/vec/functions/function.cpp @@ -269,9 +269,7 @@ Status PreparedFunctionImpl::execute(FunctionContext* context, Block& block, // res.column = block_without_low_cardinality.safe_get_by_position(result).column; // } // } else - execute_without_low_cardinality_columns(context, block, args, result, input_rows_count, - dry_run); - return Status::OK(); + return execute_without_low_cardinality_columns(context, block, args, result, input_rows_count, dry_run); } void FunctionBuilderImpl::check_number_of_arguments(size_t number_of_arguments) const { diff --git a/be/test/vec/function/CMakeLists.txt b/be/test/vec/function/CMakeLists.txt index 827bfb889e..9b3c1162ff 100644 --- a/be/test/vec/function/CMakeLists.txt +++ b/be/test/vec/function/CMakeLists.txt @@ -18,6 +18,7 @@ # where to put generated libraries set(EXECUTABLE_OUTPUT_PATH "${BUILD_DIR}/test/vec/function") +ADD_BE_TEST(function_array_element_test) ADD_BE_TEST(function_array_index_test) ADD_BE_TEST(function_bitmap_test) ADD_BE_TEST(function_comparison_test) diff --git a/be/test/vec/function/function_array_element_test.cpp b/be/test/vec/function/function_array_element_test.cpp new file mode 100644 index 0000000000..67b953dafb --- /dev/null +++ b/be/test/vec/function/function_array_element_test.cpp @@ -0,0 +1,84 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include +#include + +#include + +#include "function_test_util.h" +#include "runtime/tuple_row.h" +#include "util/url_coding.h" +#include "vec/core/field.h" + +namespace doris::vectorized { + +TEST(function_array_element_test, element_at) { + std::string func_name = "element_at"; + Array empty_arr; + + // element_at(Array, Int32) + { + InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int32, TypeIndex::Int32}; + + Array vec = {Int32(1), Int32(2), Int32(3)}; + DataSet data_set = { + {{vec, 0}, Null()}, {{vec, 1}, Int32(1)}, {{vec, 4}, Null()}, + {{vec, -1}, Int32(3)}, {{vec, -3}, Int32(1)}, {{vec, -4}, Null()}, + {{Null(), 1}, Null()}, {{empty_arr, 0}, Null()}, {{empty_arr, 1}, Null()}}; + + check_function(func_name, input_types, data_set); + } + + // element_at(Array, Int32) + { + InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int8, TypeIndex::Int32}; + + Array vec = {Int8(1), Int8(2), Int8(3)}; + DataSet data_set = { + {{vec, 0}, Null()}, {{vec, 1}, Int8(1)}, {{vec, 4}, Null()}, + {{vec, -1}, Int8(3)}, {{vec, -3}, Int8(1)}, {{vec, -4}, Null()}, + {{Null(), 1}, Null()}, {{empty_arr, 0}, Null()}, {{empty_arr, 1}, Null()}}; + + check_function(func_name, input_types, data_set); + } + + // element_at(Array, Int32) + { + InputTypeSet input_types = {TypeIndex::Array, TypeIndex::String, TypeIndex::Int32}; + + Array vec = {Field("abc", 3), Field("", 0), Field("def", 3)}; + DataSet data_set = {{{vec, 1}, std::string("abc")}, + {{vec, 2}, std::string("")}, + {{vec, 10}, Null()}, + {{vec, -2}, std::string("")}, + {{vec, 0}, Null()}, + {{vec, -10}, Null()}, + {{Null(), 1}, Null()}, + {{empty_arr, 0}, Null()}, + {{empty_arr, 1}, Null()}}; + + check_function(func_name, input_types, data_set); + } +} + +} // namespace doris::vectorized + +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/be/test/vec/function/function_array_index_test.cpp b/be/test/vec/function/function_array_index_test.cpp index 7c34c3850c..311468df10 100644 --- a/be/test/vec/function/function_array_index_test.cpp +++ b/be/test/vec/function/function_array_index_test.cpp @@ -36,7 +36,10 @@ TEST(function_array_index_test, array_contains) { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int32, TypeIndex::Int32}; Array vec = {Int32(1), Int32(2), Int32(3)}; - DataSet data_set = {{{vec, 2}, UInt8(1)}, {{vec, 4}, UInt8(0)}, {{Null(), 1}, Null()}, {{empty_arr, 1}, UInt8(0)}}; + DataSet data_set = {{{vec, 2}, UInt8(1)}, + {{vec, 4}, UInt8(0)}, + {{Null(), 1}, Null()}, + {{empty_arr, 1}, UInt8(0)}}; check_function(func_name, input_types, data_set); } @@ -46,7 +49,10 @@ TEST(function_array_index_test, array_contains) { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int32, TypeIndex::Int8}; Array vec = {Int32(1), Int32(2), Int32(3)}; - DataSet data_set = {{{vec, Int8(2)}, UInt8(1)}, {{vec, Int8(4)}, UInt8(0)}, {{Null(), Int8(1)}, Null()}, {{empty_arr, Int8(1)}, UInt8(0)}}; + DataSet data_set = {{{vec, Int8(2)}, UInt8(1)}, + {{vec, Int8(4)}, UInt8(0)}, + {{Null(), Int8(1)}, Null()}, + {{empty_arr, Int8(1)}, UInt8(0)}}; check_function(func_name, input_types, data_set); } @@ -56,7 +62,10 @@ TEST(function_array_index_test, array_contains) { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int8, TypeIndex::Int64}; Array vec = {Int8(1), Int8(2), Int8(3)}; - DataSet data_set = {{{vec, Int64(2)}, UInt8(1)}, {{vec, Int64(4)}, UInt8(0)}, {{Null(), Int64(1)}, Null()}, {{empty_arr, Int64(1)}, UInt8(0)}}; + DataSet data_set = {{{vec, Int64(2)}, UInt8(1)}, + {{vec, Int64(4)}, UInt8(0)}, + {{Null(), Int64(1)}, Null()}, + {{empty_arr, Int64(1)}, UInt8(0)}}; check_function(func_name, input_types, data_set); } @@ -65,9 +74,12 @@ TEST(function_array_index_test, array_contains) { { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::String, TypeIndex::String}; - Array vec = {Field("abc", 3), Field("", 0), Field("def",3)}; - DataSet data_set = {{{vec, std::string("abc")}, UInt8(1)}, {{vec, std::string("aaa")}, UInt8(0)}, - {{vec, std::string("")}, UInt8(1)}, {{Null(), std::string("abc")}, Null()}, {{empty_arr, std::string("")}, UInt8(0)}}; + Array vec = {Field("abc", 3), Field("", 0), Field("def", 3)}; + DataSet data_set = {{{vec, std::string("abc")}, UInt8(1)}, + {{vec, std::string("aaa")}, UInt8(0)}, + {{vec, std::string("")}, UInt8(1)}, + {{Null(), std::string("abc")}, Null()}, + {{empty_arr, std::string("")}, UInt8(0)}}; check_function(func_name, input_types, data_set); } @@ -82,7 +94,10 @@ TEST(function_array_index_test, array_position) { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int32, TypeIndex::Int32}; Array vec = {Int32(1), Int32(2), Int32(3)}; - DataSet data_set = {{{vec, 2}, Int64(2)}, {{vec, 4}, Int64(0)}, {{Null(), 1}, Null()}, {{empty_arr, 1}, Int64(0)}}; + DataSet data_set = {{{vec, 2}, Int64(2)}, + {{vec, 4}, Int64(0)}, + {{Null(), 1}, Null()}, + {{empty_arr, 1}, Int64(0)}}; check_function(func_name, input_types, data_set); } @@ -92,7 +107,10 @@ TEST(function_array_index_test, array_position) { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int32, TypeIndex::Int8}; Array vec = {Int32(1), Int32(2), Int32(3)}; - DataSet data_set = {{{vec, Int8(2)}, Int64(2)}, {{vec, Int8(4)}, Int64(0)}, {{Null(), Int8(1)}, Null()}, {{empty_arr, Int8(1)}, Int64(0)}}; + DataSet data_set = {{{vec, Int8(2)}, Int64(2)}, + {{vec, Int8(4)}, Int64(0)}, + {{Null(), Int8(1)}, Null()}, + {{empty_arr, Int8(1)}, Int64(0)}}; check_function(func_name, input_types, data_set); } @@ -102,7 +120,10 @@ TEST(function_array_index_test, array_position) { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::Int8, TypeIndex::Int64}; Array vec = {Int8(1), Int8(2), Int8(3)}; - DataSet data_set = {{{vec, Int64(2)}, Int64(2)}, {{vec, Int64(4)}, Int64(0)}, {{Null(), Int64(1)}, Null()}, {{empty_arr, Int64(1)}, Int64(0)}}; + DataSet data_set = {{{vec, Int64(2)}, Int64(2)}, + {{vec, Int64(4)}, Int64(0)}, + {{Null(), Int64(1)}, Null()}, + {{empty_arr, Int64(1)}, Int64(0)}}; check_function(func_name, input_types, data_set); } @@ -111,9 +132,12 @@ TEST(function_array_index_test, array_position) { { InputTypeSet input_types = {TypeIndex::Array, TypeIndex::String, TypeIndex::String}; - Array vec = {Field("abc", 3), Field("", 0), Field("def",3)}; - DataSet data_set = {{{vec, std::string("abc")}, Int64(1)}, {{vec, std::string("aaa")}, Int64(0)}, - {{vec, std::string("")}, Int64(2)}, {{Null(), std::string("abc")}, Null()}, {{empty_arr, std::string("")}, Int64(0)}}; + Array vec = {Field("abc", 3), Field("", 0), Field("def", 3)}; + DataSet data_set = {{{vec, std::string("abc")}, Int64(1)}, + {{vec, std::string("aaa")}, Int64(0)}, + {{vec, std::string("")}, Int64(2)}, + {{Null(), std::string("abc")}, Null()}, + {{empty_arr, std::string("")}, Int64(0)}}; check_function(func_name, input_types, data_set); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java index 42e4f21a2f..fb05ed1d4e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionSet.java @@ -1100,6 +1100,10 @@ public class FunctionSet argsType = new ArrayList(); - for (PrimitiveType type : args) { - argsType.add(Type.fromPrimitiveType(type)); + for (Type type : args) { + argsType.add(type); } addBuiltin(ScalarFunction.createBuiltin( - fnName, Type.fromPrimitiveType(retType), nullableMode, argsType, varArgs, + fnName, retType, nullableMode, argsType, varArgs, symbol, prepareFnSymbol, closeFnSymbol, userVisible)); } public void addScalarAndVectorizedBuiltin(String fnName, String symbol, boolean userVisible, String prepareFnSymbol, String closeFnSymbol, - Function.NullableMode nullableMode, PrimitiveType retType, - boolean varArgs, PrimitiveType ... args) { + Function.NullableMode nullableMode, Type retType, + boolean varArgs, Type ... args) { ArrayList argsType = new ArrayList(); - for (PrimitiveType type : args) { - argsType.add(Type.fromPrimitiveType(type)); + for (Type type : args) { + argsType.add(type); } addBuiltinBothScalaAndVectorized(ScalarFunction.createBuiltin( - fnName, Type.fromPrimitiveType(retType), nullableMode, argsType, varArgs, + fnName, retType, nullableMode, argsType, varArgs, symbol, prepareFnSymbol, closeFnSymbol, userVisible)); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java index be3f69f992..53a2ab980e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/Type.java @@ -81,8 +81,9 @@ public abstract class Type { public static final ScalarType QUANTILE_STATE = new ScalarType(PrimitiveType.QUANTILE_STATE); // Only used for alias function, to represent any type in function args public static final ScalarType ALL = new ScalarType(PrimitiveType.ALL); - public static final MapType Map = new MapType(); + public static final MapType MAP = new MapType(); public static final ArrayType ARRAY = ArrayType.create(); + public static final StructType STRUCT = new StructType(); private static ArrayList integerTypes; private static ArrayList numericTypes; @@ -380,7 +381,9 @@ public abstract class Type { } if (t1.isComplexType() || t2.isComplexType()) { if (t1.isArrayType() && t2.isArrayType()) { - return true; + // Subtype of Array do not support cast now, for example: + // Array can not cast to Array + return t1.matchesType(t2); } else if (t1.isMapType() && t2.isMapType()) { return true; } else if (t1.isStructType() && t2.isStructType()) { diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 316e7b9a7a..0d63e49776 100755 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -114,12 +114,30 @@ visible_functions = [ [['array'], 'ARRAY', ['ARRAY', '...'], '', '', '', '', ''], [['array'], 'ARRAY', ['MAP', '...'], '', '', '', '', ''], [['array'], 'ARRAY', ['STRUCT', '...'], '', '', '', '', ''], - [['%element_extract%'], 'VARCHAR', ['ARRAY', 'INT'], '', '', '', '', ''], - [['%element_extract%'], 'VARCHAR', ['ARRAY', 'VARCHAR'], '', '', '', '', ''], [['%element_extract%'], 'VARCHAR', ['MAP', 'VARCHAR'], '', '', '', '', ''], [['%element_extract%'], 'VARCHAR', ['MAP', 'INT'], '', '', '', '', ''], [['%element_extract%'], 'VARCHAR', ['STRUCT', 'INT'], '', '', '', '', ''], [['%element_extract%'], 'VARCHAR', ['STRUCT', 'VARCHAR'], '', '', '', '', ''], + + [['element_at', '%element_extract%'], 'TINYINT', ['ARRAY_TINYINT', 'INT'], + '_ZN5doris10vectorized20FunctionArrayElement12execute_implEPN9doris_udf15FunctionContextERNS0_5BlockERKSt6vectorImSaImEEmm', + '', '', 'vec', 'ALWAYS_NULLABLE'], + [['element_at', '%element_extract%'], 'SMALLINT', ['ARRAY_SMALLINT', 'INT'], + '_ZN5doris10vectorized20FunctionArrayElement12execute_implEPN9doris_udf15FunctionContextERNS0_5BlockERKSt6vectorImSaImEEmm', + '', '', 'vec', 'ALWAYS_NULLABLE'], + [['element_at', '%element_extract%'], 'INT', ['ARRAY_INT', 'INT'], + '_ZN5doris10vectorized20FunctionArrayElement12execute_implEPN9doris_udf15FunctionContextERNS0_5BlockERKSt6vectorImSaImEEmm', + '', '', 'vec', 'ALWAYS_NULLABLE'], + [['element_at', '%element_extract%'], 'BIGINT', ['ARRAY_BIGINT', 'INT'], + '_ZN5doris10vectorized20FunctionArrayElement12execute_implEPN9doris_udf15FunctionContextERNS0_5BlockERKSt6vectorImSaImEEmm', + '', '', 'vec', 'ALWAYS_NULLABLE'], + [['element_at', '%element_extract%'], 'VARCHAR', ['ARRAY_VARCHAR', 'INT'], + '_ZN5doris10vectorized20FunctionArrayElement12execute_implEPN9doris_udf15FunctionContextERNS0_5BlockERKSt6vectorImSaImEEmm', + '', '', 'vec', 'ALWAYS_NULLABLE'], + [['element_at', '%element_extract%'], 'STRING', ['ARRAY_STRING', 'INT'], + '_ZN5doris10vectorized20FunctionArrayElement12execute_implEPN9doris_udf15FunctionContextERNS0_5BlockERKSt6vectorImSaImEEmm', + '', '', 'vec', 'ALWAYS_NULLABLE'], + [['array_contains'], 'BOOLEAN', ['ARRAY', 'TINYINT'], '_ZN5doris10vectorized18FunctionArrayIndexINS0_19ArrayContainsActionENS0_17NameArrayContainsEE12execute_implEPN9doris_udf15FunctionContextERNS0_5BlockERKSt6vectorImSaImEEmm', '', '', 'vec', ''], diff --git a/gensrc/script/gen_builtins_functions.py b/gensrc/script/gen_builtins_functions.py index 9d212312b9..1f9f2d7570 100755 --- a/gensrc/script/gen_builtins_functions.py +++ b/gensrc/script/gen_builtins_functions.py @@ -52,7 +52,8 @@ java_registry_preamble = '\ \n\ package org.apache.doris.builtins;\n\ \n\ -import org.apache.doris.catalog.PrimitiveType;\n\ +import org.apache.doris.catalog.ArrayType;\n\ +import org.apache.doris.catalog.Type;\n\ import org.apache.doris.catalog.Function;\n\ import org.apache.doris.catalog.FunctionSet;\n\ import com.google.common.collect.Sets;\n\ @@ -100,6 +101,20 @@ def add_function(fn_meta_data, user_visible): meta_data_entries.append(entry) +""" +generate fe data type, support nested ARRAY type. +for example: + in[TINYINT] --> out[Type.TINYINT] + in[INT] --> out[Type.INT] + in[ARRAY_INT] --> out[new ArrayType(Type.INT)] +""" +def generate_fe_datatype(str_type): + if str_type.startswith("ARRAY_"): + vec_type = str_type.split('_', 1); + if len(vec_type) > 1 and vec_type[0] == "ARRAY": + return "new ArrayType(" + generate_fe_datatype(vec_type[1]) + ")" + return "Type." + str_type + """ Order of params: name, symbol, user_visible, prepare, close, nullable_mode, ret_type, has_var_args, args @@ -124,7 +139,7 @@ def generate_fe_entry(entry, name): java_output += ', null' java_output += ", Function.NullableMode." + entry["nullable_mode"] - java_output += ", PrimitiveType." + entry["ret_type"] + java_output += ", " + generate_fe_datatype(entry["ret_type"]) # Check the last entry for varargs indicator. if entry["args"] and entry["args"][-1] == "...": @@ -133,7 +148,7 @@ def generate_fe_entry(entry, name): else: java_output += ", false" for arg in entry["args"]: - java_output += ", PrimitiveType." + arg + java_output += ", " + generate_fe_datatype(arg) return java_output # Generates the FE builtins init file that registers all the builtins.