[Function] Support utf-8 encoding in instr, locate, locate_pos, lpad, rpad (#3638)
Support utf-8 encoding for string function `instr`, `locate`, `locate_pos`, `lpad`, `rpad` and add unit test for them
This commit is contained in:
@ -49,6 +49,15 @@ size_t get_utf8_byte_length(unsigned char byte) {
|
||||
}
|
||||
return char_size;
|
||||
}
|
||||
size_t get_char_len(const StringVal& str, std::vector<size_t>* str_index) {
|
||||
size_t char_len = 0;
|
||||
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
|
||||
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
|
||||
str_index->push_back(i);
|
||||
++char_len;
|
||||
}
|
||||
return char_len;
|
||||
}
|
||||
|
||||
// This behaves identically to the mysql implementation, namely:
|
||||
// - 1-indexed positions
|
||||
@ -73,8 +82,7 @@ StringVal StringFunctions::substring(
|
||||
std::vector<size_t> index;
|
||||
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
|
||||
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
|
||||
index.push_back(byte_pos);
|
||||
byte_pos += char_size;
|
||||
index.push_back(i);
|
||||
if (pos.val > 0 && index.size() > pos.val + len.val) {
|
||||
break;
|
||||
}
|
||||
@ -196,28 +204,45 @@ StringVal StringFunctions::lpad(
|
||||
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
|
||||
return StringVal::null();
|
||||
}
|
||||
|
||||
std::vector<size_t> str_index;
|
||||
size_t str_char_size = get_char_len(str, &str_index);
|
||||
std::vector<size_t> pad_index;
|
||||
size_t pad_char_size = get_char_len(pad, &pad_index);
|
||||
|
||||
// Corner cases: Shrink the original string, or leave it alone.
|
||||
// TODO: Hive seems to go into an infinite loop if pad.len == 0,
|
||||
// so we should pay attention to Hive's future solution to be compatible.
|
||||
if (len.val <= str.len || pad.len == 0) {
|
||||
return StringVal(str.ptr, len.val);
|
||||
if (len.val <= str_char_size || pad.len == 0) {
|
||||
if (len.val > str_index.size()) {
|
||||
return StringVal::null();
|
||||
}
|
||||
if (len.val == str_index.size()) {
|
||||
return StringVal(str.ptr, len.val);
|
||||
}
|
||||
return StringVal(str.ptr, str_index[len.val]);
|
||||
}
|
||||
|
||||
// TODO pengyubing
|
||||
// StringVal result = StringVal::create_temp_string_val(context, len.val);
|
||||
StringVal result(context, len.val);
|
||||
int32_t pad_byte_len = 0;
|
||||
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
|
||||
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
|
||||
pad_byte_len = pad_times * pad.len;
|
||||
pad_byte_len += pad_index[pad_remainder];
|
||||
int32_t byte_len = str.len + pad_byte_len;
|
||||
StringVal result(context, byte_len);
|
||||
if (result.is_null) {
|
||||
return result;
|
||||
}
|
||||
int padded_prefix_len = len.val - str.len;
|
||||
int pad_index = 0;
|
||||
int pad_idx = 0;
|
||||
int result_index = 0;
|
||||
uint8_t* ptr = result.ptr;
|
||||
|
||||
// Prepend chars of pad.
|
||||
while (result_index < padded_prefix_len) {
|
||||
ptr[result_index++] = pad.ptr[pad_index++];
|
||||
pad_index = pad_index % pad.len;
|
||||
while (result_index < pad_byte_len) {
|
||||
ptr[result_index++] = pad.ptr[pad_idx++];
|
||||
pad_idx = pad_idx % pad.len;
|
||||
}
|
||||
|
||||
// Append given string.
|
||||
@ -231,16 +256,34 @@ StringVal StringFunctions::rpad(
|
||||
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
|
||||
return StringVal::null();
|
||||
}
|
||||
|
||||
std::vector<size_t> str_index;
|
||||
size_t str_char_size = get_char_len(str, &str_index);
|
||||
std::vector<size_t> pad_index;
|
||||
size_t pad_char_size = get_char_len(pad, &pad_index);
|
||||
|
||||
// Corner cases: Shrink the original string, or leave it alone.
|
||||
// TODO: Hive seems to go into an infinite loop if pad->len == 0,
|
||||
// so we should pay attention to Hive's future solution to be compatible.
|
||||
if (len.val <= str.len || pad.len == 0) {
|
||||
return StringVal(str.ptr, len.val);
|
||||
if (len.val <= str_char_size || pad.len == 0) {
|
||||
if (len.val > str_index.size()) {
|
||||
return StringVal::null();
|
||||
}
|
||||
if (len.val == str_index.size()) {
|
||||
return StringVal(str.ptr, len.val);
|
||||
}
|
||||
return StringVal(str.ptr, str_index[len.val]);
|
||||
}
|
||||
|
||||
// TODO pengyubing
|
||||
// StringVal result = StringVal::create_temp_string_val(context, len.val);
|
||||
StringVal result(context, len.val);
|
||||
int32_t pad_byte_len = 0;
|
||||
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
|
||||
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
|
||||
pad_byte_len = pad_times * pad.len;
|
||||
pad_byte_len += pad_index[pad_remainder];
|
||||
int32_t byte_len = str.len + pad_byte_len;
|
||||
StringVal result(context, byte_len);
|
||||
if (UNLIKELY(result.is_null)) {
|
||||
return result;
|
||||
}
|
||||
@ -248,11 +291,11 @@ StringVal StringFunctions::rpad(
|
||||
|
||||
// Append chars of pad until desired length
|
||||
uint8_t* ptr = result.ptr;
|
||||
int pad_index = 0;
|
||||
int pad_idx = 0;
|
||||
int result_len = str.len;
|
||||
while (result_len < len.val) {
|
||||
ptr[result_len++] = pad.ptr[pad_index++];
|
||||
pad_index = pad_index % pad.len;
|
||||
while (result_len < byte_len) {
|
||||
ptr[result_len++] = pad.ptr[pad_idx++];
|
||||
pad_idx = pad_idx % pad.len;
|
||||
}
|
||||
return result;
|
||||
}
|
||||
@ -295,7 +338,6 @@ IntVal StringFunctions::char_utf8_length(FunctionContext* context, const StringV
|
||||
return IntVal::null();
|
||||
}
|
||||
size_t char_len = 0;
|
||||
std::vector<size_t> index;
|
||||
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
|
||||
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
|
||||
++char_len;
|
||||
@ -412,11 +454,24 @@ IntVal StringFunctions::instr(
|
||||
if (str.is_null || substr.is_null) {
|
||||
return IntVal::null();
|
||||
}
|
||||
if (substr.len == 0) {
|
||||
return IntVal(1);
|
||||
}
|
||||
StringValue str_sv = StringValue::from_string_val(str);
|
||||
StringValue substr_sv = StringValue::from_string_val(substr);
|
||||
StringSearch search(&substr_sv);
|
||||
// Hive returns positions starting from 1.
|
||||
return IntVal(search.search(&str_sv) + 1);
|
||||
int loc = search.search(&str_sv);
|
||||
if (loc > 0) {
|
||||
size_t char_len = 0;
|
||||
for (size_t i = 0, char_size = 0; i < loc; i += char_size) {
|
||||
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
|
||||
++char_len;
|
||||
}
|
||||
loc = char_len;
|
||||
}
|
||||
|
||||
return IntVal(loc + 1);
|
||||
}
|
||||
|
||||
IntVal StringFunctions::locate(
|
||||
@ -430,20 +485,34 @@ IntVal StringFunctions::locate_pos(
|
||||
if (str.is_null || substr.is_null || start_pos.is_null) {
|
||||
return IntVal::null();
|
||||
}
|
||||
if (substr.len == 0) {
|
||||
if (str.len == 0 && start_pos.val > 1) {
|
||||
return IntVal(0);
|
||||
}
|
||||
return IntVal(start_pos.val);
|
||||
}
|
||||
// Hive returns 0 for *start_pos <= 0,
|
||||
// but throws an exception for *start_pos > str->len.
|
||||
// Since returning 0 seems to be Hive's error condition, return 0.
|
||||
if (start_pos.val <= 0 || start_pos.val > str.len) {
|
||||
std::vector<size_t> index;
|
||||
size_t char_len = get_char_len(str, &index);
|
||||
if (start_pos.val <= 0 || start_pos.val > str.len || start_pos.val > char_len) {
|
||||
return IntVal(0);
|
||||
}
|
||||
StringValue substr_sv = StringValue::from_string_val(substr);
|
||||
StringSearch search(&substr_sv);
|
||||
// Input start_pos.val starts from 1.
|
||||
StringValue adjusted_str(
|
||||
reinterpret_cast<char*>(str.ptr) + start_pos.val - 1, str.len - start_pos.val + 1);
|
||||
reinterpret_cast<char*>(str.ptr) + index[start_pos.val - 1], str.len - index[start_pos.val - 1]);
|
||||
int32_t match_pos = search.search(&adjusted_str);
|
||||
if (match_pos >= 0) {
|
||||
// Hive returns the position in the original string starting from 1.
|
||||
size_t char_len = 0;
|
||||
for (size_t i = 0, char_size = 0; i < match_pos; i += char_size) {
|
||||
char_size = get_utf8_byte_length((unsigned)(adjusted_str.ptr)[i]);
|
||||
++char_len;
|
||||
}
|
||||
match_pos = char_len;
|
||||
return IntVal(start_pos.val + match_pos);
|
||||
} else {
|
||||
return IntVal(0);
|
||||
|
||||
Reference in New Issue
Block a user