diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index d87969fb76..280eb93b2a 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -174,37 +174,132 @@ struct NameInstr { static constexpr auto name = "instr"; }; +// LeftDataType and RightDataType are DataTypeString +template +struct StringInStrImpl { + using ResultDataType = DataTypeInt32; + using ResultPaddedPODArray = PaddedPODArray; + + static Status scalar_vector(const StringRef& ldata, const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + StringRef lstr_ref(ldata.data, ldata.size); + + auto size = roffsets.size(); + res.resize(size); + for (int i = 0; i < size; ++i) { + const char* r_raw_str = reinterpret_cast(&rdata[roffsets[i - 1]]); + int r_str_size = roffsets[i] - roffsets[i - 1]; + + StringRef rstr_ref(r_raw_str, r_str_size); + + res[i] = execute(lstr_ref, rstr_ref); + } + + return Status::OK(); + } + + static Status vector_scalar(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, const StringRef& rdata, + ResultPaddedPODArray& res) { + auto size = loffsets.size(); + res.resize(size); + + if (rdata.size == 0) { + for (int i = 0; i < size; ++i) { + res[i] = 1; + } + return Status::OK(); + } + + StringRef rstr_ref(rdata.data, rdata.size); + StringSearch search(&rstr_ref); + + for (int i = 0; i < size; ++i) { + const char* l_raw_str = reinterpret_cast(&ldata[loffsets[i - 1]]); + int l_str_size = loffsets[i] - loffsets[i - 1]; + + StringRef lstr_ref(l_raw_str, l_str_size); + + // Hive returns positions starting from 1. + int loc = search.search(&lstr_ref); + if (loc > 0) { + loc = get_char_len(lstr_ref, loc); + } + res[i] = loc + 1; + } + + return Status::OK(); + } + + static Status vector_vector(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, + const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + DCHECK_EQ(loffsets.size(), roffsets.size()); + + auto size = loffsets.size(); + res.resize(size); + for (int i = 0; i < size; ++i) { + const char* l_raw_str = reinterpret_cast(&ldata[loffsets[i - 1]]); + int l_str_size = loffsets[i] - loffsets[i - 1]; + StringRef lstr_ref(l_raw_str, l_str_size); + + const char* r_raw_str = reinterpret_cast(&rdata[roffsets[i - 1]]); + int r_str_size = roffsets[i] - roffsets[i - 1]; + StringRef rstr_ref(r_raw_str, r_str_size); + + res[i] = execute(lstr_ref, rstr_ref); + } + + return Status::OK(); + } + + static int execute(const StringRef& strl, const StringRef& strr) { + if (strr.size == 0) { + return 1; + } + + StringSearch search(&strr); + // Hive returns positions starting from 1. + int loc = search.search(&strl); + if (loc > 0) { + loc = get_char_len(strl, loc); + } + + return loc + 1; + } +}; + // the same impl as instr struct NameLocate { static constexpr auto name = "locate"; }; -struct InStrOP { +// LeftDataType and RightDataType are DataTypeString +template +struct StringLocateImpl { using ResultDataType = DataTypeInt32; using ResultPaddedPODArray = PaddedPODArray; - static void execute(const std::string_view& strl, const std::string_view& strr, int32_t& res) { - if (strr.length() == 0) { - res = 1; - return; - } - StringRef str_sv(strl.data(), strl.length()); - StringRef substr_sv(strr.data(), strr.length()); - StringSearch search(&substr_sv); - // Hive returns positions starting from 1. - int loc = search.search(&str_sv); - if (loc > 0) { - loc = get_char_len(str_sv, loc); - } - - res = loc + 1; + static Status scalar_vector(const StringRef& ldata, const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + return StringInStrImpl::vector_scalar(rdata, roffsets, ldata, + res); } -}; -struct LocateOP { - using ResultDataType = DataTypeInt32; - using ResultPaddedPODArray = PaddedPODArray; - static void execute(const std::string_view& strl, const std::string_view& strr, int32_t& res) { - InStrOP::execute(strr, strl, res); + + static Status vector_scalar(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, const StringRef& rdata, + ResultPaddedPODArray& res) { + return StringInStrImpl::scalar_vector(rdata, ldata, loffsets, + res); + } + + static Status vector_vector(const ColumnString::Chars& ldata, + const ColumnString::Offsets& loffsets, + const ColumnString::Chars& rdata, + const ColumnString::Offsets& roffsets, ResultPaddedPODArray& res) { + return StringInStrImpl::vector_vector(rdata, roffsets, ldata, + loffsets, res); } }; @@ -783,12 +878,6 @@ using StringStartsWithImpl = StringFunctionImpl using StringEndsWithImpl = StringFunctionImpl; -template -using StringInstrImpl = StringFunctionImpl; - -template -using StringLocateImpl = StringFunctionImpl; - template using StringFindInSetImpl = StringFunctionImpl; @@ -802,7 +891,7 @@ using FunctionStringStartsWith = using FunctionStringEndsWith = FunctionBinaryToType; using FunctionStringInstr = - FunctionBinaryToType; + FunctionBinaryToType; using FunctionStringLocate = FunctionBinaryToType; using FunctionStringFindInSet = diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index 8f9ba90ae8..5bc415dd2a 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -2325,30 +2325,95 @@ public: Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, size_t result, size_t input_rows_count) override { - auto col_substr = - block.get_by_position(arguments[0]).column->convert_to_full_column_if_const(); - auto col_str = - block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); - auto col_pos = - block.get_by_position(arguments[2]).column->convert_to_full_column_if_const(); + DCHECK_EQ(arguments.size(), 3); + bool col_const[3]; + ColumnPtr argument_columns[3]; + for (int i = 0; i < 3; ++i) { + col_const[i] = is_column_const(*block.get_by_position(arguments[i]).column); + } + argument_columns[2] = col_const[2] ? static_cast( + *block.get_by_position(arguments[2]).column) + .convert_to_full_column() + : block.get_by_position(arguments[2]).column; + default_preprocess_parameter_columns(argument_columns, col_const, {0, 1}, block, arguments); + + auto col_left = assert_cast(argument_columns[0].get()); + auto col_right = assert_cast(argument_columns[1].get()); + auto col_pos = assert_cast*>(argument_columns[2].get()); ColumnInt32::MutablePtr col_res = ColumnInt32::create(); - - auto& vec_pos = reinterpret_cast(col_pos.get())->get_data(); auto& vec_res = col_res->get_data(); - vec_res.resize(input_rows_count); + vec_res.resize(block.rows()); - for (int i = 0; i < input_rows_count; ++i) { - vec_res[i] = - locate_pos(col_substr->get_data_at(i), col_str->get_data_at(i), vec_pos[i]); + if (col_const[0] && col_const[1]) { + scalar_search(col_left->get_data_at(0), col_right, col_pos->get_data(), vec_res); + } else if (col_const[0] && !col_const[1]) { + scalar_search(col_left->get_data_at(0), col_right, col_pos->get_data(), vec_res); + } else if (!col_const[0] && col_const[1]) { + vector_search(col_left, col_right, col_pos->get_data(), vec_res); + } else { + vector_search(col_left, col_right, col_pos->get_data(), vec_res); } - block.replace_by_position(result, std::move(col_res)); return Status::OK(); } private: - int locate_pos(StringRef substr, StringRef str, int start_pos) { + template + void scalar_search(const StringRef& ldata, const ColumnString* col_right, + const PaddedPODArray& posdata, PaddedPODArray& res) { + const ColumnString::Chars& rdata = col_right->get_chars(); + const ColumnString::Offsets& roffsets = col_right->get_offsets(); + + auto size = posdata.size(); + res.resize(size); + StringRef substr(ldata.data, ldata.size); + std::shared_ptr search_ptr(new StringSearch(&substr)); + + for (int i = 0; i < size; ++i) { + if constexpr (!Const) { + const char* r_raw_str = reinterpret_cast(&rdata[roffsets[i - 1]]); + int r_str_size = roffsets[i] - roffsets[i - 1]; + + StringRef str(r_raw_str, r_str_size); + res[i] = locate_pos(substr, str, search_ptr, posdata[i]); + } else { + res[i] = locate_pos(substr, col_right->get_data_at(0), search_ptr, posdata[i]); + } + } + } + + template + void vector_search(const ColumnString* col_left, const ColumnString* col_right, + const PaddedPODArray& posdata, PaddedPODArray& res) { + const ColumnString::Chars& rdata = col_right->get_chars(); + const ColumnString::Offsets& roffsets = col_right->get_offsets(); + + const ColumnString::Chars& ldata = col_left->get_chars(); + const ColumnString::Offsets& loffsets = col_left->get_offsets(); + + auto size = posdata.size(); + res.resize(size); + std::shared_ptr search_ptr; + for (int i = 0; i < size; ++i) { + const char* l_raw_str = reinterpret_cast(&ldata[loffsets[i - 1]]); + int l_str_size = loffsets[i] - loffsets[i - 1]; + + StringRef substr(l_raw_str, l_str_size); + if constexpr (!Const) { + const char* r_raw_str = reinterpret_cast(&rdata[roffsets[i - 1]]); + int r_str_size = roffsets[i] - roffsets[i - 1]; + + StringRef str(r_raw_str, r_str_size); + res[i] = locate_pos(substr, str, search_ptr, posdata[i]); + } else { + res[i] = locate_pos(substr, col_right->get_data_at(0), search_ptr, posdata[i]); + } + } + } + + int locate_pos(StringRef substr, StringRef str, std::shared_ptr search_ptr, + int start_pos) { if (substr.size == 0) { if (start_pos <= 0) { return 0; @@ -2368,11 +2433,12 @@ private: if (start_pos <= 0 || start_pos > str.size || start_pos > char_len) { return 0; } - StringRef substr_sv = StringRef(substr); - StringSearch search(&substr_sv); + if (!search_ptr) { + search_ptr.reset(new StringSearch(&substr)); + } // Input start_pos starts from 1. StringRef adjusted_str(str.data + index[start_pos - 1], str.size - index[start_pos - 1]); - int32_t match_pos = search.search(&adjusted_str); + int32_t match_pos = search_ptr->search(&adjusted_str); if (match_pos >= 0) { // Hive returns the position in the original string starting from 1. return start_pos + get_char_len(adjusted_str, match_pos);