[feature](aes_encrypt) support GCM mode for aes_encrypt and aes_decrypt (#40004) (#40672)

pick #40004 to branch-2.1
This commit is contained in:
camby
2024-09-11 23:28:28 +08:00
committed by GitHub
parent bf156d1665
commit 361a59dec8
9 changed files with 357 additions and 77 deletions

View File

@ -79,7 +79,10 @@ inline StringCaseUnorderedMap<EncryptionMode> aes_mode_map {
{"AES_256_CTR", EncryptionMode::AES_256_CTR},
{"AES_128_OFB", EncryptionMode::AES_128_OFB},
{"AES_192_OFB", EncryptionMode::AES_192_OFB},
{"AES_256_OFB", EncryptionMode::AES_256_OFB}};
{"AES_256_OFB", EncryptionMode::AES_256_OFB},
{"AES_128_GCM", EncryptionMode::AES_128_GCM},
{"AES_192_GCM", EncryptionMode::AES_192_GCM},
{"AES_256_GCM", EncryptionMode::AES_256_GCM}};
inline StringCaseUnorderedMap<EncryptionMode> sm4_mode_map {
{"SM4_128_ECB", EncryptionMode::SM4_128_ECB},
{"SM4_128_CBC", EncryptionMode::SM4_128_CBC},
@ -120,7 +123,7 @@ void execute_result_vector(std::vector<const ColumnString::Offsets*>& offsets_li
std::vector<const ColumnString::Chars*>& chars_list, size_t i,
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
NullMap& null_map, const char* aad, int aad_length) {
int src_size = (*offsets_list[0])[i] - (*offsets_list[0])[i - 1];
const auto* src_raw =
reinterpret_cast<const char*>(&(*chars_list[0])[(*offsets_list[0])[i - 1]]);
@ -128,7 +131,8 @@ void execute_result_vector(std::vector<const ColumnString::Offsets*>& offsets_li
const auto* key_raw =
reinterpret_cast<const char*>(&(*chars_list[1])[(*offsets_list[1])[i - 1]]);
execute_result<Impl, is_encrypt>(src_raw, src_size, key_raw, key_size, i, encryption_mode,
iv_raw, iv_length, result_data, result_offset, null_map);
iv_raw, iv_length, result_data, result_offset, null_map, aad,
aad_length);
}
template <typename Impl, bool is_encrypt>
@ -136,19 +140,19 @@ void execute_result_const(const ColumnString::Offsets* offsets_column,
const ColumnString::Chars* chars_column, StringRef key_arg, size_t i,
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
NullMap& null_map, const char* aad, int aad_length) {
int src_size = (*offsets_column)[i] - (*offsets_column)[i - 1];
const auto* src_raw = reinterpret_cast<const char*>(&(*chars_column)[(*offsets_column)[i - 1]]);
execute_result<Impl, is_encrypt>(src_raw, src_size, key_arg.data, key_arg.size, i,
encryption_mode, iv_raw, iv_length, result_data, result_offset,
null_map);
null_map, aad, aad_length);
}
template <typename Impl, bool is_encrypt>
void execute_result(const char* src_raw, int src_size, const char* key_raw, int key_size, size_t i,
EncryptionMode& encryption_mode, const char* iv_raw, int iv_length,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
NullMap& null_map, const char* aad, int aad_length) {
if (src_size == 0) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
return;
@ -156,6 +160,10 @@ void execute_result(const char* src_raw, int src_size, const char* key_raw, int
int cipher_len = src_size;
if constexpr (is_encrypt) {
cipher_len += 16;
// for output AEAD tag
if (EncryptionUtil::is_gcm_mode(encryption_mode)) {
cipher_len += EncryptionUtil::GCM_TAG_SIZE;
}
}
std::unique_ptr<char[]> p;
p.reset(new char[cipher_len]);
@ -163,7 +171,7 @@ void execute_result(const char* src_raw, int src_size, const char* key_raw, int
ret_code = Impl::execute_impl(encryption_mode, (unsigned char*)src_raw, src_size,
(unsigned char*)key_raw, key_size, iv_raw, iv_length, true,
(unsigned char*)p.get());
(unsigned char*)p.get(), (unsigned char*)aad, aad_length);
if (ret_code < 0) {
StringOP::push_null_string(i, result_data, result_offset, null_map);
@ -248,7 +256,7 @@ struct EncryptionAndDecryptTwoImpl {
}
execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i,
encryption_mode, nullptr, 0, result_data,
result_offset, null_map);
result_offset, null_map, nullptr, 0);
}
}
@ -275,16 +283,22 @@ struct EncryptionAndDecryptTwoImpl {
}
execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode,
nullptr, 0, result_data, result_offset,
null_map);
null_map, nullptr, 0);
}
}
};
template <typename Impl, EncryptionMode mode, bool is_encrypt, bool is_sm_mode>
struct EncryptionAndDecryptFourImpl {
template <typename Impl, EncryptionMode mode, bool is_encrypt, bool is_sm_mode, int arg_num = 4>
struct EncryptionAndDecryptMultiImpl {
static DataTypes get_variadic_argument_types_impl() {
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()};
if constexpr (arg_num == 5) {
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>()};
} else {
return {std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>(),
std::make_shared<DataTypeString>(), std::make_shared<DataTypeString>()};
}
}
static Status execute_impl_inner(FunctionContext* context, Block& block,
@ -292,8 +306,8 @@ struct EncryptionAndDecryptFourImpl {
size_t input_rows_count) {
auto result_column = ColumnString::create();
auto result_null_map_column = ColumnUInt8::create(input_rows_count, 0);
DCHECK_EQ(4, arguments.size());
const size_t argument_size = 4;
DCHECK_EQ(arguments.size(), arg_num);
const size_t argument_size = arg_num;
bool col_const[argument_size];
ColumnPtr argument_columns[argument_size];
for (int i = 0; i < argument_size; ++i) {
@ -304,8 +318,13 @@ struct EncryptionAndDecryptFourImpl {
.convert_to_full_column()
: block.get_by_position(arguments[0]).column;
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3}, block,
arguments);
if constexpr (arg_num == 5) {
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3, 4}, block,
arguments);
} else {
default_preprocess_parameter_columns(argument_columns, col_const, {1, 2, 3}, block,
arguments);
}
for (int i = 0; i < argument_size; i++) {
check_set_nullable(argument_columns[i], result_null_map_column, col_const[i]);
@ -314,11 +333,17 @@ struct EncryptionAndDecryptFourImpl {
auto& result_offset = result_column->get_offsets();
result_offset.resize(input_rows_count);
if (col_const[1] && col_const[2] && col_const[3]) {
if ((arg_num == 5) && col_const[1] && col_const[2] && col_const[3] && col_const[4]) {
vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()),
argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0),
argument_columns[3]->get_data_at(0), input_rows_count, result_data,
result_offset, result_null_map_column->get_data());
result_offset, result_null_map_column->get_data(),
argument_columns[4]->get_data_at(0));
} else if ((arg_num == 4) && col_const[1] && col_const[2] && col_const[3]) {
vector_const(assert_cast<const ColumnString*>(argument_columns[0].get()),
argument_columns[1]->get_data_at(0), argument_columns[2]->get_data_at(0),
argument_columns[3]->get_data_at(0), input_rows_count, result_data,
result_offset, result_null_map_column->get_data(), StringRef());
} else {
std::vector<const ColumnString::Offsets*> offsets_list(argument_size);
std::vector<const ColumnString::Chars*> chars_list(argument_size);
@ -338,7 +363,7 @@ struct EncryptionAndDecryptFourImpl {
static void vector_const(const ColumnString* column, StringRef key_arg, StringRef iv_arg,
StringRef mode_arg, size_t input_rows_count,
ColumnString::Chars& result_data, ColumnString::Offsets& result_offset,
NullMap& null_map) {
NullMap& null_map, StringRef aad_arg) {
EncryptionMode encryption_mode = mode;
bool all_insert_null = false;
if (mode_arg.size != 0) {
@ -363,9 +388,9 @@ struct EncryptionAndDecryptFourImpl {
StringOP::push_null_string(i, result_data, result_offset, null_map);
continue;
}
execute_result_const<Impl, is_encrypt>(offsets_column, chars_column, key_arg, i,
encryption_mode, iv_arg.data, iv_arg.size,
result_data, result_offset, null_map);
execute_result_const<Impl, is_encrypt>(
offsets_column, chars_column, key_arg, i, encryption_mode, iv_arg.data,
iv_arg.size, result_data, result_offset, null_map, aad_arg.data, aad_arg.size);
}
}
@ -403,9 +428,16 @@ struct EncryptionAndDecryptFourImpl {
}
}
int aad_size = 0;
const char* aad = nullptr;
if constexpr (arg_num == 5) {
aad_size = (*offsets_list[4])[i] - (*offsets_list[4])[i - 1];
aad = reinterpret_cast<const char*>(&(*chars_list[4])[(*offsets_list[4])[i - 1]]);
}
execute_result_vector<Impl, is_encrypt>(offsets_list, chars_list, i, encryption_mode,
iv_raw, iv_size, result_data, result_offset,
null_map);
null_map, aad, aad_size);
}
}
};
@ -413,18 +445,20 @@ struct EncryptionAndDecryptFourImpl {
struct EncryptImpl {
static int execute_impl(EncryptionMode mode, const unsigned char* source,
uint32_t source_length, const unsigned char* key, uint32_t key_length,
const char* iv, int iv_length, bool padding, unsigned char* encrypt) {
const char* iv, int iv_length, bool padding, unsigned char* encrypt,
const unsigned char* aad, int aad_length) {
return EncryptionUtil::encrypt(mode, source, source_length, key, key_length, iv, iv_length,
true, encrypt);
true, encrypt, aad, aad_length);
}
};
struct DecryptImpl {
static int execute_impl(EncryptionMode mode, const unsigned char* source,
uint32_t source_length, const unsigned char* key, uint32_t key_length,
const char* iv, int iv_length, bool padding, unsigned char* encrypt) {
const char* iv, int iv_length, bool padding, unsigned char* encrypt,
const unsigned char* aad, int aad_length) {
return EncryptionUtil::decrypt(mode, source, source_length, key, key_length, iv, iv_length,
true, encrypt);
true, encrypt, aad, aad_length);
}
};
@ -459,16 +493,24 @@ void register_function_encryption(SimpleFunctionFactory& factory) {
AESDecryptName>>();
factory.register_function<FunctionEncryptionAndDecrypt<
EncryptionAndDecryptFourImpl<EncryptImpl, EncryptionMode::SM4_128_ECB, true, true>,
EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::SM4_128_ECB, true, true>,
SM4EncryptName>>();
factory.register_function<FunctionEncryptionAndDecrypt<
EncryptionAndDecryptFourImpl<DecryptImpl, EncryptionMode::SM4_128_ECB, false, true>,
EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::SM4_128_ECB, false, true>,
SM4DecryptName>>();
factory.register_function<FunctionEncryptionAndDecrypt<
EncryptionAndDecryptFourImpl<EncryptImpl, EncryptionMode::AES_128_ECB, true, false>,
EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::AES_128_ECB, true, false>,
AESEncryptName>>();
factory.register_function<FunctionEncryptionAndDecrypt<
EncryptionAndDecryptFourImpl<DecryptImpl, EncryptionMode::AES_128_ECB, false, false>,
EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::AES_128_ECB, false, false>,
AESDecryptName>>();
factory.register_function<FunctionEncryptionAndDecrypt<
EncryptionAndDecryptMultiImpl<EncryptImpl, EncryptionMode::AES_128_GCM, true, false, 5>,
AESEncryptName>>();
factory.register_function<FunctionEncryptionAndDecrypt<
EncryptionAndDecryptMultiImpl<DecryptImpl, EncryptionMode::AES_128_GCM, false, false,
5>,
AESDecryptName>>();
}