diff --git a/be/src/exprs/math_functions.cpp b/be/src/exprs/math_functions.cpp index f2f5aad1c1..28424f575d 100644 --- a/be/src/exprs/math_functions.cpp +++ b/be/src/exprs/math_functions.cpp @@ -26,10 +26,9 @@ #include "common/compiler_util.h" #include "exprs/anyval_util.h" -#include "exprs/expr.h" #include "runtime/decimalv2_value.h" -#include "runtime/tuple_row.h" #include "util/string_parser.hpp" +#include "util/simd/vstring_function.h" namespace doris { @@ -351,14 +350,10 @@ StringVal MathFunctions::hex_string(FunctionContext* ctx, const StringVal& s) { if (s.is_null) { return StringVal::null(); } - std::stringstream ss; - ss << std::hex << std::uppercase << std::setfill('0'); - for (int i = 0; i < s.len; ++i) { - // setw is not sticky. std::stringstream only converts integral values, - // so a cast to int is required, but only convert the least significant byte to hex. - ss << std::setw(2) << (static_cast(s.ptr[i]) & 0xFF); - } - return AnyValUtil::from_string_temp(ctx, ss.str()); + + StringVal result = StringVal::create_temp_string_val(ctx, s.len * 2); + simd::VStringFunctions::hex_encode(s.ptr, s.len, reinterpret_cast(result.ptr)); + return result; } StringVal MathFunctions::unhex(FunctionContext* ctx, const StringVal& s) { diff --git a/be/src/exprs/string_functions.cpp b/be/src/exprs/string_functions.cpp index aedf6331a3..00f26438c6 100644 --- a/be/src/exprs/string_functions.cpp +++ b/be/src/exprs/string_functions.cpp @@ -22,37 +22,17 @@ #include #include "exprs/anyval_util.h" -#include "exprs/expr.h" #include "fmt/format.h" #include "math_functions.h" #include "runtime/string_value.hpp" -#include "runtime/tuple_row.h" +#include "util/simd/vstring_function.h" #include "util/url_parser.h" -#include "util/vectorized-tool/lower.h" -#include "util/vectorized-tool/upper.h" // NOTE: be careful not to use string::append. It is not performant. namespace doris { void StringFunctions::init() {} -size_t get_utf8_byte_length(unsigned char byte) { - size_t char_size = 0; - if (byte >= 0xFC) { - char_size = 6; - } else if (byte >= 0xF8) { - char_size = 5; - } else if (byte >= 0xF0) { - char_size = 4; - } else if (byte >= 0xE0) { - char_size = 3; - } else if (byte >= 0xC0) { - char_size = 2; - } else { - char_size = 1; - } - return char_size; -} size_t get_char_len(const StringVal& str, std::vector* str_index) { size_t char_len = 0; for (size_t i = 0, char_size = 0; i < str.len; i += char_size) { @@ -353,7 +333,7 @@ StringVal StringFunctions::lower(FunctionContext* context, const StringVal& str) if (UNLIKELY(result.is_null)) { return result; } - Lower::to_lower(str.ptr, str.len, result.ptr); + simd::VStringFunctions::to_lower(str.ptr, str.len, result.ptr); return result; } @@ -365,7 +345,7 @@ StringVal StringFunctions::upper(FunctionContext* context, const StringVal& str) if (UNLIKELY(result.is_null)) { return result; } - Upper::to_upper(str.ptr, str.len, result.ptr); + simd::VStringFunctions::to_upper(str.ptr, str.len, result.ptr); return result; } @@ -379,57 +359,20 @@ StringVal StringFunctions::reverse(FunctionContext* context, const StringVal& st return result; } - for (size_t i = 0, char_size = 0; i < str.len; i += char_size) { - char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); - std::copy(str.ptr + i, str.ptr + i + char_size, result.ptr + result.len - i - char_size); - } - + simd::VStringFunctions::reverse(str, result); return result; } StringVal StringFunctions::trim(FunctionContext* context, const StringVal& str) { - if (str.is_null) { - return StringVal::null(); - } - // Find new starting position. - int32_t begin = 0; - while (begin < str.len && str.ptr[begin] == ' ') { - ++begin; - } - // Find new ending position. - int32_t end = str.len - 1; - while (end > begin && str.ptr[end] == ' ') { - --end; - } - return StringVal(str.ptr + begin, end - begin + 1); + return simd::VStringFunctions::trim(str); } StringVal StringFunctions::ltrim(FunctionContext* context, const StringVal& str) { - if (str.is_null) { - return StringVal::null(); - } - // Find new starting position. - int32_t begin = 0; - while (begin < str.len && str.ptr[begin] == ' ') { - ++begin; - } - return StringVal(str.ptr + begin, str.len - begin); + return simd::VStringFunctions::ltrim(str); } StringVal StringFunctions::rtrim(FunctionContext* context, const StringVal& str) { - if (str.is_null) { - return StringVal::null(); - } - if (str.len == 0) { - return str; - } - // Find new ending position. - int32_t end = str.len - 1; - while (end > 0 && str.ptr[end] == ' ') { - --end; - } - DCHECK_GE(end, 0); - return StringVal(str.ptr, (str.ptr[end] == ' ') ? end : end + 1); + return simd::VStringFunctions::rtrim(str); } IntVal StringFunctions::ascii(FunctionContext* context, const StringVal& str) { diff --git a/be/src/util/vectorized-tool/lower_upper_impl.h b/be/src/util/simd/lower_upper_impl.h similarity index 99% rename from be/src/util/vectorized-tool/lower_upper_impl.h rename to be/src/util/simd/lower_upper_impl.h index 6d7853fe0a..c8a2572f7f 100644 --- a/be/src/util/vectorized-tool/lower_upper_impl.h +++ b/be/src/util/simd/lower_upper_impl.h @@ -25,7 +25,7 @@ // the code refer: https://clickhouse.tech/codebrowser/html_report//ClickHouse/src/Functions/LowerUpperImpl.h.html // Doris only handle one character at a time, this function use SIMD to more characters at a time -namespace doris { +namespace doris::simd { template class LowerUpperImpl { diff --git a/be/src/util/simd/vstring_function.h b/be/src/util/simd/vstring_function.h new file mode 100644 index 0000000000..5ed4fb051a --- /dev/null +++ b/be/src/util/simd/vstring_function.h @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "util/simd/lower_upper_impl.h" + +#include +#include +#include "runtime/string_value.hpp" + +namespace doris { + +static size_t get_utf8_byte_length(unsigned char byte) { + size_t char_size = 0; + if (byte >= 0xFC) { + char_size = 6; + } else if (byte >= 0xF8) { + char_size = 5; + } else if (byte >= 0xF0) { + char_size = 4; + } else if (byte >= 0xE0) { + char_size = 3; + } else if (byte >= 0xC0) { + char_size = 2; + } else { + char_size = 1; + } + return char_size; +} + +namespace simd { + +class VStringFunctions { +public: +#ifdef __SSE2__ + /// n equals to 16 chars length + static constexpr auto REGISTER_SIZE = sizeof(__m128i); +#endif +public: + static StringVal rtrim(const StringVal& str) { + if (str.is_null || str.len == 0) { + return str; + } + auto begin = 0; + auto end = str.len - 1; +#ifdef __SSE2__ + char blank = ' '; + const auto pattern = _mm_set1_epi8(blank); + while (end - begin + 1 >= REGISTER_SIZE) { + const auto v_haystack = _mm_loadu_si128(reinterpret_cast(str.ptr + end + 1 - REGISTER_SIZE)); + const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); + const auto mask = _mm_movemask_epi8(v_against_pattern); + int offset = __builtin_clz(~(mask << REGISTER_SIZE)); + /// means not found + if (offset == 0) + { + return StringVal(str.ptr + begin, end - begin + 1); + } else { + end -= offset; + } + } +#endif + while (end >= begin && str.ptr[end] == ' ') { + --end; + } + if (end < 0) { + return StringVal(""); + } + return StringVal(str.ptr + begin, end - begin + 1); + } + + static StringVal ltrim(const StringVal& str) { + if (str.is_null || str.len == 0) { + return str; + } + auto begin = 0; + auto end = str.len - 1; +#ifdef __SSE2__ + char blank = ' '; + const auto pattern = _mm_set1_epi8(blank); + while (end - begin + 1 >= REGISTER_SIZE) { + const auto v_haystack = _mm_loadu_si128(reinterpret_cast(str.ptr + begin)); + const auto v_against_pattern = _mm_cmpeq_epi8(v_haystack, pattern); + const auto mask = _mm_movemask_epi8(v_against_pattern); + const auto offset = __builtin_ctz(mask ^ 0xffff); + /// means not found + if (offset == 0) + { + return StringVal(str.ptr + begin, end - begin + 1); + } else if (offset > REGISTER_SIZE) { + begin += REGISTER_SIZE; + } else { + begin += offset; + return StringVal(str.ptr + begin, end - begin + 1); + } + } +#endif + while (begin <= end && str.ptr[begin] == ' ') { + ++begin; + } + return StringVal(str.ptr + begin, end - begin + 1); + } + + static StringVal trim(const StringVal& str) { + if (str.is_null || str.len == 0) { + return str; + } + return rtrim(ltrim(str)); + } + + // Gcc will do auto simd in this function + static bool is_ascii(const StringVal& str) { + char or_code = 0; + for (size_t i = 0; i < str.len; i++) { + or_code |= str.ptr[i]; + } + return !(or_code & 0x80); + } + + static void reverse(const StringVal& str, StringVal dst) { + if (is_ascii(str)) { + int64_t begin = 0; + int64_t end = str.len; + int64_t result_end = dst.len - 1; + + // auto SIMD here + auto* __restrict l = dst.ptr; + auto* __restrict r = str.ptr; + for (; begin < end; ++begin, --result_end) { + l[result_end] = r[begin]; + } + } else { + for (size_t i = 0, char_size = 0; i < str.len; i += char_size) { + char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]); + std::copy(str.ptr + i, str.ptr + i + char_size, dst.ptr + str.len - i - char_size); + } + } + } + + static void hex_encode(const unsigned char* src_str, size_t length, char* dst_str) { + static constexpr auto hex_table = "0123456789ABCDEF"; + auto src_str_end = src_str + length; + +#if defined(__SSE2__) + constexpr auto step = sizeof(uint64); + if (src_str + step < src_str_end) { + const auto hex_map = _mm_loadu_si128(reinterpret_cast(hex_table)); + const auto mask_map = _mm_set1_epi8(0x0F); + + do { + auto data = _mm_loadu_si64(src_str); + auto hex_loc = _mm_and_si128(_mm_unpacklo_epi8(_mm_srli_epi64(data, 4), data), mask_map); + _mm_storeu_si128(reinterpret_cast<__m128i *>(dst_str), _mm_shuffle_epi8(hex_map, hex_loc)); + + src_str += step; + dst_str += step * 2; + } while (src_str + step < src_str_end); + } +#endif + char res[2]; + // hex(str) str length is n, result must be 2 * n length + for (; src_str < src_str_end; src_str += 1, dst_str += 2) { + // low 4 bits + *(res + 1) = hex_table[src_str[0] & 0x0F]; + // high 4 bits + *res = hex_table[(src_str[0] >> 4)]; + std::copy(res, res + 2, dst_str); + } + } + + static void to_lower(uint8_t * src, int64_t len, uint8_t * dst) { + if (len <= 0) { + return; + } + LowerUpperImpl<'A', 'Z'> lowerUpper; + lowerUpper.transfer(src, src + len, dst); + } + + static void to_upper(uint8_t * src, int64_t len, uint8_t * dst) { + if (len <= 0) { + return; + } + LowerUpperImpl<'a', 'z'> lowerUpper; + lowerUpper.transfer(src, src + len, dst); + } +}; +} +} diff --git a/be/src/util/vectorized-tool/lower.h b/be/src/util/vectorized-tool/lower.h deleted file mode 100644 index c28308f365..0000000000 --- a/be/src/util/vectorized-tool/lower.h +++ /dev/null @@ -1,37 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -#pragma once -#ifndef BE_LOWER_H -#define BE_LOWER_H - -#include "lower_upper_impl.h" -#include - -namespace doris { -class Lower { -public: - static void to_lower(uint8_t * src, int64_t len, uint8_t * dst) { - if (len <= 0) { - return; - } - LowerUpperImpl<'A', 'Z'> lowerUpper; - lowerUpper.transfer(src, src + len, dst); - } -}; -} - -#endif //BE_LOWER_H diff --git a/be/src/util/vectorized-tool/upper.h b/be/src/util/vectorized-tool/upper.h deleted file mode 100644 index a1b7b7a72d..0000000000 --- a/be/src/util/vectorized-tool/upper.h +++ /dev/null @@ -1,36 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. -#pragma once -#ifndef BE_UPPER_H -#define BE_UPPER_H - -#include "lower_upper_impl.h" -#include - -namespace doris { - class Upper { - public: - static void to_upper(uint8_t * src, int64_t len, uint8_t * dst) { - if (len <= 0) { - return; - } - LowerUpperImpl<'a', 'z'> lowerUpper; - lowerUpper.transfer(src, src + len, dst); - } - }; -} -#endif //BE_UPPER_H diff --git a/be/test/exprs/string_functions_test.cpp b/be/test/exprs/string_functions_test.cpp index 010c1e3ef3..330d99c563 100644 --- a/be/test/exprs/string_functions_test.cpp +++ b/be/test/exprs/string_functions_test.cpp @@ -677,6 +677,76 @@ TEST_F(StringFunctionsTest, upper) { ASSERT_EQ(StringVal(""), StringFunctions::upper(ctx, StringVal(""))); } +TEST_F(StringFunctionsTest, ltrim) { + // no blank + StringVal src("hello worldaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + StringVal res = StringFunctions::ltrim(ctx, src); + ASSERT_EQ(src, res); + // empty string + StringVal src1(""); + res = StringFunctions::ltrim(ctx, src1); + ASSERT_EQ(src1, res); + // null string + StringVal src2(StringVal::null()); + res = StringFunctions::ltrim(ctx, src2); + ASSERT_EQ(src2, res); + // less than 16 blanks + StringVal src3(" hello worldaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + res = StringFunctions::ltrim(ctx, src3); + ASSERT_EQ(src, res); + // more than 16 blanks + StringVal src4(" hello worldaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + res = StringFunctions::ltrim(ctx, src4); + ASSERT_EQ(src, res); + // all are blanks, less than 16 blanks + StringVal src5(" "); + res = StringFunctions::ltrim(ctx, src5); + ASSERT_EQ(StringVal(""), res); + // all are blanks, more than 16 blanks + StringVal src6(" "); + res = StringFunctions::ltrim(ctx, src6); + ASSERT_EQ(StringVal(""), res); + // src less than 16 length + StringVal src7(" 12345678910"); + res = StringFunctions::ltrim(ctx, src7); + ASSERT_EQ(StringVal("12345678910"), res); +} + +TEST_F(StringFunctionsTest, rtrim) { + // no blank + StringVal src("hello worldaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa"); + StringVal res = StringFunctions::rtrim(ctx, src); + ASSERT_EQ(src, res); + // empty string + StringVal src1(""); + res = StringFunctions::rtrim(ctx, src1); + ASSERT_EQ(src1, res); + // null string + StringVal src2(StringVal::null()); + res = StringFunctions::rtrim(ctx, src2); + ASSERT_EQ(src2, res); + // less than 16 blanks + StringVal src3("hello worldaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa "); + res = StringFunctions::rtrim(ctx, src3); + ASSERT_EQ(src, res); + // more than 16 blanks + StringVal src4("hello worldaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa "); + res = StringFunctions::rtrim(ctx, src4); + ASSERT_EQ(src, res); + // all are blanks, less than 16 blanks + StringVal src5(" "); + res = StringFunctions::rtrim(ctx, src5); + ASSERT_EQ(StringVal(""), res); + // all are blanks, more than 16 blanks + StringVal src6(" "); + res = StringFunctions::rtrim(ctx, src6); + ASSERT_EQ(StringVal(""), res); + // src less than 16 length + StringVal src7("12345678910 "); + res = StringFunctions::rtrim(ctx, src7); + ASSERT_EQ(StringVal("12345678910"), res); +} + } // namespace doris int main(int argc, char** argv) {