pick #40004 to branch-2.1
This commit is contained in:
@ -25,6 +25,7 @@
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
|
||||
namespace doris {
|
||||
|
||||
@ -80,6 +81,12 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) {
|
||||
return EVP_aes_256_ctr();
|
||||
case EncryptionMode::AES_256_OFB:
|
||||
return EVP_aes_256_ofb();
|
||||
case EncryptionMode::AES_128_GCM:
|
||||
return EVP_aes_128_gcm();
|
||||
case EncryptionMode::AES_192_GCM:
|
||||
return EVP_aes_192_gcm();
|
||||
case EncryptionMode::AES_256_GCM:
|
||||
return EVP_aes_256_gcm();
|
||||
case EncryptionMode::SM4_128_CBC:
|
||||
return EVP_sm4_cbc();
|
||||
case EncryptionMode::SM4_128_ECB:
|
||||
@ -95,41 +102,29 @@ const EVP_CIPHER* get_evp_type(const EncryptionMode mode) {
|
||||
}
|
||||
}
|
||||
|
||||
static uint mode_key_sizes[] = {
|
||||
128 /* AES_128_ECB */,
|
||||
192 /* AES_192_ECB */,
|
||||
256 /* AES_256_ECB */,
|
||||
128 /* AES_128_CBC */,
|
||||
192 /* AES_192_CBC */,
|
||||
256 /* AES_256_CBC */,
|
||||
128 /* AES_128_CFB */,
|
||||
192 /* AES_192_CFB */,
|
||||
256 /* AES_256_CFB */,
|
||||
128 /* AES_128_CFB1 */,
|
||||
192 /* AES_192_CFB1 */,
|
||||
256 /* AES_256_CFB1 */,
|
||||
128 /* AES_128_CFB8 */,
|
||||
192 /* AES_192_CFB8 */,
|
||||
256 /* AES_256_CFB8 */,
|
||||
128 /* AES_128_CFB128 */,
|
||||
192 /* AES_192_CFB128 */,
|
||||
256 /* AES_256_CFB128 */,
|
||||
128 /* AES_128_CTR */,
|
||||
192 /* AES_192_CTR */,
|
||||
256 /* AES_256_CTR */,
|
||||
128 /* AES_128_OFB */,
|
||||
192 /* AES_192_OFB */,
|
||||
256 /* AES_256_OFB */,
|
||||
128 /* SM4_128_ECB */,
|
||||
128 /* SM4_128_CBC */,
|
||||
128 /* SM4_128_CFB128 */,
|
||||
128 /* SM4_128_OFB */,
|
||||
128 /* SM4_128_CTR */
|
||||
};
|
||||
static std::unordered_map<EncryptionMode, uint> mode_key_sizes = {
|
||||
{EncryptionMode::AES_128_ECB, 128}, {EncryptionMode::AES_192_ECB, 192},
|
||||
{EncryptionMode::AES_256_ECB, 256}, {EncryptionMode::AES_128_CBC, 128},
|
||||
{EncryptionMode::AES_192_CBC, 192}, {EncryptionMode::AES_256_CBC, 256},
|
||||
{EncryptionMode::AES_128_CFB, 128}, {EncryptionMode::AES_192_CFB, 192},
|
||||
{EncryptionMode::AES_256_CFB, 256}, {EncryptionMode::AES_128_CFB1, 128},
|
||||
{EncryptionMode::AES_192_CFB1, 192}, {EncryptionMode::AES_256_CFB1, 256},
|
||||
{EncryptionMode::AES_128_CFB8, 128}, {EncryptionMode::AES_192_CFB8, 192},
|
||||
{EncryptionMode::AES_256_CFB8, 256}, {EncryptionMode::AES_128_CFB128, 128},
|
||||
{EncryptionMode::AES_192_CFB128, 192}, {EncryptionMode::AES_256_CFB128, 256},
|
||||
{EncryptionMode::AES_128_CTR, 128}, {EncryptionMode::AES_192_CTR, 192},
|
||||
{EncryptionMode::AES_256_CTR, 256}, {EncryptionMode::AES_128_OFB, 128},
|
||||
{EncryptionMode::AES_192_OFB, 192}, {EncryptionMode::AES_256_OFB, 256},
|
||||
{EncryptionMode::AES_128_GCM, 128}, {EncryptionMode::AES_192_GCM, 192},
|
||||
{EncryptionMode::AES_256_GCM, 256},
|
||||
|
||||
{EncryptionMode::SM4_128_ECB, 128}, {EncryptionMode::SM4_128_CBC, 128},
|
||||
{EncryptionMode::SM4_128_CFB128, 128}, {EncryptionMode::SM4_128_OFB, 128},
|
||||
{EncryptionMode::SM4_128_CTR, 128}};
|
||||
|
||||
static void create_key(const unsigned char* origin_key, uint32_t key_length, uint8_t* encrypt_key,
|
||||
EncryptionMode mode) {
|
||||
const uint key_size = mode_key_sizes[int(mode)] / 8;
|
||||
const uint key_size = mode_key_sizes[mode] / 8;
|
||||
uint8_t* origin_key_end = ((uint8_t*)origin_key) + key_length; /* origin key boundary*/
|
||||
|
||||
uint8_t* encrypt_key_end; /* encrypt key boundary */
|
||||
@ -172,10 +167,58 @@ static int do_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int do_gcm_encrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
|
||||
const unsigned char* source, uint32_t source_length,
|
||||
const unsigned char* encrypt_key, const unsigned char* iv, int iv_length,
|
||||
unsigned char* encrypt, int* length_ptr, const unsigned char* aad,
|
||||
uint32_t aad_length) {
|
||||
int ret = EVP_EncryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, nullptr);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, nullptr);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
ret = EVP_EncryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, iv);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
if (aad) {
|
||||
int tmp_len = 0;
|
||||
ret = EVP_EncryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, aad_length);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
std::memcpy(encrypt, iv, iv_length);
|
||||
encrypt += iv_length;
|
||||
|
||||
int u_len = 0;
|
||||
ret = EVP_EncryptUpdate(cipher_ctx, encrypt, &u_len, source, source_length);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
encrypt += u_len;
|
||||
|
||||
int f_len = 0;
|
||||
ret = EVP_EncryptFinal_ex(cipher_ctx, encrypt, &f_len);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
encrypt += f_len;
|
||||
|
||||
ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_GET_TAG, EncryptionUtil::GCM_TAG_SIZE,
|
||||
encrypt);
|
||||
*length_ptr = iv_length + u_len + f_len + EncryptionUtil::GCM_TAG_SIZE;
|
||||
return ret;
|
||||
}
|
||||
|
||||
int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source,
|
||||
uint32_t source_length, const unsigned char* key, uint32_t key_length,
|
||||
const char* iv_str, int iv_input_length, bool padding,
|
||||
unsigned char* encrypt) {
|
||||
unsigned char* encrypt, const unsigned char* aad, uint32_t aad_length) {
|
||||
const EVP_CIPHER* cipher = get_evp_type(mode);
|
||||
/* The encrypt key to be used for encryption */
|
||||
unsigned char encrypt_key[ENCRYPTION_MAX_KEY_LENGTH / 8];
|
||||
@ -196,8 +239,16 @@ int EncryptionUtil::encrypt(EncryptionMode mode, const unsigned char* source,
|
||||
EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
|
||||
EVP_CIPHER_CTX_reset(cipher_ctx);
|
||||
int length = 0;
|
||||
int ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
|
||||
int ret = 0;
|
||||
if (is_gcm_mode(mode)) {
|
||||
ret = do_gcm_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
|
||||
reinterpret_cast<unsigned char*>(init_vec), iv_length, encrypt,
|
||||
&length, aad, aad_length);
|
||||
} else {
|
||||
ret = do_encrypt(cipher_ctx, cipher, source, source_length, encrypt_key,
|
||||
reinterpret_cast<unsigned char*>(init_vec), padding, encrypt, &length);
|
||||
}
|
||||
|
||||
EVP_CIPHER_CTX_free(cipher_ctx);
|
||||
if (ret == 0) {
|
||||
ERR_clear_error();
|
||||
@ -230,10 +281,61 @@ static int do_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
|
||||
return ret;
|
||||
}
|
||||
|
||||
static int do_gcm_decrypt(EVP_CIPHER_CTX* cipher_ctx, const EVP_CIPHER* cipher,
|
||||
const unsigned char* encrypt, uint32_t encrypt_length,
|
||||
const unsigned char* encrypt_key, int iv_length,
|
||||
unsigned char* decrypt_content, int* length_ptr, const unsigned char* aad,
|
||||
uint32_t aad_length) {
|
||||
if (encrypt_length < iv_length + EncryptionUtil::GCM_TAG_SIZE) {
|
||||
return -1;
|
||||
}
|
||||
int ret = EVP_DecryptInit_ex(cipher_ctx, cipher, nullptr, nullptr, nullptr);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_IVLEN, iv_length, nullptr);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
ret = EVP_DecryptInit_ex(cipher_ctx, nullptr, nullptr, encrypt_key, encrypt);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
encrypt += iv_length;
|
||||
if (aad) {
|
||||
int tmp_len = 0;
|
||||
ret = EVP_DecryptUpdate(cipher_ctx, nullptr, &tmp_len, aad, aad_length);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
}
|
||||
|
||||
uint32_t real_encrypt_length = encrypt_length - iv_length - EncryptionUtil::GCM_TAG_SIZE;
|
||||
int u_len = 0;
|
||||
ret = EVP_DecryptUpdate(cipher_ctx, decrypt_content, &u_len, encrypt, real_encrypt_length);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
encrypt += real_encrypt_length;
|
||||
decrypt_content += u_len;
|
||||
|
||||
void* tag = const_cast<void*>(reinterpret_cast<const void*>(encrypt));
|
||||
ret = EVP_CIPHER_CTX_ctrl(cipher_ctx, EVP_CTRL_GCM_SET_TAG, EncryptionUtil::GCM_TAG_SIZE, tag);
|
||||
if (ret != 1) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
int f_len = 0;
|
||||
ret = EVP_DecryptFinal_ex(cipher_ctx, decrypt_content, &f_len);
|
||||
*length_ptr = u_len + f_len;
|
||||
return ret;
|
||||
}
|
||||
|
||||
int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt,
|
||||
uint32_t encrypt_length, const unsigned char* key, uint32_t key_length,
|
||||
const char* iv_str, int iv_input_length, bool padding,
|
||||
unsigned char* decrypt_content) {
|
||||
unsigned char* decrypt_content, const unsigned char* aad,
|
||||
uint32_t aad_length) {
|
||||
const EVP_CIPHER* cipher = get_evp_type(mode);
|
||||
|
||||
/* The encrypt key to be used for decryption */
|
||||
@ -255,9 +357,15 @@ int EncryptionUtil::decrypt(EncryptionMode mode, const unsigned char* encrypt,
|
||||
EVP_CIPHER_CTX* cipher_ctx = EVP_CIPHER_CTX_new();
|
||||
EVP_CIPHER_CTX_reset(cipher_ctx);
|
||||
int length = 0;
|
||||
int ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key,
|
||||
int ret = 0;
|
||||
if (is_gcm_mode(mode)) {
|
||||
ret = do_gcm_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key, iv_length,
|
||||
decrypt_content, &length, aad, aad_length);
|
||||
} else {
|
||||
ret = do_decrypt(cipher_ctx, cipher, encrypt, encrypt_length, encrypt_key,
|
||||
reinterpret_cast<unsigned char*>(init_vec), padding, decrypt_content,
|
||||
&length);
|
||||
}
|
||||
EVP_CIPHER_CTX_free(cipher_ctx);
|
||||
if (ret > 0) {
|
||||
return length;
|
||||
|
||||
@ -46,6 +46,9 @@ enum class EncryptionMode {
|
||||
AES_128_OFB,
|
||||
AES_192_OFB,
|
||||
AES_256_OFB,
|
||||
AES_128_GCM,
|
||||
AES_192_GCM,
|
||||
AES_256_GCM,
|
||||
SM4_128_ECB,
|
||||
SM4_128_CBC,
|
||||
SM4_128_CFB128,
|
||||
@ -57,13 +60,23 @@ enum EncryptionState { AES_SUCCESS = 0, AES_BAD_DATA = -1 };
|
||||
|
||||
class EncryptionUtil {
|
||||
public:
|
||||
static bool is_gcm_mode(EncryptionMode mode) {
|
||||
return mode == EncryptionMode::AES_128_GCM || mode == EncryptionMode::AES_192_GCM ||
|
||||
mode == EncryptionMode::AES_256_GCM;
|
||||
}
|
||||
|
||||
// https://tools.ietf.org/html/rfc5116#section-5.1
|
||||
static const int GCM_TAG_SIZE = 16;
|
||||
|
||||
static int encrypt(EncryptionMode mode, const unsigned char* source, uint32_t source_length,
|
||||
const unsigned char* key, uint32_t key_length, const char* iv_str,
|
||||
int iv_input_length, bool padding, unsigned char* encrypt);
|
||||
int iv_input_length, bool padding, unsigned char* encrypt,
|
||||
const unsigned char* aad = nullptr, uint32_t aad_length = 0);
|
||||
|
||||
static int decrypt(EncryptionMode mode, const unsigned char* encrypt, uint32_t encrypt_length,
|
||||
const unsigned char* key, uint32_t key_length, const char* iv_str,
|
||||
int iv_input_length, bool padding, unsigned char* decrypt_content);
|
||||
int iv_input_length, bool padding, unsigned char* decrypt_content,
|
||||
const unsigned char* aad = nullptr, uint32_t aad_length = 0);
|
||||
};
|
||||
|
||||
} // namespace doris
|
||||
|
||||
@ -79,7 +79,10 @@ inline StringCaseUnorderedMap<EncryptionMode> aes_mode_map {
|
||||
{"AES_256_CTR", EncryptionMode::AES_256_CTR},
|
||||
{"AES_128_OFB", EncryptionMode::AES_128_OFB},
|
||||
{"AES_192_OFB", EncryptionMode::AES_192_OFB},
|
||||
{"AES_256_OFB", EncryptionMode::AES_256_OFB}};
|
||||
{"AES_256_OFB", EncryptionMode::AES_256_OFB},
|
||||
{"AES_128_GCM", EncryptionMode::AES_128_GCM},
|
||||
{"AES_192_GCM", EncryptionMode::AES_192_GCM},
|
||||
{"AES_256_GCM", EncryptionMode::AES_256_GCM}};
|
||||
inline StringCaseUnorderedMap<EncryptionMode> sm4_mode_map {
|
||||
{"SM4_128_ECB", EncryptionMode::SM4_128_ECB},
|
||||
{"SM4_128_CBC", EncryptionMode::SM4_128_CBC},
|
||||
@ -120,7 +123,7 @@ void execute_result_vector(std::vector<const ColumnString::Offsets*>& offsets_li
|
||||
std::vector<const ColumnString::Chars*>& chars_list, size_t i,
|
||||
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
|
||||
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
|
||||
NullMap& null_map) {
|
||||
NullMap& null_map, const char* aad, int aad_length) {
|
||||
int src_size = (*offsets_list[0])[i] - (*offsets_list[0])[i - 1];
|
||||
const auto* src_raw =
|
||||
reinterpret_cast<const char*>(&(*chars_list[0])[(*offsets_list[0])[i - 1]]);
|
||||
@ -128,7 +131,8 @@ void execute_result_vector(std::vector<const ColumnString::Offsets*>& offsets_li
|
||||
const auto* key_raw =
|
||||
reinterpret_cast<const char*>(&(*chars_list[1])[(*offsets_list[1])[i - 1]]);
|
||||
execute_result<Impl, is_encrypt>(src_raw, src_size, key_raw, key_size, i, encryption_mode,
|
||||
iv_raw, iv_length, result_data, result_offset, null_map);
|
||||
iv_raw, iv_length, result_data, result_offset, null_map, aad,
|
||||
aad_length);
|
||||
}
|
||||
|
||||
template <typename Impl, bool is_encrypt>
|
||||
@ -136,19 +140,19 @@ void execute_result_const(const ColumnString::Offsets* offsets_column,
|
||||
const ColumnString::Chars* chars_column, StringRef key_arg, size_t i,
|
||||
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
|
||||
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
|
||||
NullMap& null_map) {
|
||||
NullMap& null_map, const char* aad, int aad_length) {
|
||||
int src_size = (*offsets_column)[i] - (*offsets_column)[i - 1];
|
||||
const auto* src_raw = reinterpret_cast<const char*>(&(*chars_column)[(*offsets_column)[i - 1]]);
|
||||
execute_result<Impl, is_encrypt>(src_raw, src_size, key_arg.data, key_arg.size, i,
|
||||
encryption_mode, iv_raw, iv_length, result_data, result_offset,
|
||||
null_map);
|
||||
null_map, aad, aad_length);
|
||||
}
|
||||
|
||||
template <typename Impl, bool is_encrypt>
|
||||
void execute_result(const char* src_raw, int src_size, const char* key_raw, int key_size, size_t i,
|
||||
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
|
||||
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
|
||||
NullMap& null_map) {
|
||||
NullMap& null_map, const char* aad, int aad_length) {
|
||||
if (src_size == 0) {
|
||||
StringOP::push_null_string(i, result_data, result_offset, null_map);
|
||||
return;
|
||||
@ -156,6 +160,10 @@ void execute_result(const char* src_raw, int src_size, const char* key_raw, int
|
||||
int cipher_len = src_size;
|
||||
if constexpr (is_encrypt) {
|
||||
cipher_len += 16;
|
||||
// for output AEAD tag
|
||||
if (EncryptionUtil::is_gcm_mode(encryption_mode)) {
|
||||
cipher_len += EncryptionUtil::GCM_TAG_SIZE;
|
||||
}
|
||||
}
|
||||
std::unique_ptr<char[]> p;
|
||||
p.reset(new char[cipher_len]);
|
||||
@ -163,7 +171,7 @@ void execute_result(const char* src_raw, int src_size, const char* key_raw, int
|
||||
|
||||
ret_code = Impl::execute_impl(encryption_mode, (unsigned char*)src_raw, src_size,
|
||||
(unsigned char*)key_raw, key_size, iv_raw, iv_length, true,
|
||||
(unsigned char*)p.get());
|
||||
(unsigned char*)p.get(), (unsigned char*)aad, aad_length);
|
||||
|
||||
if (ret_code < 0) {
|
||||
StringOP::push_null_string(i, result_data, result_offset, null_map);
|
||||
@ -248,7 +256,7 @@ struct EncryptionAndDecryptTwoImpl {
|
||||
}
|
||||
execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i,
|
||||
encryption_mode, nullptr, 0, result_data,
|
||||
result_offset, null_map);
|
||||
result_offset, null_map, nullptr, 0);
|
||||
}
|
||||
}
|
||||
|
||||
@ -275,16 +283,22 @@ struct EncryptionAndDecryptTwoImpl {
|
||||
}
|
||||
execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode,
|
||||
nullptr, 0, result_data, result_offset,
|
||||
null_map);
|
||||
null_map, nullptr, 0);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Impl, EncryptionMode mode, bool is_encrypt, bool is_sm_mode>
|
||||
struct EncryptionAndDecryptFourImpl {
|
||||
template <typename Impl, EncryptionMode mode, bool is_encrypt, bool is_sm_mode, int arg_num = 4>
|
||||
struct EncryptionAndDecryptMultiImpl {
|
||||
static DataTypes get_variadic_argument_types_impl() {
|
||||
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
|
||||
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()};
|
||||
if constexpr (arg_num == 5) {
|
||||
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
|
||||
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
|
||||
std::make_shared<DataTypeString>()};
|
||||
} else {
|
||||
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
|
||||
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()};
|
||||
}
|
||||
}
|
||||
|
||||
static Status execute_impl_inner(FunctionContext* context, Block& block,
|
||||
@ -292,8 +306,8 @@ struct EncryptionAndDecryptFourImpl {
|
||||
size_t input_rows_count) {
|
||||
auto result_column = ColumnString::create();
|
||||
auto result_null_map_column = ColumnUInt8::create(input_rows_count, 0);
|
||||
DCHECK_EQ(4, arguments.size());
|
||||
const size_t argument_size = 4;
|
||||
DCHECK_EQ(arguments.size(), arg_num);
|
||||
const size_t argument_size = arg_num;
|
||||
bool col_const[argument_size];
|
||||
ColumnPtr argument_columns[argument_size];
|
||||
for (int i = 0; i < argument_size; ++i) {
|
||||
@ -304,8 +318,13 @@ struct EncryptionAndDecryptFourImpl {
|
||||
.convert_to_full_column()
|
||||
: block.get_by_position(arguments[0]).column;
|
||||
|
||||
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3}, block,
|
||||
arguments);
|
||||
if constexpr (arg_num == 5) {
|
||||
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3, 4}, block,
|
||||
arguments);
|
||||
} else {
|
||||
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3}, block,
|
||||
arguments);
|
||||
}
|
||||
|
||||
for (int i = 0; i < argument_size; i++) {
|
||||
check_set_nullable(argument_columns[i], result_null_map_column, col_const[i]);
|
||||
@ -314,11 +333,17 @@ struct EncryptionAndDecryptFourImpl {
|
||||
auto& result_offset = result_column->get_offsets();
|
||||
result_offset.resize(input_rows_count);
|
||||
|
||||
if (col_const[1] && col_const[2] && col_const[3]) {
|
||||
if ((arg_num == 5) && col_const[1] && col_const[2] && col_const[3] && col_const[4]) {
|
||||
vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()),
|
||||
argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0),
|
||||
argument_columns[3]->get_data_at(0), input_rows_count, result_data,
|
||||
result_offset, result_null_map_column->get_data());
|
||||
result_offset, result_null_map_column->get_data(),
|
||||
argument_columns[4]->get_data_at(0));
|
||||
} else if ((arg_num == 4) && col_const[1] && col_const[2] && col_const[3]) {
|
||||
vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()),
|
||||
argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0),
|
||||
argument_columns[3]->get_data_at(0), input_rows_count, result_data,
|
||||
result_offset, result_null_map_column->get_data(), StringRef());
|
||||
} else {
|
||||
std::vector<const ColumnString::Offsets*> offsets_list(argument_size);
|
||||
std::vector<const ColumnString::Chars*> chars_list(argument_size);
|
||||
@ -338,7 +363,7 @@ struct EncryptionAndDecryptFourImpl {
|
||||
static void vector_const(const ColumnString* column, StringRef key_arg, StringRef iv_arg,
|
||||
StringRef mode_arg, size_t input_rows_count,
|
||||
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
|
||||
NullMap& null_map) {
|
||||
NullMap& null_map, StringRef aad_arg) {
|
||||
EncryptionMode encryption_mode = mode;
|
||||
bool all_insert_null = false;
|
||||
if (mode_arg.size != 0) {
|
||||
@ -363,9 +388,9 @@ struct EncryptionAndDecryptFourImpl {
|
||||
StringOP::push_null_string(i, result_data, result_offset, null_map);
|
||||
continue;
|
||||
}
|
||||
execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i,
|
||||
encryption_mode, iv_arg.data, iv_arg.size,
|
||||
result_data, result_offset, null_map);
|
||||
execute_result_const<Impl, is_encrypt>(
|
||||
offsets_column, chars_column, key_arg, i, encryption_mode, iv_arg.data,
|
||||
iv_arg.size, result_data, result_offset, null_map, aad_arg.data, aad_arg.size);
|
||||
}
|
||||
}
|
||||
|
||||
@ -403,9 +428,16 @@ struct EncryptionAndDecryptFourImpl {
|
||||
}
|
||||
}
|
||||
|
||||
int aad_size = 0;
|
||||
const char* aad = nullptr;
|
||||
if constexpr (arg_num == 5) {
|
||||
aad_size = (*offsets_list[4])[i] - (*offsets_list[4])[i - 1];
|
||||
aad = reinterpret_cast<const char*>(&(*chars_list[4])[(*offsets_list[4])[i - 1]]);
|
||||
}
|
||||
|
||||
execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode,
|
||||
iv_raw, iv_size, result_data, result_offset,
|
||||
null_map);
|
||||
null_map, aad, aad_size);
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -413,18 +445,20 @@ struct EncryptionAndDecryptFourImpl {
|
||||
struct EncryptImpl {
|
||||
static int execute_impl(EncryptionMode mode, const unsigned char* source,
|
||||
uint32_t source_length, const unsigned char* key, uint32_t key_length,
|
||||
const char* iv, int iv_length, bool padding, unsigned char* encrypt) {
|
||||
const char* iv, int iv_length, bool padding, unsigned char* encrypt,
|
||||
const unsigned char* aad, int aad_length) {
|
||||
return EncryptionUtil::encrypt(mode, source, source_length, key, key_length, iv, iv_length,
|
||||
true, encrypt);
|
||||
true, encrypt, aad, aad_length);
|
||||
}
|
||||
};
|
||||
|
||||
struct DecryptImpl {
|
||||
static int execute_impl(EncryptionMode mode, const unsigned char* source,
|
||||
uint32_t source_length, const unsigned char* key, uint32_t key_length,
|
||||
const char* iv, int iv_length, bool padding, unsigned char* encrypt) {
|
||||
const char* iv, int iv_length, bool padding, unsigned char* encrypt,
|
||||
const unsigned char* aad, int aad_length) {
|
||||
return EncryptionUtil::decrypt(mode, source, source_length, key, key_length, iv, iv_length,
|
||||
true, encrypt);
|
||||
true, encrypt, aad, aad_length);
|
||||
}
|
||||
};
|
||||
|
||||
@ -459,16 +493,24 @@ void register_function_encryption(SimpleFunctionFactory& factory) {
|
||||
AESDecryptName>>();
|
||||
|
||||
factory.register_function<FunctionEncryptionAndDecrypt<
|
||||
EncryptionAndDecryptFourImpl<EncryptImpl, EncryptionMode::SM4_128_ECB, true, true>,
|
||||
EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::SM4_128_ECB, true, true>,
|
||||
SM4EncryptName>>();
|
||||
factory.register_function<FunctionEncryptionAndDecrypt<
|
||||
EncryptionAndDecryptFourImpl<DecryptImpl, EncryptionMode::SM4_128_ECB, false, true>,
|
||||
EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::SM4_128_ECB, false, true>,
|
||||
SM4DecryptName>>();
|
||||
factory.register_function<FunctionEncryptionAndDecrypt<
|
||||
EncryptionAndDecryptFourImpl<EncryptImpl, EncryptionMode::AES_128_ECB, true, false>,
|
||||
EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::AES_128_ECB, true, false>,
|
||||
AESEncryptName>>();
|
||||
factory.register_function<FunctionEncryptionAndDecrypt<
|
||||
EncryptionAndDecryptFourImpl<DecryptImpl, EncryptionMode::AES_128_ECB, false, false>,
|
||||
EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::AES_128_ECB, false, false>,
|
||||
AESDecryptName>>();
|
||||
|
||||
factory.register_function<FunctionEncryptionAndDecrypt<
|
||||
EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::AES_128_GCM, true, false, 5>,
|
||||
AESEncryptName>>();
|
||||
factory.register_function<FunctionEncryptionAndDecrypt<
|
||||
EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::AES_128_GCM, false, false,
|
||||
5>,
|
||||
AESDecryptName>>();
|
||||
}
|
||||
|
||||
|
||||
@ -19,6 +19,7 @@ package org.apache.doris.nereids.trees.expressions.functions.scalar;
|
||||
|
||||
import org.apache.doris.nereids.exceptions.AnalysisException;
|
||||
import org.apache.doris.nereids.trees.expressions.Expression;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLikeLiteral;
|
||||
import org.apache.doris.nereids.trees.expressions.literal.StringLiteral;
|
||||
|
||||
import com.google.common.collect.ImmutableSet;
|
||||
@ -52,7 +53,16 @@ public abstract class AesCryptoFunction extends CryptoFunction {
|
||||
"AES_256_CTR",
|
||||
"AES_128_OFB",
|
||||
"AES_192_OFB",
|
||||
"AES_256_OFB"
|
||||
"AES_256_OFB",
|
||||
"AES_128_GCM",
|
||||
"AES_192_GCM",
|
||||
"AES_256_GCM"
|
||||
);
|
||||
|
||||
public static final Set<String> AES_GCM_MODES = ImmutableSet.of(
|
||||
"AES_128_GCM",
|
||||
"AES_192_GCM",
|
||||
"AES_256_GCM"
|
||||
);
|
||||
|
||||
public AesCryptoFunction(String name, Expression... arguments) {
|
||||
@ -72,4 +82,17 @@ public abstract class AesCryptoFunction extends CryptoFunction {
|
||||
}
|
||||
return encryptionMode;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void checkLegalityAfterRewrite() {
|
||||
if (arity() >= 4 && child(3) instanceof StringLikeLiteral) {
|
||||
String mode = ((StringLikeLiteral) child(3)).getValue().toUpperCase();
|
||||
if (!AES_MODES.contains(mode)) {
|
||||
throw new AnalysisException("mode " + mode + " is not supported");
|
||||
}
|
||||
if (arity() == 5 && !AES_GCM_MODES.contains(mode)) {
|
||||
throw new AnalysisException("only GCM mode support AAD(the 5th arg)");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -50,7 +50,16 @@ public class AesDecrypt extends AesCryptoFunction {
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(StringType.INSTANCE)
|
||||
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE)
|
||||
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
|
||||
.args(VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(StringType.INSTANCE)
|
||||
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE,
|
||||
StringType.INSTANCE)
|
||||
);
|
||||
|
||||
/**
|
||||
@ -68,18 +77,25 @@ public class AesDecrypt extends AesCryptoFunction {
|
||||
super("aes_decrypt", arg0, arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
public AesDecrypt(Expression arg0, Expression arg1, Expression arg2, Expression arg3, Expression arg4) {
|
||||
super("aes_decrypt", arg0, arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
/**
|
||||
* withChildren.
|
||||
*/
|
||||
@Override
|
||||
public AesDecrypt withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
|
||||
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 5);
|
||||
if (children.size() == 2) {
|
||||
return new AesDecrypt(children.get(0), children.get(1));
|
||||
} else if (children().size() == 3) {
|
||||
return new AesDecrypt(children.get(0), children.get(1), children.get(2));
|
||||
} else {
|
||||
} else if (children().size() == 4) {
|
||||
return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3));
|
||||
} else {
|
||||
return new AesDecrypt(children.get(0), children.get(1), children.get(2), children.get(3),
|
||||
children.get(4));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -50,7 +50,16 @@ public class AesEncrypt extends AesCryptoFunction {
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(StringType.INSTANCE)
|
||||
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE)
|
||||
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE),
|
||||
FunctionSignature.ret(VarcharType.SYSTEM_DEFAULT)
|
||||
.args(VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT,
|
||||
VarcharType.SYSTEM_DEFAULT),
|
||||
FunctionSignature.ret(StringType.INSTANCE)
|
||||
.args(StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE, StringType.INSTANCE,
|
||||
StringType.INSTANCE)
|
||||
);
|
||||
|
||||
/**
|
||||
@ -68,18 +77,25 @@ public class AesEncrypt extends AesCryptoFunction {
|
||||
super("aes_encrypt", arg0, arg1, arg2, arg3);
|
||||
}
|
||||
|
||||
public AesEncrypt(Expression arg0, Expression arg1, Expression arg2, Expression arg3, Expression arg4) {
|
||||
super("aes_encrypt", arg0, arg1, arg2, arg3, arg4);
|
||||
}
|
||||
|
||||
/**
|
||||
* withChildren.
|
||||
*/
|
||||
@Override
|
||||
public AesEncrypt withChildren(List<Expression> children) {
|
||||
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 4);
|
||||
Preconditions.checkArgument(children.size() >= 2 && children.size() <= 5);
|
||||
if (children.size() == 2) {
|
||||
return new AesEncrypt(children.get(0), children.get(1));
|
||||
} else if (children().size() == 3) {
|
||||
return new AesEncrypt(children.get(0), children.get(1), children.get(2));
|
||||
} else {
|
||||
} else if (children().size() == 4) {
|
||||
return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3));
|
||||
} else {
|
||||
return new AesEncrypt(children.get(0), children.get(1), children.get(2), children.get(3),
|
||||
children.get(4));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -1919,6 +1919,8 @@ visible_functions = {
|
||||
[['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
[['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
[['sm4_decrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
[['sm4_encrypt'], 'VARCHAR', ['VARCHAR', 'VARCHAR', 'VARCHAR', 'VARCHAR'], 'ALWAYS_NULLABLE'],
|
||||
@ -1928,6 +1930,8 @@ visible_functions = {
|
||||
[['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
[['aes_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
[['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
[['sm4_decrypt'], 'STRING', ['STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
[['sm4_encrypt'], 'STRING', ['STRING', 'STRING', 'STRING', 'STRING'], 'ALWAYS_NULLABLE'],
|
||||
|
||||
@ -56,3 +56,29 @@ text
|
||||
-- !sql --
|
||||
82ec580fe6d36ae4f81cae3c73f4a5b3b5a09c943172dc9053c69fd8e18dca1e
|
||||
|
||||
-- !sql_gcm_1 --
|
||||
MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==
|
||||
|
||||
-- !sql_gcm_2 --
|
||||
Spark SQL
|
||||
|
||||
-- !sql_gcm_3 --
|
||||
AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4
|
||||
|
||||
-- !sql_gcm_4 --
|
||||
Spark
|
||||
|
||||
-- !sql_gcm_5 --
|
||||
1 MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==
|
||||
2 AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4
|
||||
|
||||
-- !sql_gcm_6 --
|
||||
1 Spark SQL
|
||||
2 Spark
|
||||
|
||||
-- !sql_gcm_7 --
|
||||
1 MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==
|
||||
|
||||
-- !sql_gcm_8 --
|
||||
Spark SQL
|
||||
|
||||
|
||||
@ -57,4 +57,36 @@ suite("test_encryption_function") {
|
||||
qt_sql "SELECT SM3(\"abc\");"
|
||||
qt_sql "select sm3(\"abcd\");"
|
||||
qt_sql "select sm3sum(\"ab\",\"cd\");"
|
||||
|
||||
qt_sql_gcm_1 "SELECT TO_BASE64(AES_ENCRYPT('Spark SQL', '1234567890abcdef', '123456789012', 'aes_128_gcm', 'Some AAD'))"
|
||||
qt_sql_gcm_2 "SELECT AES_DECRYPT(FROM_BASE64('MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA=='), '1234567890abcdef', '', 'aes_128_gcm', 'Some AAD')"
|
||||
|
||||
qt_sql_gcm_3 "select to_base64(aes_encrypt('Spark','abcdefghijklmnop12345678ABCDEFGH',unhex('000000000000000000000000'),'aes_256_gcm', 'This is an AAD mixed into the input'));"
|
||||
qt_sql_gcm_4 "SELECT AES_DECRYPT(FROM_BASE64('AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4'), 'abcdefghijklmnop12345678ABCDEFGH', '', 'aes_256_gcm', 'This is an AAD mixed into the input');"
|
||||
|
||||
sql "DROP TABLE IF EXISTS aes_encrypt_decrypt_tbl"
|
||||
sql """
|
||||
CREATE TABLE IF NOT EXISTS aes_encrypt_decrypt_tbl (
|
||||
id int,
|
||||
plain_txt varchar(255),
|
||||
enc_txt varchar(255),
|
||||
k varchar(255),
|
||||
iv varchar(255),
|
||||
mode varchar(255),
|
||||
aad varchar(255)
|
||||
) DISTRIBUTED BY HASH(id) BUCKETS 1
|
||||
PROPERTIES (
|
||||
"replication_num" = "1"
|
||||
)
|
||||
"""
|
||||
sql """ insert into aes_encrypt_decrypt_tbl values(1,'Spark SQL','MTIzNDU2Nzg5MDEyMdXvR41sJqwZ6hnTU8FRTTtXbL8yeChIZA==','1234567890abcdef','123456789012','aes_128_gcm','Some AAD');"""
|
||||
sql """ insert into aes_encrypt_decrypt_tbl values(2,'Spark','AAAAAAAAAAAAAAAAQiYi+sTLm7KD9UcZ2nlRdYDe/PX4','abcdefghijklmnop12345678ABCDEFGH',unhex('000000000000000000000000'),'aes_256_gcm','This is an AAD mixed into the input');"""
|
||||
sql """ sync """
|
||||
|
||||
qt_sql_gcm_5 "SELECT id,TO_BASE64(AES_ENCRYPT(plain_txt,k,iv,mode,aad)) from aes_encrypt_decrypt_tbl order by id;"
|
||||
qt_sql_gcm_6 "SELECT id,AES_DECRYPT(FROM_BASE64(enc_txt),k,'',mode,aad) from aes_encrypt_decrypt_tbl order by id;"
|
||||
|
||||
// test for const opt branch, only first column is not const
|
||||
qt_sql_gcm_7 "SELECT id,TO_BASE64(AES_ENCRYPT(plain_txt, '1234567890abcdef', '123456789012', 'aes_128_gcm', 'Some AAD')) from aes_encrypt_decrypt_tbl where id=1"
|
||||
qt_sql_gcm_8 "SELECT AES_DECRYPT(FROM_BASE64(enc_txt), '1234567890abcdef', '', 'aes_128_gcm', 'Some AAD') from aes_encrypt_decrypt_tbl where id=1"
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user