From 48aaaa80058a86ce93eefed28803123dfb19a88b Mon Sep 17 00:00:00 2001 From: koarz <66543806+koarz@users.noreply.github.com> Date: Sun, 4 Feb 2024 22:19:30 +0800 Subject: [PATCH] [Enhancement](fuction) change function REPEAT nullable mode (#30743) --- be/src/vec/functions/function_string.cpp | 1 + be/src/vec/functions/function_string.h | 123 +++++++++++++++--- .../expressions/functions/scalar/Repeat.java | 4 +- gensrc/script/doris_builtins_functions.py | 4 +- 4 files changed, 112 insertions(+), 20 deletions(-) diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index 6965139a1c..c5ce208d26 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -1021,6 +1021,7 @@ void register_function_string(SimpleFunctionFactory& factory) { factory.register_alternative_function(); factory.register_alternative_function(); factory.register_alternative_function(); + factory.register_alternative_function(); factory.register_alias(FunctionLeft::name, "strleft"); factory.register_alias(FunctionRight::name, "strright"); diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 6fc84074dd..4794d28e0e 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -1031,7 +1031,6 @@ public: DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { return std::make_shared(); } - bool use_default_implementation_for_nulls() const override { return true; } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { @@ -1051,7 +1050,7 @@ public: for (int i = 0; i < argument_size; ++i) { argument_columns[i] = block.get_by_position(arguments[i]).column->convert_to_full_column_if_const(); - auto col_str = assert_cast(argument_columns[i].get()); + const auto* col_str = assert_cast(argument_columns[i].get()); offsets_list[i] = &col_str->get_offsets(); chars_list[i] = &col_str->get_chars(); } @@ -1084,8 +1083,8 @@ public: for (size_t i = 0; i < input_rows_count; ++i) { int current_length = 0; for (size_t j = 0; j < offsets_list.size(); ++j) { - auto& current_offsets = *offsets_list[j]; - auto& current_chars = *chars_list[j]; + const auto& current_offsets = *offsets_list[j]; + const auto& current_chars = *chars_list[j]; int size = current_offsets[i] - current_offsets[i - 1]; if (size > 0) { @@ -1431,6 +1430,103 @@ public: String get_name() const override { return name; } size_t get_number_of_arguments() const override { return 2; } + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return std::make_shared(); + } + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + DCHECK_EQ(arguments.size(), 2); + auto res = ColumnString::create(); + + ColumnPtr argument_ptr[2]; + argument_ptr[0] = + block.get_by_position(arguments[0]).column->convert_to_full_column_if_const(); + argument_ptr[1] = block.get_by_position(arguments[1]).column; + + if (auto* col1 = check_and_get_column(*argument_ptr[0])) { + if (auto* col2 = check_and_get_column(*argument_ptr[1])) { + vector_vector(col1->get_chars(), col1->get_offsets(), col2->get_data(), + res->get_chars(), res->get_offsets(), + context->state()->repeat_max_num()); + block.replace_by_position(result, std::move(res)); + return Status::OK(); + } else if (auto* col2_const = check_and_get_column(*argument_ptr[1])) { + DCHECK(check_and_get_column(col2_const->get_data_column())); + int repeat = 0; + repeat = std::min(col2_const->get_int(0), context->state()->repeat_max_num()); + + if (repeat <= 0) { + res->insert_many_defaults(input_rows_count); + } else { + vector_const(col1->get_chars(), col1->get_offsets(), repeat, res->get_chars(), + res->get_offsets()); + } + block.replace_by_position(result, std::move(res)); + return Status::OK(); + } + } + + return Status::RuntimeError("repeat function get error param: {}, {}", + argument_ptr[0]->get_name(), argument_ptr[1]->get_name()); + } + + void vector_vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, + const ColumnInt32::Container& repeats, ColumnString::Chars& res_data, + ColumnString::Offsets& res_offsets, const int repeat_max_num) const { + size_t input_row_size = offsets.size(); + + fmt::memory_buffer buffer; + res_offsets.resize(input_row_size); + for (ssize_t i = 0; i < input_row_size; ++i) { + buffer.clear(); + const char* raw_str = reinterpret_cast(&data[offsets[i - 1]]); + size_t size = offsets[i] - offsets[i - 1]; + int repeat = 0; + repeat = std::min(repeats[i], repeat_max_num); + + if (repeat <= 0) { + StringOP::push_empty_string(i, res_data, res_offsets); + } else { + for (int j = 0; j < repeat; ++j) { + buffer.append(raw_str, raw_str + size); + } + StringOP::push_value_string(std::string_view(buffer.data(), buffer.size()), i, + res_data, res_offsets); + } + } + } + + // TODO: 1. use pmr::vector replace fmt_buffer may speed up the code + // 2. abstract the `vector_vector` and `vector_const` + // 3. rethink we should use `DEFAULT_MAX_STRING_SIZE` to bigger here + void vector_const(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, + int repeat, ColumnString::Chars& res_data, + ColumnString::Offsets& res_offsets) const { + size_t input_row_size = offsets.size(); + + fmt::memory_buffer buffer; + res_offsets.resize(input_row_size); + for (ssize_t i = 0; i < input_row_size; ++i) { + buffer.clear(); + const char* raw_str = reinterpret_cast(&data[offsets[i - 1]]); + size_t size = offsets[i] - offsets[i - 1]; + + for (int j = 0; j < repeat; ++j) { + buffer.append(raw_str, raw_str + size); + } + StringOP::push_value_string(std::string_view(buffer.data(), buffer.size()), i, res_data, + res_offsets); + } + } +}; + +class FunctionStringRepeatOld : public IFunction { +public: + static constexpr auto name = "repeat"; + static FunctionPtr create() { return std::make_shared(); } + String get_name() const override { return name; } + size_t get_number_of_arguments() const override { return 2; } + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { return make_nullable(std::make_shared()); } @@ -1545,7 +1641,6 @@ public: DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { return make_nullable(std::make_shared()); } - bool use_default_implementation_for_nulls() const override { return true; } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { @@ -1688,8 +1783,6 @@ public: return make_nullable(std::make_shared()); } - bool use_default_implementation_for_nulls() const override { return true; } - Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { DCHECK_EQ(arguments.size(), 3); @@ -2003,7 +2096,7 @@ public: class FunctionSubstringIndexOld : public IFunction { public: static constexpr auto name = "substring_index"; - static FunctionPtr create() { return std::make_shared(); } + static FunctionPtr create() { return std::make_shared(); } String get_name() const override { return name; } size_t get_number_of_arguments() const override { return 3; } @@ -2160,7 +2253,6 @@ public: return Status::OK(); } }; - class FunctionSplitByString : public IFunction { public: static constexpr auto name = "split_by_string"; @@ -2205,17 +2297,17 @@ public: dest_offsets.reserve(0); NullMapType* dest_nested_null_map = nullptr; - ColumnNullable* dest_nullable_col = reinterpret_cast(dest_nested_column); + auto* dest_nullable_col = reinterpret_cast(dest_nested_column); dest_nested_column = dest_nullable_col->get_nested_column_ptr(); dest_nested_null_map = &dest_nullable_col->get_null_map_column().get_data(); - auto col_left = check_and_get_column(src_column.get()); + const auto* col_left = check_and_get_column(src_column.get()); if (!col_left) { return Status::InternalError("Left operator of function {} can not be {}", get_name(), src_column_type->get_name()); } - auto col_right = check_and_get_column(right_column.get()); + const auto* col_right = check_and_get_column(right_column.get()); if (!col_right) { return Status::InternalError("Right operator of function {} can not be {}", get_name(), right_column_type->get_name()); @@ -2245,7 +2337,7 @@ private: const StringRef& delimiter_ref, IColumn& dest_nested_column, ColumnArray::Offsets64& dest_offsets, NullMapType* dest_nested_null_map) const { - ColumnString& dest_column_string = reinterpret_cast(dest_nested_column); + auto& dest_column_string = reinterpret_cast(dest_nested_column); ColumnString::Chars& column_string_chars = dest_column_string.get_chars(); ColumnString::Offsets& column_string_offsets = dest_column_string.get_offsets(); column_string_chars.reserve(0); @@ -2312,7 +2404,7 @@ private: const ColumnString& delimiter_column, IColumn& dest_nested_column, ColumnArray::Offsets64& dest_offsets, NullMapType* dest_nested_null_map) const { - ColumnString& dest_column_string = reinterpret_cast(dest_nested_column); + auto& dest_column_string = reinterpret_cast(dest_nested_column); ColumnString::Chars& column_string_chars = dest_column_string.get_chars(); ColumnString::Offsets& column_string_offsets = dest_column_string.get_offsets(); column_string_chars.reserve(0); @@ -2369,7 +2461,7 @@ private: IColumn& dest_nested_column, ColumnArray::Offsets64& dest_offsets, NullMapType* dest_nested_null_map) const { - ColumnString& dest_column_string = reinterpret_cast(dest_nested_column); + auto& dest_column_string = reinterpret_cast(dest_nested_column); ColumnString::Chars& column_string_chars = dest_column_string.get_chars(); ColumnString::Offsets& column_string_offsets = dest_column_string.get_offsets(); column_string_chars.reserve(0); @@ -2659,7 +2751,6 @@ public: DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { return make_nullable(std::make_shared()); } - bool use_default_implementation_for_nulls() const override { return true; } Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) const override { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java index b85a812197..918443e816 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Repeat.java @@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.BinaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.IntegerType; @@ -35,7 +35,7 @@ import java.util.List; * ScalarFunction 'repeat'. This class is generated by GenerateFunction. */ public class Repeat extends ScalarFunction - implements BinaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + implements BinaryExpression, ExplicitlyCastableSignature, PropagateNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(StringType.INSTANCE).args(StringType.INSTANCE, IntegerType.INSTANCE) diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 722715da2c..bd52ffe789 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -1564,7 +1564,7 @@ visible_functions = { [['null_or_empty'], 'BOOLEAN', ['VARCHAR'], 'ALWAYS_NOT_NULLABLE'], [['not_null_or_empty'], 'BOOLEAN', ['VARCHAR'], 'ALWAYS_NOT_NULLABLE'], [['space'], 'VARCHAR', ['INT'], ''], - [['repeat'], 'VARCHAR', ['VARCHAR', 'INT'], 'ALWAYS_NULLABLE'], + [['repeat'], 'VARCHAR', ['VARCHAR', 'INT'], 'DEPEND_ON_ARGUMENT'], [['lpad'], 'VARCHAR', ['VARCHAR', 'INT', 'VARCHAR'], 'ALWAYS_NULLABLE'], [['rpad'], 'VARCHAR', ['VARCHAR', 'INT', 'VARCHAR'], 'ALWAYS_NULLABLE'], [['append_trailing_char_if_absent'], 'VARCHAR', ['VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'], @@ -1624,7 +1624,7 @@ visible_functions = { [['null_or_empty'], 'BOOLEAN', ['STRING'], 'ALWAYS_NOT_NULLABLE'], [['not_null_or_empty'], 'BOOLEAN', ['STRING'], 'ALWAYS_NOT_NULLABLE'], [['space'], 'STRING', ['INT'], ''], - [['repeat'], 'STRING', ['STRING', 'INT'], 'ALWAYS_NULLABLE'], + [['repeat'], 'STRING', ['STRING', 'INT'], 'DEPEND_ON_ARGUMENT'], [['lpad'], 'STRING', ['STRING', 'INT', 'STRING'], 'ALWAYS_NULLABLE'], [['rpad'], 'STRING', ['STRING', 'INT', 'STRING'], 'ALWAYS_NULLABLE'], [['append_trailing_char_if_absent'], 'STRING', ['STRING', 'STRING'], 'ALWAYS_NULLABLE'],