diff --git a/be/src/agent/be_exec_version_manager.h b/be/src/agent/be_exec_version_manager.h index 1cabc38eba..f5213c5408 100644 --- a/be/src/agent/be_exec_version_manager.h +++ b/be/src/agent/be_exec_version_manager.h @@ -64,7 +64,7 @@ private: * c. cleared old version of Version 2. * d. unix_timestamp function support timestamp with float for datetimev2, and change nullable mode. * e. change shuffle serialize/deserialize way - * f. the right function outputs NULL when the function contains NULL, substr function returns empty if start > str.length, and change some function nullable mode. + * f. shrink some function's nullable mode. */ constexpr inline int BeExecVersionManager::max_be_exec_version = 3; constexpr inline int BeExecVersionManager::min_be_exec_version = 0; diff --git a/be/src/util/url_coding.cpp b/be/src/util/url_coding.cpp index 0a5c4144c5..f091e63812 100644 --- a/be/src/util/url_coding.cpp +++ b/be/src/util/url_coding.cpp @@ -91,7 +91,7 @@ static void encode_base64_internal(const std::string& in, std::string* out, size_t len = in.size(); // Every 3 source bytes will be encoded into 4 bytes. std::unique_ptr buf(new unsigned char[(((len + 2) / 3) * 4)]); - const unsigned char* s = reinterpret_cast(in.data()); + const auto* s = reinterpret_cast(in.data()); unsigned char* d = buf.get(); while (len > 2) { *d++ = basis[(s[0] >> 2) & 0x3f]; @@ -157,7 +157,7 @@ static short decoding_table[256] = { static int mod_table[] = {0, 2, 1}; size_t base64_encode(const unsigned char* data, size_t length, unsigned char* encoded_data) { - size_t output_length = (size_t)(4.0 * ceil((double)length / 3.0)); + auto output_length = (size_t)(4.0 * ceil((double)length / 3.0)); if (encoded_data == nullptr) { return 0; @@ -267,7 +267,7 @@ bool base64_decode(const std::string& in, std::string* out) { } void escape_for_html(const std::string& in, std::stringstream* out) { - for (auto& c : in) { + for (const auto& c : in) { switch (c) { case '<': (*out) << "<"; diff --git a/be/src/vec/functions/function_string.cpp b/be/src/vec/functions/function_string.cpp index cce325e89e..ab4ac6c86a 100644 --- a/be/src/vec/functions/function_string.cpp +++ b/be/src/vec/functions/function_string.cpp @@ -140,8 +140,8 @@ struct FindInSetOp { using ResultDataType = DataTypeInt32; using ResultPaddedPODArray = PaddedPODArray; static void execute(const std::string_view& strl, const std::string_view& strr, int32_t& res) { - for (int i = 0; i < strl.length(); ++i) { - if (strl[i] == ',') { + for (const auto& c : strl) { + if (c == ',') { res = 0; return; } @@ -635,18 +635,12 @@ struct UnHexImpl { } static Status vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, - ColumnString::Chars& dst_data, ColumnString::Offsets& dst_offsets, - NullMap& null_map) { + ColumnString::Chars& dst_data, ColumnString::Offsets& dst_offsets) { auto rows_count = offsets.size(); dst_offsets.resize(rows_count); for (int i = 0; i < rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, dst_data, dst_offsets, null_map); - continue; - } - - auto source = reinterpret_cast(&data[offsets[i - 1]]); + const auto* source = reinterpret_cast(&data[offsets[i - 1]]); size_t srclen = offsets[i] - offsets[i - 1]; if (srclen == 0) { @@ -666,17 +660,103 @@ struct UnHexImpl { int outlen = hex_decode(source, srclen, dst); - if (outlen < 0) { - StringOP::push_null_string(i, dst_data, dst_offsets, null_map); - } else { - StringOP::push_value_string(std::string_view(dst, outlen), i, dst_data, - dst_offsets); - } + StringOP::push_value_string(std::string_view(dst, outlen), i, dst_data, dst_offsets); } return Status::OK(); } }; + +struct UnHexOldImpl { + static constexpr auto name = "unhex"; + using ReturnType = DataTypeString; + using ColumnType = ColumnString; + + static bool check_and_decode_one(char& c, const char src_c, bool flag) { + int k = flag ? 16 : 1; + int value = src_c - '0'; + // 9 = ('9'-'0') + if (value >= 0 && value <= 9) { + c += value * k; + return true; + } + + value = src_c - 'A'; + // 5 = ('F'-'A') + if (value >= 0 && value <= 5) { + c += (value + 10) * k; + return true; + } + + value = src_c - 'a'; + // 5 = ('f'-'a') + if (value >= 0 && value <= 5) { + c += (value + 10) * k; + return true; + } + // not in ( ['0','9'], ['a','f'], ['A','F'] ) + return false; + } + + static int hex_decode(const char* src_str, size_t src_len, char* dst_str) { + // if str length is odd or 0, return empty string like mysql dose. + if ((src_len & 1) != 0 or src_len == 0) { + return 0; + } + //check and decode one character at the same time + // character in ( ['0','9'], ['a','f'], ['A','F'] ), return 'NULL' like mysql dose. + for (auto i = 0, dst_index = 0; i < src_len; i += 2, dst_index++) { + char c = 0; + // combine two character into dst_str one character + bool left_4bits_flag = check_and_decode_one(c, *(src_str + i), true); + bool right_4bits_flag = check_and_decode_one(c, *(src_str + i + 1), false); + + if (!left_4bits_flag || !right_4bits_flag) { + return 0; + } + *(dst_str + dst_index) = c; + } + return src_len / 2; + } + + static Status vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, + ColumnString::Chars& dst_data, ColumnString::Offsets& dst_offsets, + NullMap& null_map) { + auto rows_count = offsets.size(); + dst_offsets.resize(rows_count); + + for (int i = 0; i < rows_count; ++i) { + if (null_map[i]) { + StringOP::push_null_string(i, dst_data, dst_offsets, null_map); + continue; + } + const auto* source = reinterpret_cast(&data[offsets[i - 1]]); + size_t srclen = offsets[i] - offsets[i - 1]; + + if (srclen == 0) { + StringOP::push_empty_string(i, dst_data, dst_offsets); + continue; + } + + char dst_array[MAX_STACK_CIPHER_LEN]; + char* dst = dst_array; + + int cipher_len = srclen / 2; + std::unique_ptr dst_uptr; + if (cipher_len > MAX_STACK_CIPHER_LEN) { + dst_uptr.reset(new char[cipher_len]); + dst = dst_uptr.get(); + } + + int outlen = hex_decode(source, srclen, dst); + + StringOP::push_value_string(std::string_view(dst, outlen), i, dst_data, dst_offsets); + } + + return Status::OK(); + } +}; + struct NameStringSpace { static constexpr auto name = "space"; }; @@ -714,22 +794,16 @@ struct ToBase64Impl { using ColumnType = ColumnString; static Status vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, - ColumnString::Chars& dst_data, ColumnString::Offsets& dst_offsets, - NullMap& null_map) { + ColumnString::Chars& dst_data, ColumnString::Offsets& dst_offsets) { auto rows_count = offsets.size(); dst_offsets.resize(rows_count); for (int i = 0; i < rows_count; ++i) { - if (null_map[i]) { - StringOP::push_null_string(i, dst_data, dst_offsets, null_map); - continue; - } - - auto source = reinterpret_cast(&data[offsets[i - 1]]); + const auto* source = reinterpret_cast(&data[offsets[i - 1]]); size_t srclen = offsets[i] - offsets[i - 1]; if (srclen == 0) { - StringOP::push_null_string(i, dst_data, dst_offsets, null_map); + StringOP::push_empty_string(i, dst_data, dst_offsets); continue; } @@ -745,12 +819,50 @@ struct ToBase64Impl { int outlen = base64_encode((const unsigned char*)source, srclen, (unsigned char*)dst); - if (outlen < 0) { + StringOP::push_value_string(std::string_view(dst, outlen), i, dst_data, dst_offsets); + } + return Status::OK(); + } +}; + +struct ToBase64OldImpl { + static constexpr auto name = "to_base64"; + using ReturnType = DataTypeString; + using ColumnType = ColumnString; + + static Status vector(const ColumnString::Chars& data, const ColumnString::Offsets& offsets, + ColumnString::Chars& dst_data, ColumnString::Offsets& dst_offsets, + NullMap& null_map) { + auto rows_count = offsets.size(); + dst_offsets.resize(rows_count); + + for (int i = 0; i < rows_count; ++i) { + if (null_map[i]) { StringOP::push_null_string(i, dst_data, dst_offsets, null_map); - } else { - StringOP::push_value_string(std::string_view(dst, outlen), i, dst_data, - dst_offsets); + continue; } + + const auto* source = reinterpret_cast(&data[offsets[i - 1]]); + size_t srclen = offsets[i] - offsets[i - 1]; + + if (srclen == 0) { + StringOP::push_empty_string(i, dst_data, dst_offsets); + continue; + } + + char dst_array[MAX_STACK_CIPHER_LEN]; + char* dst = dst_array; + + int cipher_len = (int)(4.0 * ceil((double)srclen / 3.0)); + std::unique_ptr dst_uptr; + if (cipher_len > MAX_STACK_CIPHER_LEN) { + dst_uptr.reset(new char[cipher_len]); + dst = dst_uptr.get(); + } + + int outlen = base64_encode((const unsigned char*)source, srclen, (unsigned char*)dst); + + StringOP::push_value_string(std::string_view(dst, outlen), i, dst_data, dst_offsets); } return Status::OK(); } @@ -773,11 +885,11 @@ struct FromBase64Impl { continue; } - auto source = reinterpret_cast(&data[offsets[i - 1]]); + const auto* source = reinterpret_cast(&data[offsets[i - 1]]); size_t srclen = offsets[i] - offsets[i - 1]; if (srclen == 0) { - StringOP::push_null_string(i, dst_data, dst_offsets, null_map); + StringOP::push_empty_string(i, dst_data, dst_offsets); continue; } @@ -946,8 +1058,11 @@ using FunctionToUpper = FunctionStringToString, NameTo using FunctionToInitcap = FunctionStringToString; -using FunctionUnHex = FunctionStringOperateToNullType; -using FunctionToBase64 = FunctionStringOperateToNullType; +using FunctionUnHex = FunctionStringEncode; +using FunctionToBase64 = FunctionStringEncode; + +using FunctionUnHexOld = FunctionStringOperateToNullType; +using FunctionToBase64Old = FunctionStringOperateToNullType; using FunctionFromBase64 = FunctionStringOperateToNullType; using FunctionStringAppendTrailingCharIfAbsent = @@ -1023,6 +1138,8 @@ void register_function_string(SimpleFunctionFactory& factory) { factory.register_alternative_function(); factory.register_alternative_function(); factory.register_alternative_function(); + factory.register_alternative_function(); + factory.register_alternative_function(); factory.register_alias(FunctionLeft::name, "strleft"); factory.register_alias(FunctionRight::name, "strright"); diff --git a/be/src/vec/functions/function_totype.h b/be/src/vec/functions/function_totype.h index 39c8cba09e..e6722a6941 100644 --- a/be/src/vec/functions/function_totype.h +++ b/be/src/vec/functions/function_totype.h @@ -474,7 +474,7 @@ public: auto& col_ptr = block.get_by_position(arguments[0]).column; auto res = Impl::ColumnType::create(); - if (const ColumnString* col = check_and_get_column(col_ptr.get())) { + if (const auto* col = check_and_get_column(col_ptr.get())) { auto col_res = Impl::ColumnType::create(); static_cast(Impl::vector(col->get_chars(), col->get_offsets(), col_res->get_chars(), col_res->get_offsets(), @@ -490,4 +490,38 @@ public: } }; +template +class FunctionStringEncode : public IFunction { +public: + static constexpr auto name = Impl::name; + + static FunctionPtr create() { return std::make_shared(); } + + String get_name() const override { return name; } + + size_t get_number_of_arguments() const override { return 1; } + + DataTypePtr get_return_type_impl(const DataTypes& arguments) const override { + return std::make_shared(); + } + + Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments, + size_t result, size_t input_rows_count) const override { + auto& col_ptr = block.get_by_position(arguments[0]).column; + + auto res = Impl::ColumnType::create(); + if (const auto* col = check_and_get_column(col_ptr.get())) { + auto col_res = Impl::ColumnType::create(); + static_cast(Impl::vector(col->get_chars(), col->get_offsets(), + col_res->get_chars(), col_res->get_offsets())); + block.replace_by_position(result, std::move(col_res)); + } else { + return Status::RuntimeError("Illegal column {} of argument of function {}", + block.get_by_position(arguments[0]).column->get_name(), + get_name()); + } + return Status::OK(); + } +}; + } // namespace doris::vectorized diff --git a/be/test/vec/function/function_string_test.cpp b/be/test/vec/function/function_string_test.cpp index 5425b0cd08..39a9dca901 100644 --- a/be/test/vec/function/function_string_test.cpp +++ b/be/test/vec/function/function_string_test.cpp @@ -482,7 +482,7 @@ TEST(function_string_test, function_to_base64_test) { DataSet data_set = {{{std::string("asd你好")}, {std::string("YXNk5L2g5aW9")}}, {{std::string("hello world")}, {std::string("aGVsbG8gd29ybGQ=")}}, {{std::string("HELLO,!^%")}, {std::string("SEVMTE8sIV4l")}}, - {{std::string("")}, {Null()}}, + {{std::string("")}, {std::string("")}}, {{std::string("MYtestSTR")}, {std::string("TVl0ZXN0U1RS")}}, {{std::string("ò&ø")}, {std::string("w7Imw7g=")}}}; @@ -496,7 +496,7 @@ TEST(function_string_test, function_from_base64_test) { DataSet data_set = {{{std::string("YXNk5L2g5aW9")}, {std::string("asd你好")}}, {{std::string("aGVsbG8gd29ybGQ=")}, {std::string("hello world")}}, {{std::string("SEVMTE8sIV4l")}, {std::string("HELLO,!^%")}}, - {{std::string("")}, {Null()}}, + {{std::string("")}, {std::string("")}}, {{std::string("TVl0ZXN0U1RS")}, {std::string("MYtestSTR")}}, {{std::string("w7Imw7g=")}, {std::string("ò&ø")}}, {{std::string("ò&ø")}, {Null()}}, diff --git a/docs/en/docs/sql-manual/sql-functions/string-functions/from-base64.md b/docs/en/docs/sql-manual/sql-functions/string-functions/from-base64.md index bb2d285a08..ea6cfe3745 100644 --- a/docs/en/docs/sql-manual/sql-functions/string-functions/from-base64.md +++ b/docs/en/docs/sql-manual/sql-functions/string-functions/from-base64.md @@ -31,7 +31,7 @@ under the License. `VARCHAR from_base64(VARCHAR str)` -Returns the result of Base64 decoding the input string +Returns the result of Base64 decoding the input string, NULL is returned when the input string is incorrect (with non-Base64 encoded characters). ### example diff --git a/docs/zh-CN/docs/sql-manual/sql-functions/string-functions/from-base64.md b/docs/zh-CN/docs/sql-manual/sql-functions/string-functions/from-base64.md index 34a6c488e5..f708e8a2fb 100644 --- a/docs/zh-CN/docs/sql-manual/sql-functions/string-functions/from-base64.md +++ b/docs/zh-CN/docs/sql-manual/sql-functions/string-functions/from-base64.md @@ -31,7 +31,7 @@ under the License. `VARCHAR from_base64(VARCHAR str)` -返回对输入的字符串进行Base64解码后的结果 +返回对输入的字符串进行Base64解码后的结果,当输入字符串不正确时(出现非Base64编码后可能出现的字符)将会返回NULL ### example diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ToBase64.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ToBase64.java index 99a1722dae..dab8e46ec3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ToBase64.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/ToBase64.java @@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.StringType; @@ -34,7 +34,7 @@ import java.util.List; * ScalarFunction 'to_base64'. This class is generated by GenerateFunction. */ public class ToBase64 extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(StringType.INSTANCE).args(StringType.INSTANCE) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Unhex.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Unhex.java index d500eef41e..4fb1cc37be 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Unhex.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/scalar/Unhex.java @@ -19,8 +19,8 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar; import org.apache.doris.catalog.FunctionSignature; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.functions.AlwaysNullable; import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature; +import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable; import org.apache.doris.nereids.trees.expressions.shape.UnaryExpression; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.types.StringType; @@ -35,7 +35,7 @@ import java.util.List; * ScalarFunction 'unhex'. This class is generated by GenerateFunction. */ public class Unhex extends ScalarFunction - implements UnaryExpression, ExplicitlyCastableSignature, AlwaysNullable { + implements UnaryExpression, ExplicitlyCastableSignature, PropagateNullable { public static final List SIGNATURES = ImmutableList.of( FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT).args(VarcharType.SYSTEM_DEFAULT), diff --git a/gensrc/script/doris_builtins_functions.py b/gensrc/script/doris_builtins_functions.py index 3fbe079eb3..1277f72db6 100644 --- a/gensrc/script/doris_builtins_functions.py +++ b/gensrc/script/doris_builtins_functions.py @@ -1405,8 +1405,8 @@ visible_functions = { [['truncate'], 'DECIMAL64', ['DECIMAL64', 'INT'], ''], [['truncate'], 'DECIMAL128', ['DECIMAL128', 'INT'], ''], - [['unhex'], 'VARCHAR', ['VARCHAR'], 'ALWAYS_NULLABLE'], - [['unhex'], 'STRING', ['STRING'], 'ALWAYS_NULLABLE'] + [['unhex'], 'VARCHAR', ['VARCHAR'], 'DEPEND_ON_ARGUMENT'], + [['unhex'], 'STRING', ['STRING'], 'DEPEND_ON_ARGUMENT'] ], # Conditional Functions @@ -1908,8 +1908,8 @@ visible_functions = { [['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['sm4_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'], [['from_base64'], 'STRING', ['STRING'], 'ALWAYS_NULLABLE'], - [['to_base64'], 'STRING', ['STRING'], 'ALWAYS_NULLABLE'], - [['to_base64'], 'VARCHAR', ['VARCHAR'], 'ALWAYS_NULLABLE'] + [['to_base64'], 'STRING', ['STRING'], 'DEPEND_ON_ARGUMENT'], + [['to_base64'], 'VARCHAR', ['VARCHAR'], 'DEPEND_ON_ARGUMENT'] ], # for compatable with MySQL diff --git a/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out b/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out index c0bffd9e92..7f2e30b5c1 100644 --- a/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out +++ b/regression-test/data/nereids_p0/sql_functions/string_functions/test_string_function.out @@ -104,6 +104,12 @@ A -- !sql -- AB +-- !sql -- + + +-- !sql -- +\N + -- !sql -- 2 diff --git a/regression-test/data/query_p0/sql_functions/string_functions/test_string_function.out b/regression-test/data/query_p0/sql_functions/string_functions/test_string_function.out index fef380ebc7..e3ca494c63 100644 Binary files a/regression-test/data/query_p0/sql_functions/string_functions/test_string_function.out and b/regression-test/data/query_p0/sql_functions/string_functions/test_string_function.out differ diff --git a/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy b/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy index b1eb8aeefa..64fcc2e372 100644 --- a/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy +++ b/regression-test/suites/nereids_p0/sql_functions/string_functions/test_string_function.groovy @@ -73,6 +73,8 @@ suite("test_string_function") { qt_sql "select unhex('68656C6C6F2C646F726973');" qt_sql "select unhex('41');" qt_sql "select unhex('4142');" + qt_sql "select unhex('');" + qt_sql "select unhex(NULL);" qt_sql "select instr(\"abc\", \"b\");" qt_sql "select instr(\"abc\", \"d\");" diff --git a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy index f43941870b..2f5b49aa22 100644 --- a/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy +++ b/regression-test/suites/query_p0/sql_functions/string_functions/test_string_function.groovy @@ -68,6 +68,8 @@ suite("test_string_function", "arrow_flight_sql") { qt_sql "select unhex('68656C6C6F2C646F726973');" qt_sql "select unhex('41');" qt_sql "select unhex('4142');" + qt_sql "select unhex('');" + qt_sql "select unhex(NULL);" qt_sql_instr "select instr(\"abc\", \"b\");" qt_sql_instr "select instr(\"abc\", \"d\");"