[Function] Support utf-8 encoding in instr, locate, locate_pos, lpad, rpad (#3638)

Support utf-8 encoding for string function `instr`, `locate`, `locate_pos`, `lpad`, `rpad`
and add unit test for them
This commit is contained in:
yangzhg
2020-05-22 14:34:26 +08:00
committed by GitHub
parent 16deac96a9
commit ba7d2dbf7b
8 changed files with 289 additions and 27 deletions

View File

@ -49,6 +49,15 @@ size_t get_utf8_byte_length(unsigned char byte) {
}
return char_size;
}
size_t get_char_len(const StringVal& str, std::vector<size_t>* str_index) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
str_index->push_back(i);
++char_len;
}
return char_len;
}
// This behaves identically to the mysql implementation, namely:
// - 1-indexed positions
@ -73,8 +82,7 @@ StringVal StringFunctions::substring(
std::vector<size_t> index;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
index.push_back(byte_pos);
byte_pos += char_size;
index.push_back(i);
if (pos.val > 0 && index.size() > pos.val + len.val) {
break;
}
@ -196,28 +204,45 @@ StringVal StringFunctions::lpad(
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
return StringVal::null();
}
std::vector<size_t> str_index;
size_t str_char_size = get_char_len(str, &str_index);
std::vector<size_t> pad_index;
size_t pad_char_size = get_char_len(pad, &pad_index);
// Corner cases: Shrink the original string, or leave it alone.
// TODO: Hive seems to go into an infinite loop if pad.len == 0,
// so we should pay attention to Hive's future solution to be compatible.
if (len.val <= str.len || pad.len == 0) {
return StringVal(str.ptr, len.val);
if (len.val <= str_char_size || pad.len == 0) {
if (len.val > str_index.size()) {
return StringVal::null();
}
if (len.val == str_index.size()) {
return StringVal(str.ptr, len.val);
}
return StringVal(str.ptr, str_index[len.val]);
}
// TODO pengyubing
// StringVal result = StringVal::create_temp_string_val(context, len.val);
StringVal result(context, len.val);
int32_t pad_byte_len = 0;
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
pad_byte_len = pad_times * pad.len;
pad_byte_len += pad_index[pad_remainder];
int32_t byte_len = str.len + pad_byte_len;
StringVal result(context, byte_len);
if (result.is_null) {
return result;
}
int padded_prefix_len = len.val - str.len;
int pad_index = 0;
int pad_idx = 0;
int result_index = 0;
uint8_t* ptr = result.ptr;
// Prepend chars of pad.
while (result_index < padded_prefix_len) {
ptr[result_index++] = pad.ptr[pad_index++];
pad_index = pad_index % pad.len;
while (result_index < pad_byte_len) {
ptr[result_index++] = pad.ptr[pad_idx++];
pad_idx = pad_idx % pad.len;
}
// Append given string.
@ -231,16 +256,34 @@ StringVal StringFunctions::rpad(
if (str.is_null || len.is_null || pad.is_null || len.val < 0) {
return StringVal::null();
}
std::vector<size_t> str_index;
size_t str_char_size = get_char_len(str, &str_index);
std::vector<size_t> pad_index;
size_t pad_char_size = get_char_len(pad, &pad_index);
// Corner cases: Shrink the original string, or leave it alone.
// TODO: Hive seems to go into an infinite loop if pad->len == 0,
// so we should pay attention to Hive's future solution to be compatible.
if (len.val <= str.len || pad.len == 0) {
return StringVal(str.ptr, len.val);
if (len.val <= str_char_size || pad.len == 0) {
if (len.val > str_index.size()) {
return StringVal::null();
}
if (len.val == str_index.size()) {
return StringVal(str.ptr, len.val);
}
return StringVal(str.ptr, str_index[len.val]);
}
// TODO pengyubing
// StringVal result = StringVal::create_temp_string_val(context, len.val);
StringVal result(context, len.val);
int32_t pad_byte_len = 0;
int32_t pad_times = (len.val - str_char_size) / pad_char_size;
int32_t pad_remainder = (len.val - str_char_size) % pad_char_size;
pad_byte_len = pad_times * pad.len;
pad_byte_len += pad_index[pad_remainder];
int32_t byte_len = str.len + pad_byte_len;
StringVal result(context, byte_len);
if (UNLIKELY(result.is_null)) {
return result;
}
@ -248,11 +291,11 @@ StringVal StringFunctions::rpad(
// Append chars of pad until desired length
uint8_t* ptr = result.ptr;
int pad_index = 0;
int pad_idx = 0;
int result_len = str.len;
while (result_len < len.val) {
ptr[result_len++] = pad.ptr[pad_index++];
pad_index = pad_index % pad.len;
while (result_len < byte_len) {
ptr[result_len++] = pad.ptr[pad_idx++];
pad_idx = pad_idx % pad.len;
}
return result;
}
@ -295,7 +338,6 @@ IntVal StringFunctions::char_utf8_length(FunctionContext* context, const StringV
return IntVal::null();
}
size_t char_len = 0;
std::vector<size_t> index;
for (size_t i = 0, char_size = 0; i < str.len; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
++char_len;
@ -412,11 +454,24 @@ IntVal StringFunctions::instr(
if (str.is_null || substr.is_null) {
return IntVal::null();
}
if (substr.len == 0) {
return IntVal(1);
}
StringValue str_sv = StringValue::from_string_val(str);
StringValue substr_sv = StringValue::from_string_val(substr);
StringSearch search(&substr_sv);
// Hive returns positions starting from 1.
return IntVal(search.search(&str_sv) + 1);
int loc = search.search(&str_sv);
if (loc > 0) {
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < loc; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(str.ptr)[i]);
++char_len;
}
loc = char_len;
}
return IntVal(loc + 1);
}
IntVal StringFunctions::locate(
@ -430,20 +485,34 @@ IntVal StringFunctions::locate_pos(
if (str.is_null || substr.is_null || start_pos.is_null) {
return IntVal::null();
}
if (substr.len == 0) {
if (str.len == 0 && start_pos.val > 1) {
return IntVal(0);
}
return IntVal(start_pos.val);
}
// Hive returns 0 for *start_pos <= 0,
// but throws an exception for *start_pos > str->len.
// Since returning 0 seems to be Hive's error condition, return 0.
if (start_pos.val <= 0 || start_pos.val > str.len) {
std::vector<size_t> index;
size_t char_len = get_char_len(str, &index);
if (start_pos.val <= 0 || start_pos.val > str.len || start_pos.val > char_len) {
return IntVal(0);
}
StringValue substr_sv = StringValue::from_string_val(substr);
StringSearch search(&substr_sv);
// Input start_pos.val starts from 1.
StringValue adjusted_str(
reinterpret_cast<char*>(str.ptr) + start_pos.val - 1, str.len - start_pos.val + 1);
reinterpret_cast<char*>(str.ptr) + index[start_pos.val - 1], str.len - index[start_pos.val - 1]);
int32_t match_pos = search.search(&adjusted_str);
if (match_pos >= 0) {
// Hive returns the position in the original string starting from 1.
size_t char_len = 0;
for (size_t i = 0, char_size = 0; i < match_pos; i += char_size) {
char_size = get_utf8_byte_length((unsigned)(adjusted_str.ptr)[i]);
++char_len;
}
match_pos = char_len;
return IntVal(start_pos.val + match_pos);
} else {
return IntVal(0);

View File

@ -68,9 +68,6 @@ TEST_F(StringFunctionsTest, money_format_large_int) {
ss << str;
__int128 value;
ss >> value;
std::cout << "value: " << value << std::endl;
StringVal result = StringFunctions::money_format(context, doris_udf::LargeIntVal(value));
StringVal expected = AnyValUtil::from_string_temp(context, std::string("170,141,183,460,469,231,731,687,303,715,884,105,727.00"));
ASSERT_EQ(expected, result);
@ -361,6 +358,92 @@ TEST_F(StringFunctionsTest, append_trailing_char_if_absent) {
StringVal("a"), StringVal("abc")));
}
TEST_F(StringFunctionsTest, instr) {
doris_udf::FunctionContext* context = new doris_udf::FunctionContext();
ASSERT_EQ(IntVal(4), StringFunctions::instr(context, StringVal("foobarbar"), StringVal("bar")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("foobar"), StringVal("xbar")));
ASSERT_EQ(IntVal(2), StringFunctions::instr(context, StringVal("123456234"), StringVal("234")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("123456"), StringVal("567")));
ASSERT_EQ(IntVal(2), StringFunctions::instr(context, StringVal("1.234"), StringVal(".234")));
ASSERT_EQ(IntVal(1), StringFunctions::instr(context, StringVal("1.234"), StringVal("")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal(""), StringVal("123")));
ASSERT_EQ(IntVal(1), StringFunctions::instr(context, StringVal(""), StringVal("")));
ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好世界"), StringVal("世界")));
ASSERT_EQ(IntVal(0), StringFunctions::instr(context, StringVal("你好世界"), StringVal("您好")));
ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好abc"), StringVal("a")));
ASSERT_EQ(IntVal(3), StringFunctions::instr(context, StringVal("你好abc"), StringVal("abc")));
ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal::null(), StringVal("2")));
ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal(""), StringVal::null()));
ASSERT_EQ(IntVal::null(), StringFunctions::instr(context, StringVal::null(), StringVal::null()));
}
TEST_F(StringFunctionsTest, locate) {
doris_udf::FunctionContext* context = new doris_udf::FunctionContext();
ASSERT_EQ(IntVal(4), StringFunctions::locate(context, StringVal("bar"), StringVal("foobarbar")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("xbar"), StringVal("foobar")));
ASSERT_EQ(IntVal(2), StringFunctions::locate(context, StringVal("234"), StringVal("123456234")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("567"), StringVal("123456")));
ASSERT_EQ(IntVal(2), StringFunctions::locate(context, StringVal(".234"), StringVal("1.234")));
ASSERT_EQ(IntVal(1), StringFunctions::locate(context, StringVal(""), StringVal("1.234")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("123"), StringVal("")));
ASSERT_EQ(IntVal(1), StringFunctions::locate(context, StringVal(""), StringVal("")));
ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("世界"), StringVal("你好世界")));
ASSERT_EQ(IntVal(0), StringFunctions::locate(context, StringVal("您好"), StringVal("你好世界")));
ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("a"), StringVal("你好abc")));
ASSERT_EQ(IntVal(3), StringFunctions::locate(context, StringVal("abc"), StringVal("你好abc")));
ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal::null(), StringVal("2")));
ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal(""), StringVal::null()));
ASSERT_EQ(IntVal::null(), StringFunctions::locate(context, StringVal::null(), StringVal::null()));
}
TEST_F(StringFunctionsTest, locate_pos) {
doris_udf::FunctionContext* context = new doris_udf::FunctionContext();
ASSERT_EQ(IntVal(7), StringFunctions::locate_pos(context, StringVal("bar"), StringVal("foobarbar"), IntVal(5)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("xbar"), StringVal("foobar"), IntVal(1)));
ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal(""), StringVal("foobar"), IntVal(2)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("foobar"), StringVal(""), IntVal(1)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal(""), StringVal(""), IntVal(2)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("A"), StringVal("AAAAAA"), IntVal(0)));
ASSERT_EQ(IntVal(0), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(0)));
ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(1)));
ASSERT_EQ(IntVal(2), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(2)));
ASSERT_EQ(IntVal(5), StringFunctions::locate_pos(context, StringVal("A"), StringVal("大A写的A"), IntVal(3)));
ASSERT_EQ(IntVal(7), StringFunctions::locate_pos(context, StringVal("BaR"), StringVal("foobarBaR"), IntVal(5)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal("2"), IntVal(1)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal(""), StringVal::null(), IntVal(4)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal::null(), IntVal(4)));
ASSERT_EQ(IntVal::null(), StringFunctions::locate_pos(context, StringVal::null(), StringVal::null(), IntVal(-1)));
}
TEST_F(StringFunctionsTest, lpad) {
ASSERT_EQ(StringVal("???hi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("?")));
ASSERT_EQ(StringVal("g8%7IgY%AHx7luNtf8Kh"), StringFunctions::lpad(ctx, StringVal("g8%7IgY%AHx7luNtf8Kh"), IntVal(20), StringVal("")));
ASSERT_EQ(StringVal("h"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal(""), StringFunctions::lpad(ctx, StringVal("你好"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal(""), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(0), StringVal("?")));
ASSERT_EQ(StringVal::null(), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(-1), StringVal("?")));
ASSERT_EQ(StringVal("h"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(1), StringVal("")));
ASSERT_EQ(StringVal::null(), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("")));
ASSERT_EQ(StringVal("abahi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("ab")));
ASSERT_EQ(StringVal("ababhi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(6), StringVal("ab")));
ASSERT_EQ(StringVal("呵呵呵hi"), StringFunctions::lpad(ctx, StringVal("hi"), IntVal(5), StringVal("呵呵")));
ASSERT_EQ(StringVal("hih呵呵"), StringFunctions::lpad(ctx, StringVal("呵呵"), IntVal(5), StringVal("hi")));
}
TEST_F(StringFunctionsTest, rpad) {
ASSERT_EQ(StringVal("hi???"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("?")));
ASSERT_EQ(StringVal("g8%7IgY%AHx7luNtf8Kh"), StringFunctions::rpad(ctx, StringVal("g8%7IgY%AHx7luNtf8Kh"), IntVal(20), StringVal("")));
ASSERT_EQ(StringVal("h"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal(""), StringFunctions::rpad(ctx, StringVal("你好"), IntVal(1), StringVal("?")));
ASSERT_EQ(StringVal(""), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(0), StringVal("?")));
ASSERT_EQ(StringVal::null(), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(-1), StringVal("?")));
ASSERT_EQ(StringVal("h"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(1), StringVal("")));
ASSERT_EQ(StringVal::null(), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("")));
ASSERT_EQ(StringVal("hiaba"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("ab")));
ASSERT_EQ(StringVal("hiabab"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(6), StringVal("ab")));
ASSERT_EQ(StringVal("hi呵呵呵"), StringFunctions::rpad(ctx, StringVal("hi"), IntVal(5), StringVal("呵呵")));
ASSERT_EQ(StringVal("呵呵hih"), StringFunctions::rpad(ctx, StringVal("呵呵"), IntVal(5), StringVal("hi")));
}
}
int main(int argc, char** argv) {

View File

@ -221,6 +221,7 @@ module.exports = [
"regexp_replace",
"repeat",
"right",
"rpad",
"split_part",
"starts_with",
"strleft",

View File

@ -233,6 +233,7 @@ module.exports = [
"regexp_replace",
"repeat",
"right",
"rpad",
"split_part",
"starts_with",
"strleft",

View File

@ -28,10 +28,10 @@ under the License.
## Description
### Syntax
'VARCHAR lpad (VARCHAR str., INT len, VARCHAR pad)'
'VARCHAR lpad (VARCHAR str, INT len, VARCHAR pad)'
Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length.
Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length. The len is character length not the bye size.
## example

View File

@ -0,0 +1,54 @@
---
{
"title": "rpad",
"language": "en"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# rpad
## Description
### Syntax
'VARCHAR rpad (VARCHAR str, INT len, VARCHAR pad)'
Returns a string of length len in str, starting with the initials. If len is longer than str, pad characters are added to the right of STR until the length of the string reaches len. If len is less than str's length, the function is equivalent to truncating STR strings and returning only strings of len's length. The len is character length not the bye size.
## example
```
mysql> SELECT rpad("hi", 5, "xy");
+---------------------+
| rpad('hi', 5, 'xy') |
+---------------------+
| hixyx |
+---------------------+
mysql> SELECT rpad("hi", 1, "xy");
+---------------------+
| rpad('hi', 1, 'xy') |
+---------------------+
| h |
+---------------------+
```
##keyword
RPAD

View File

@ -31,7 +31,7 @@ under the License.
`VARCHAR lpad(VARCHAR str, INT len, VARCHAR pad)`
返回 str 中长度为 len(从首字母开始算起)的字符串。如果 len 大于 str 的长度,则在 str 的前面不断补充 pad 字符,直到该字符串的长度达到 len 为止。如果 len 小于 str 的长度,该函数相当于截断 str 字符串,只返回长度为 len 的字符串。
返回 str 中长度为 len(从首字母开始算起)的字符串。如果 len 大于 str 的长度,则在 str 的前面不断补充 pad 字符,直到该字符串的长度达到 len 为止。如果 len 小于 str 的长度,该函数相当于截断 str 字符串,只返回长度为 len 的字符串。len 指的是字符长度而不是字节长度。
## example

View File

@ -0,0 +1,54 @@
---
{
"title": "rpad",
"language": "zh-CN"
}
---
<!--
Licensed to the Apache Software Foundation (ASF) under one
or more contributor license agreements. See the NOTICE file
distributed with this work for additional information
regarding copyright ownership. The ASF licenses this file
to you under the Apache License, Version 2.0 (the
"License"); you may not use this file except in compliance
with the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing,
software distributed under the License is distributed on an
"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
-->
# rpad
## description
### Syntax
`VARCHAR rpad(VARCHAR str, INT len, VARCHAR pad)`
返回 str 中长度为 len(从首字母开始算起)的字符串。如果 len 大于 str 的长度,则在 str 的后面不断补充 pad 字符,直到该字符串的长度达到 len 为止。如果 len 小于 str 的长度,该函数相当于截断 str 字符串,只返回长度为 len 的字符串。len 指的是字符长度而不是字节长度。
## example
```
mysql> SELECT rpad("hi", 5, "xy");
+---------------------+
| rpad('hi', 5, 'xy') |
+---------------------+
| hixyx |
+---------------------+
mysql> SELECT rpad("hi", 1, "xy");
+---------------------+
| rpad('hi', 1, 'xy') |
+---------------------+
| h |
+---------------------+
```
##keyword
RPAD