diff --git a/be/src/util/url_coding.cpp b/be/src/util/url_coding.cpp index b1119db058..3f914d720c 100644 --- a/be/src/util/url_coding.cpp +++ b/be/src/util/url_coding.cpp @@ -38,7 +38,9 @@ void url_encode(const std::string_view& in, std::string* out) { } else if (c == ' ') { os << '+'; } else { - os << '%' << to_hex(c >> 4) << to_hex(c % 16); + ///TODO: In the past, there was an error here involving the modulus operation on a char (signed number). + // When the char data exceeds 128 (UTF-8 byte), it leads to incorrect results. It is actually better to use some third-party libraries here. + os << '%' << to_hex((unsigned char)c >> 4) << to_hex((unsigned char)c % 16); } } diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index 921a0f689f..be8d06e1e7 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -1219,6 +1219,15 @@ struct StringAppendTrailingCharIfAbsent { using Offsets = ColumnString::Offsets; using ReturnType = DataTypeString; using ColumnType = ColumnString; + + static bool str_end_with(const StringRef& str, const StringRef& end) { + if (str.size < end.size) { + return false; + } + // The end_with method of StringRef needs to ensure that the size of end is less than or equal to the size of str. + return str.end_with(end); + } + static void vector_vector(FunctionContext* context, const Chars& ldata, const Offsets& loffsets, const Chars& rdata, const Offsets& roffsets, Chars& res_data, Offsets& res_offsets, NullMap& null_map_data) { @@ -1230,36 +1239,39 @@ struct StringAppendTrailingCharIfAbsent { for (size_t i = 0; i < input_rows_count; ++i) { buffer.clear(); - int l_size = loffsets[i] - loffsets[i - 1]; - const auto l_raw = reinterpret_cast(&ldata[loffsets[i - 1]]); + StringRef lstr = StringRef(reinterpret_cast(&ldata[loffsets[i - 1]]), + loffsets[i] - loffsets[i - 1]); + StringRef rstr = StringRef(reinterpret_cast(&rdata[roffsets[i - 1]]), + roffsets[i] - roffsets[i - 1]); + // The iterate_utf8_with_limit_length function iterates over a maximum of two UTF-8 characters. + auto [byte_len, char_len] = simd::VStringFunctions::iterate_utf8_with_limit_length( + rstr.begin(), rstr.end(), 2); - int r_size = roffsets[i] - roffsets[i - 1]; - const auto r_raw = reinterpret_cast(&rdata[roffsets[i - 1]]); - - if (r_size != 1) { + if (char_len != 1) { StringOP::push_null_string(i, res_data, res_offsets, null_map_data); continue; } - if (l_raw[l_size - 1] == r_raw[0]) { - StringOP::push_value_string(std::string_view(l_raw, l_size), i, res_data, - res_offsets); + if (str_end_with(lstr, rstr)) { + StringOP::push_value_string(lstr, i, res_data, res_offsets); continue; } - buffer.append(l_raw, l_raw + l_size); - buffer.append(r_raw, r_raw + 1); + buffer.append(lstr.begin(), lstr.end()); + buffer.append(rstr.begin(), rstr.end()); StringOP::push_value_string(std::string_view(buffer.data(), buffer.size()), i, res_data, res_offsets); } } static void vector_scalar(FunctionContext* context, const Chars& ldata, const Offsets& loffsets, - const StringRef& rdata, Chars& res_data, Offsets& res_offsets, + const StringRef& rstr, Chars& res_data, Offsets& res_offsets, NullMap& null_map_data) { size_t input_rows_count = loffsets.size(); res_offsets.resize(input_rows_count); fmt::memory_buffer buffer; - - if (rdata.size != 1) { + // The iterate_utf8_with_limit_length function iterates over a maximum of two UTF-8 characters. + auto [byte_len, char_len] = + simd::VStringFunctions::iterate_utf8_with_limit_length(rstr.begin(), rstr.end(), 2); + if (char_len != 1) { for (size_t i = 0; i < input_rows_count; ++i) { StringOP::push_null_string(i, res_data, res_offsets, null_map_data); } @@ -1268,23 +1280,21 @@ struct StringAppendTrailingCharIfAbsent { for (size_t i = 0; i < input_rows_count; ++i) { buffer.clear(); + StringRef lstr = StringRef(reinterpret_cast(&ldata[loffsets[i - 1]]), + loffsets[i] - loffsets[i - 1]); - int l_size = loffsets[i] - loffsets[i - 1]; - const auto l_raw = reinterpret_cast(&ldata[loffsets[i - 1]]); - - if (l_raw[l_size - 1] == rdata.data[0]) { - StringOP::push_value_string(std::string_view(l_raw, l_size), i, res_data, - res_offsets); + if (str_end_with(lstr, rstr)) { + StringOP::push_value_string(lstr, i, res_data, res_offsets); continue; } - buffer.append(l_raw, l_raw + l_size); - buffer.append(rdata.begin(), rdata.end()); + buffer.append(lstr.begin(), lstr.end()); + buffer.append(rstr.begin(), rstr.end()); StringOP::push_value_string(std::string_view(buffer.data(), buffer.size()), i, res_data, res_offsets); } } - static void scalar_vector(FunctionContext* context, const StringRef& ldata, const Chars& rdata, + static void scalar_vector(FunctionContext* context, const StringRef& lstr, const Chars& rdata, const Offsets& roffsets, Chars& res_data, Offsets& res_offsets, NullMap& null_map_data) { size_t input_rows_count = roffsets.size(); @@ -1294,20 +1304,23 @@ struct StringAppendTrailingCharIfAbsent { for (size_t i = 0; i < input_rows_count; ++i) { buffer.clear(); - int r_size = roffsets[i] - roffsets[i - 1]; - const auto r_raw = reinterpret_cast(&rdata[roffsets[i - 1]]); + StringRef rstr = StringRef(reinterpret_cast(&rdata[roffsets[i - 1]]), + roffsets[i] - roffsets[i - 1]); + // The iterate_utf8_with_limit_length function iterates over a maximum of two UTF-8 characters. + auto [byte_len, char_len] = simd::VStringFunctions::iterate_utf8_with_limit_length( + rstr.begin(), rstr.end(), 2); - if (r_size != 1) { + if (char_len != 1) { StringOP::push_null_string(i, res_data, res_offsets, null_map_data); continue; } - if (ldata.size == 0 || ldata.back() == r_raw[0]) { - StringOP::push_value_string(ldata.to_string_view(), i, res_data, res_offsets); + if (str_end_with(lstr, rstr)) { + StringOP::push_value_string(lstr, i, res_data, res_offsets); continue; } - buffer.append(ldata.begin(), ldata.end()); - buffer.append(r_raw, r_raw + 1); + buffer.append(lstr.begin(), lstr.end()); + buffer.append(rstr.begin(), rstr.end()); StringOP::push_value_string(std::string_view(buffer.data(), buffer.size()), i, res_data, res_offsets); } diff --git a/be/src/vec/functions/function_string.h b/be/src/vec/functions/function_string.h index a17f4b847d..69613662a4 100644 --- a/be/src/vec/functions/function_string.h +++ b/be/src/vec/functions/function_string.h @@ -1165,14 +1165,14 @@ public: auto str_col = block.get_by_position(arguments[0]).column->convert_to_full_column_if_const(); - const auto& str_offset = assert_cast(str_col.get())->get_offsets(); - + const auto* str_column = assert_cast(str_col.get()); auto pos_col = block.get_by_position(arguments[1]).column->convert_to_full_column_if_const(); const auto& pos_data = assert_cast(pos_col.get())->get_data(); for (int i = 0; i < input_rows_count; ++i) { - strlen_data[i] = str_offset[i] - str_offset[i - 1]; + auto str = str_column->get_data_at(i); + strlen_data[i] = simd::VStringFunctions::get_char_len(str.data, str.size); } for (int i = 0; i < input_rows_count; ++i) { diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index 224adc1937..3b0f0e2308 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -84,12 +84,34 @@ TEST(function_string_test, function_string_strright_test) { std::string func_name = "strright"; InputTypeSet input_types = {TypeIndex::String, TypeIndex::Int32}; - DataSet data_set = {{{std::string("asd"), 1}, std::string("d")}, - {{std::string("hello word"), -2}, std::string("ello word")}, - {{std::string("hello word"), 20}, std::string("hello word")}, - {{std::string("HELLO,!^%"), 2}, std::string("^%")}, - {{std::string(""), 3}, std::string("")}, - {{Null(), 3}, Null()}}; + DataSet data_set = { + {{std::string("asd"), 1}, std::string("d")}, + {{std::string("hello word"), -2}, std::string("ello word")}, + {{std::string("hello word"), 20}, std::string("hello word")}, + {{std::string("HELLO,!^%"), 2}, std::string("^%")}, + {{std::string(""), 3}, std::string("")}, + {{Null(), 3}, Null()}, + {{std::string("12345"), 10}, std::string("12345")}, + {{std::string("12345"), -10}, std::string("")}, + {{std::string(""), Null()}, Null()}, + {{Null(), -100}, Null()}, + {{std::string("12345"), 12345}, std::string("12345")}, + {{std::string(""), 1}, std::string()}, + {{std::string("a b c d _ %"), -3}, std::string("b c d _ %")}, + {{std::string(""), Null()}, Null()}, + {{std::string("hah hah"), -1}, std::string("hah hah")}, + {{std::string("🤣"), -1}, std::string("🤣")}, + {{std::string("🤣😃😄"), -2}, std::string("😃😄")}, + {{std::string("🐼abc🐼"), 100}, std::string("🐼abc🐼")}, + {{std::string("你好世界"), 5}, std::string("你好世界")}, + {{std::string("12345"), 6}, std::string("12345")}, + {{std::string("12345"), 12345}, std::string("12345")}, + {{std::string("-12345"), -1}, std::string("-12345")}, + {{std::string("-12345"), -12345}, std::string()}, + {{Null(), -12345}, Null()}, + {{std::string("😡"), Null()}, Null()}, + {{std::string("🤣"), 0}, std::string()}, + }; static_cast(check_function(func_name, input_types, data_set)); } @@ -248,7 +270,34 @@ TEST(function_string_test, function_append_trailing_char_if_absent_test) { DataSet data_set = {{{std::string("ASD"), std::string("D")}, std::string("ASD")}, {{std::string("AS"), std::string("D")}, std::string("ASD")}, {{std::string(""), std::string("")}, Null()}, - {{std::string(""), std::string("A")}, std::string("A")}}; + {{std::string(""), std::string("A")}, std::string("A")}, + {{std::string("AC"), std::string("BACBAC")}, Null()}, + {{Null(), Null()}, Null()}, + {{std::string("ABC"), Null()}, Null()}, + {{Null(), std::string("ABC")}, Null()}, + {{std::string(""), Null()}, Null()}, + {{std::string("中文"), std::string("文")}, std::string("中文")}, + {{std::string("中"), std::string("文")}, std::string("中文")}, + {{std::string(""), std::string("文")}, std::string("文")}, + {{Null(), std::string("")}, Null()}}; + static_cast(check_function(func_name, input_types, data_set)); +} + +TEST(function_string_test, function_url_encode_test) { + std::string func_name = "url_encode"; + + InputTypeSet input_types = {TypeIndex::String}; + + DataSet data_set = { + {{std::string("编码")}, std::string("%E7%BC%96%E7%A0%81")}, + {{std::string("http://www.baidu.com/?a=中文日文韩文俄文希伯来文Emoji")}, + std::string( + "http%3A%2F%2Fwww.baidu.com%2F%3Fa%3D%E4%B8%AD%E6%96%87%E6%97%A5%E6%96%87%E9%" + "9F%A9%E6%96%87%E4%BF%84%E6%96%87%E5%B8%8C%E4%BC%AF%E6%9D%A5%E6%96%87Emoji")}, + {{std::string("http://www.baidu.com?a=http%3A%2F%2Fexample.com%2F😊")}, + std::string("http%3A%2F%2Fwww.baidu.com%3Fa%3Dhttp%253A%252F%252Fexample.com%252F%F0%" + "9F%98%8A")}, + }; static_cast(check_function(func_name, input_types, data_set)); } diff --git a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy index 7a064e5f0f..20d13b3697 100644 --- a/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy +++ b/regression-test/suites/nereids_p0/expression/fold_constant/fold_constant_string_arithmatic.groovy @@ -50,7 +50,10 @@ suite("fold_constant_string_arithmatic") { testFoldConst("select append_trailing_char_if_absent('こんにちは', '!')") testFoldConst("select append_trailing_char_if_absent('\n\t', '\n')") testFoldConst("select append_trailing_char_if_absent('こんにちは', 'ちは')") - + testFoldConst("select append_trailing_char_if_absent('中文', '文')") + testFoldConst("select append_trailing_char_if_absent('中', '文')") + testFoldConst("select append_trailing_char_if_absent('', '文')") + // ascii testFoldConst("select ascii('!')") testFoldConst("select ascii('1')") @@ -684,7 +687,8 @@ suite("fold_constant_string_arithmatic") { testFoldConst("select right('Hello World', 5)") testFoldConst("select right('Hello World', 0)") testFoldConst("select right(NULL, 1)") - + testFoldConst("select right('🐼abc🐼', 100)") + testFoldConst("select right('你好世界',5)") // rpad testFoldConst("select rpad(cast('hi' as string), 1, cast('xy' as string))") testFoldConst("select rpad(cast('hi' as string), 5, cast('xy' as string))") @@ -1231,6 +1235,7 @@ suite("fold_constant_string_arithmatic") { testFoldConst("select url_decode('http%3A%2F%2Fwww.apache.org%2Flicenses%2FLICENSE-22.0')") testFoldConst("select url_encode('http://www.apache.org/licenses/LICENSE-2.0')") testFoldConst("select url_encode(' http://www.apache.org/licenses/LICENSE-2.0 ')") + testFoldConst("select url_encode(' http://www.baidu.com/?a=中文日文韩文俄文希伯来文Emoji')") // Normal Usage Test Cases