[Function] Support utf-8 encoding in instr, locate, locate_pos, lpad, rpad (#3638)

Support utf-8 encoding for string function `instr`, `locate`, `locate_pos`, `lpad`, `rpad`
and add unit test for them
This commit is contained in:
yangzhg
2020-05-22 14:34:26 +08:00
committed by GitHub
parent 16deac96a9
commit ba7d2dbf7b
8 changed files with 289 additions and 27 deletions

View File

@ -49,6 +49,15 @@ size_t get_utf8_byte_length(unsigned char byte) {
}
return char_size;
}
size_t get_char_len(const StringVal& str, std::vector<size_t>* str_index) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
str_index->push_back(i);
++char_len;
}
return char_len;
}
// This behaves identically to the mysql implementation, namely:
// - 1-indexed positions
@ -73,8 +82,7 @@ StringVal StringFunctions::substring(
std::vector<size_t> index;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
index.push_back(byte_pos);
byte_pos += char_size;
index.push_back(i);
if (pos.val > 0 && index.size() > pos.val + len.val) {
break;
}
@ -196,28 +204,45 @@ StringVal StringFunctions::lpad(
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
return StringVal::null();
}
std::vector<size_t> str_index;
size_t str_char_size = get_char_len(str, &str_index);
std::vector<size_t> pad_index;
size_t pad_char_size = get_char_len(pad, &pad_index);
// Corner cases: Shrink the original string, or leave it alone.
// TODO: Hive seems to go into an infinite loop if pad.len == 0,
// so we should pay attention to Hive's future solution to be compatible.
if (len.val <= str.len || pad.len == 0) {
return StringVal(str.ptr, len.val);
if (len.val <= str_char_size || pad.len == 0) {
if (len.val > str_index.size()) {
return StringVal::null();
}
if (len.val == str_index.size()) {
return StringVal(str.ptr, len.val);
}
return StringVal(str.ptr, str_index[len.val]);
}
// TODO pengyubing
// StringVal result = StringVal::create_temp_string_val(context, len.val);
StringVal result(context, len.val);
int32_t pad_byte_len = 0;
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
pad_byte_len = pad_times * pad.len;
pad_byte_len += pad_index[pad_remainder];
int32_t byte_len = str.len + pad_byte_len;
StringVal result(context, byte_len);
if (result.is_null) {
return result;
}
int padded_prefix_len = len.val - str.len;
int pad_index = 0;
int pad_idx = 0;
int result_index = 0;
uint8_t* ptr = result.ptr;
// Prepend chars of pad.
while (result_index < padded_prefix_len) {
ptr[result_index++] = pad.ptr[pad_index++];
pad_index = pad_index % pad.len;
while (result_index < pad_byte_len) {
ptr[result_index++] = pad.ptr[pad_idx++];
pad_idx = pad_idx % pad.len;
}
// Append given string.
@ -231,16 +256,34 @@ StringVal StringFunctions::rpad(
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
return StringVal::null();
}
std::vector<size_t> str_index;
size_t str_char_size = get_char_len(str, &str_index);
std::vector<size_t> pad_index;
size_t pad_char_size = get_char_len(pad, &pad_index);
// Corner cases: Shrink the original string, or leave it alone.
// TODO: Hive seems to go into an infinite loop if pad->len == 0,
// so we should pay attention to Hive's future solution to be compatible.
if (len.val <= str.len || pad.len == 0) {
return StringVal(str.ptr, len.val);
if (len.val <= str_char_size || pad.len == 0) {
if (len.val > str_index.size()) {
return StringVal::null();
}
if (len.val == str_index.size()) {
return StringVal(str.ptr, len.val);
}
return StringVal(str.ptr, str_index[len.val]);
}
// TODO pengyubing
// StringVal result = StringVal::create_temp_string_val(context, len.val);
StringVal result(context, len.val);
int32_t pad_byte_len = 0;
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
pad_byte_len = pad_times * pad.len;
pad_byte_len += pad_index[pad_remainder];
int32_t byte_len = str.len + pad_byte_len;
StringVal result(context, byte_len);
if (UNLIKELY(result.is_null)) {
return result;
}
@ -248,11 +291,11 @@ StringVal StringFunctions::rpad(
// Append chars of pad until desired length
uint8_t* ptr = result.ptr;
int pad_index = 0;
int pad_idx = 0;
int result_len = str.len;
while (result_len < len.val) {
ptr[result_len++] = pad.ptr[pad_index++];
pad_index = pad_index % pad.len;
while (result_len < byte_len) {
ptr[result_len++] = pad.ptr[pad_idx++];
pad_idx = pad_idx % pad.len;
}
return result;
}
@ -295,7 +338,6 @@ IntVal StringFunctions::char_utf8_length(FunctionContext* context, const StringV
return IntVal::null();
}
size_t char_len = 0;
std::vector<size_t> index;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
++char_len;
@ -412,11 +454,24 @@ IntVal StringFunctions::instr(
if (str.is_null || substr.is_null) {
return IntVal::null();
}
if (substr.len == 0) {
return IntVal(1);
}
StringValue str_sv = StringValue::from_string_val(str);
StringValue substr_sv = StringValue::from_string_val(substr);
StringSearch search(&substr_sv);
// Hive returns positions starting from 1.
return IntVal(search.search(&str_sv) + 1);
int loc = search.search(&str_sv);
if (loc > 0) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < loc; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
++char_len;
}
loc = char_len;
}
return IntVal(loc + 1);
}
IntVal StringFunctions::locate(
@ -430,20 +485,34 @@ IntVal StringFunctions::locate_pos(
if (str.is_null || substr.is_null || start_pos.is_null) {
return IntVal::null();
}
if (substr.len == 0) {
if (str.len == 0 && start_pos.val > 1) {
return IntVal(0);
}
return IntVal(start_pos.val);
}
// Hive returns 0 for *start_pos <= 0,
// but throws an exception for *start_pos > str->len.
// Since returning 0 seems to be Hive's error condition, return 0.
if (start_pos.val <= 0 || start_pos.val > str.len) {
std::vector<size_t> index;
size_t char_len = get_char_len(str, &index);
if (start_pos.val <= 0 || start_pos.val > str.len || start_pos.val > char_len) {
return IntVal(0);
}
StringValue substr_sv = StringValue::from_string_val(substr);
StringSearch search(&substr_sv);
// Input start_pos.val starts from 1.
StringValue adjusted_str(
reinterpret_cast<char*>(str.ptr) + start_pos.val - 1, str.len - start_pos.val + 1);
reinterpret_cast<char*>(str.ptr) + index[start_pos.val - 1], str.len - index[start_pos.val - 1]);
int32_t match_pos = search.search(&adjusted_str);
if (match_pos >= 0) {
// Hive returns the position in the original string starting from 1.
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < match_pos; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(adjusted_str.ptr)[i]);
++char_len;
}
match_pos = char_len;
return IntVal(start_pos.val + match_pos);
} else {
return IntVal(0);