[feature] Suport national secret (national commercial password) algorithm SM3/SM4 (#7464)

SM3 is password hash algorithm
SM4 is a block cipher used to replace DES / AES and other international algorithms.
This commit is contained in:
Zhengguo Yang
2021-12-28 10:39:54 +08:00
committed by GitHub
parent 6e052f4ede
commit 07e2acb2f3
33 changed files with 2036 additions and 342 deletions

View File

@ -21,53 +21,177 @@
#include "exprs/expr.h"
#include "runtime/string_value.h"
#include "runtime/tuple_row.h"
#include "util/aes_util.h"
#include "util/debug_util.h"
#include "util/encryption_util.h"
#include "util/md5.h"
#include "util/sm3.h"
#include "util/string_util.h"
#include "util/url_coding.h"
namespace doris {
void EncryptionFunctions::init() {}
StringVal EncryptionFunctions::aes_encrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key) {
StringCaseUnorderedMap<EncryptionMode> aes_mode_map {
{"AES_128_ECB", AES_128_ECB}, {"AES_192_ECB", AES_192_ECB},
{"AES_256_ECB", AES_256_ECB}, {"AES_128_CBC", AES_128_CBC},
{"AES_192_CBC", AES_192_CBC}, {"AES_256_CBC", AES_256_CBC},
{"AES_128_CFB", AES_128_CFB}, {"AES_192_CFB", AES_192_CFB},
{"AES_256_CFB", AES_256_CFB}, {"AES_128_CFB1", AES_128_CFB1},
{"AES_192_CFB1", AES_192_CFB1}, {"AES_256_CFB1", AES_256_CFB1},
{"AES_128_CFB8", AES_128_CFB8}, {"AES_192_CFB8", AES_192_CFB8},
{"AES_256_CFB8", AES_256_CFB8}, {"AES_128_CFB128", AES_128_CFB128},
{"AES_192_CFB128", AES_192_CFB128}, {"AES_256_CFB128", AES_256_CFB128},
{"AES_128_CTR", AES_128_CTR}, {"AES_192_CTR", AES_192_CTR},
{"AES_256_CTR", AES_256_CTR}, {"AES_128_OFB", AES_128_OFB},
{"AES_192_OFB", AES_192_OFB}, {"AES_256_OFB", AES_256_OFB}};
StringCaseUnorderedMap<EncryptionMode> sm4_mode_map {{"SM4_128_ECB", SM4_128_ECB},
{"SM4_128_CBC", SM4_128_CBC},
{"SM4_128_CFB128", SM4_128_CFB128},
{"SM4_128_OFB", SM4_128_OFB},
{"SM4_128_CTR", SM4_128_CTR}};
StringVal encrypt(FunctionContext* ctx, const StringVal& src, const StringVal& key,
const StringVal& iv, EncryptionMode mode) {
if (src.len == 0 || src.is_null) {
return StringVal::null();
}
// cipher_len = (clearLen/16 + 1) * 16;
int cipher_len = src.len + 16;
std::unique_ptr<char[]> p;
p.reset(new char[cipher_len]);
int ret_code =
AesUtil::encrypt(AES_128_ECB, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr,
key.len, nullptr, true, (unsigned char*)p.get());
int ret_code = 0;
if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB &&
mode != SM4_128_ECB) {
if (iv.len == 0 || iv.is_null) {
return StringVal::null();
}
int iv_len = 32; // max key length 256 / 8
std::unique_ptr<char[]> init_vec;
init_vec.reset(new char[iv_len]);
std::memset(init_vec.get(), 0, iv.len + 1);
memcpy(init_vec.get(), iv.ptr, iv.len);
ret_code = EncryptionUtil::encrypt(
mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len,
(unsigned char*)init_vec.get(), true, (unsigned char*)p.get());
} else {
ret_code = EncryptionUtil::encrypt(mode, (unsigned char*)src.ptr, src.len,
(unsigned char*)key.ptr, key.len, nullptr, true,
(unsigned char*)p.get());
}
if (ret_code < 0) {
return StringVal::null();
}
return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
}
StringVal EncryptionFunctions::aes_decrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key) {
StringVal decrypt(FunctionContext* ctx, const StringVal& src, const StringVal& key,
const StringVal& iv, EncryptionMode mode) {
if (src.len == 0 || src.is_null) {
return StringVal::null();
}
int cipher_len = src.len;
std::unique_ptr<char[]> p;
p.reset(new char[cipher_len]);
int ret_code =
AesUtil::decrypt(AES_128_ECB, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr,
key.len, nullptr, true, (unsigned char*)p.get());
int ret_code = 0;
if (mode != AES_128_ECB && mode != AES_192_ECB && mode != AES_256_ECB && mode != AES_256_ECB &&
mode != SM4_128_ECB) {
if (iv.len == 0 || iv.is_null) {
return StringVal::null();
}
int iv_len = 32; // max key length 256 / 8
std::unique_ptr<char[]> init_vec;
init_vec.reset(new char[iv_len]);
std::memset(init_vec.get(), 0, iv.len + 1);
memcpy(init_vec.get(), iv.ptr, iv.len);
ret_code = EncryptionUtil::decrypt(
mode, (unsigned char*)src.ptr, src.len, (unsigned char*)key.ptr, key.len,
(unsigned char*)init_vec.get(), true, (unsigned char*)p.get());
} else {
ret_code = EncryptionUtil::decrypt(mode, (unsigned char*)src.ptr, src.len,
(unsigned char*)key.ptr, key.len, nullptr, true,
(unsigned char*)p.get());
}
if (ret_code < 0) {
return StringVal::null();
}
return AnyValUtil::from_buffer_temp(ctx, p.get(), ret_code);
}
StringVal EncryptionFunctions::aes_encrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key) {
return aes_encrypt(ctx, src, key, StringVal::null(), StringVal("AES_128_ECB"));
}
StringVal EncryptionFunctions::aes_decrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key) {
return aes_decrypt(ctx, src, key, StringVal::null(), StringVal("AES_128_ECB"));
}
StringVal EncryptionFunctions::aes_encrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key, const StringVal& iv,
const StringVal& mode) {
EncryptionMode encryption_mode = AES_128_ECB;
if (mode.len != 0 && !mode.is_null) {
std::string mode_str(reinterpret_cast<char*>(mode.ptr), mode.len);
if (aes_mode_map.count(mode_str) == 0) {
return StringVal::null();
}
encryption_mode = aes_mode_map.at(mode_str);
}
return encrypt(ctx, src, key, iv, encryption_mode);
}
StringVal EncryptionFunctions::aes_decrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key, const StringVal& iv,
const StringVal& mode) {
EncryptionMode encryption_mode = AES_128_ECB;
if (mode.len != 0 && !mode.is_null) {
std::string mode_str(reinterpret_cast<char*>(mode.ptr), mode.len);
if (aes_mode_map.count(mode_str) == 0) {
return StringVal::null();
}
encryption_mode = aes_mode_map.at(mode_str);
}
return decrypt(ctx, src, key, iv, encryption_mode);
}
StringVal EncryptionFunctions::sm4_encrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key) {
return sm4_encrypt(ctx, src, key, StringVal::null(), StringVal("SM4_128_ECB"));
}
StringVal EncryptionFunctions::sm4_decrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key) {
return sm4_decrypt(ctx, src, key, StringVal::null(), StringVal("SM4_128_ECB"));
}
StringVal EncryptionFunctions::sm4_encrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key, const StringVal& iv,
const StringVal& mode) {
EncryptionMode encryption_mode = SM4_128_ECB;
if (mode.len != 0 && !mode.is_null) {
std::string mode_str(reinterpret_cast<char*>(mode.ptr), mode.len);
if (sm4_mode_map.count(mode_str) == 0) {
return StringVal::null();
}
encryption_mode = sm4_mode_map.at(mode_str);
}
return encrypt(ctx, src, key, iv, encryption_mode);
}
StringVal EncryptionFunctions::sm4_decrypt(FunctionContext* ctx, const StringVal& src,
const StringVal& key, const StringVal& iv,
const StringVal& mode) {
EncryptionMode encryption_mode = SM4_128_ECB;
if (mode.len != 0 && !mode.is_null) {
std::string mode_str(reinterpret_cast<char*>(mode.ptr), mode.len);
if (sm4_mode_map.count(mode_str) == 0) {
return StringVal::null();
}
encryption_mode = sm4_mode_map.at(mode_str);
}
return decrypt(ctx, src, key, iv, encryption_mode);
}
StringVal EncryptionFunctions::from_base64(FunctionContext* ctx, const StringVal& src) {
if (src.len == 0 || src.is_null) {
return StringVal::null();
@ -123,4 +247,27 @@ StringVal EncryptionFunctions::md5(FunctionContext* ctx, const StringVal& src) {
return AnyValUtil::from_buffer_temp(ctx, digest.hex().c_str(), digest.hex().size());
}
StringVal EncryptionFunctions::sm3sum(FunctionContext* ctx, int num_args, const StringVal* args) {
SM3Digest digest;
for (int i = 0; i < num_args; ++i) {
const StringVal& arg = args[i];
if (arg.is_null) {
continue;
}
digest.update(arg.ptr, arg.len);
}
digest.digest();
return AnyValUtil::from_buffer_temp(ctx, digest.hex().c_str(), digest.hex().size());
}
StringVal EncryptionFunctions::sm3(FunctionContext* ctx, const StringVal& src) {
if (src.is_null) {
return StringVal::null();
}
SM3Digest digest;
digest.update(src.ptr, src.len);
digest.digest();
return AnyValUtil::from_buffer_temp(ctx, digest.hex().c_str(), digest.hex().size());
}
} // namespace doris