From 871c19bb46dde4d6dffff1cc16057e1fbba00cb4 Mon Sep 17 00:00:00 2001 From: Zhang Zhiyi <909645105@qq.com> Date: Mon, 18 Nov 2019 22:00:24 +0800 Subject: [PATCH] expression: implement vectorized evaluation for `builtinAesEncryptIVSig` (#13521) --- expression/bench_test.go | 16 ++-- expression/builtin_encryption_vec.go | 91 ++++++++++++++++++++++- expression/builtin_encryption_vec_test.go | 11 ++- 3 files changed, 103 insertions(+), 15 deletions(-) diff --git a/expression/bench_test.go b/expression/bench_test.go index b8e60d2fa9..8381ee3a37 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -693,6 +693,8 @@ type vecExprBenchCase struct { // geners[gen1, gen2] will be regarded as geners[gen1, gen2, nil]. // This field is optional. geners []dataGenerator + // aesModeAttr information, needed by encryption functions + aesModes string // constants are used to generate constant data for children[i]. constants []*Constant } @@ -1002,10 +1004,8 @@ func testVectorizedBuiltinFunc(c *C, vecExprCases vecExprBenchCases) { for funcName, testCases := range vecExprCases { for _, testCase := range testCases { ctx := mock.NewContext() - if funcName == ast.AesDecrypt || funcName == ast.AesEncrypt { - err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, "aes-128-ecb") - c.Assert(err, IsNil) - } + err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, testCase.aesModes) + c.Assert(err, IsNil) if funcName == ast.CurrentUser || funcName == ast.User { ctx.GetSessionVars().User = &auth.UserIdentity{ Username: "tidb", @@ -1205,11 +1205,9 @@ func benchmarkVectorizedBuiltinFunc(b *testing.B, vecExprCases vecExprBenchCases } for funcName, testCases := range vecExprCases { for _, testCase := range testCases { - if funcName == ast.AesDecrypt || funcName == ast.AesEncrypt { - err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, "aes-128-ecb") - if err != nil { - panic(err) - } + err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, testCase.aesModes) + if err != nil { + panic(err) } if funcName == ast.CurrentUser || funcName == ast.User { ctx.GetSessionVars().User = &auth.UserIdentity{ diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index 3165da6798..6752de4d06 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -15,6 +15,7 @@ package expression import ( "bytes" + "crypto/aes" "crypto/md5" "crypto/rand" "crypto/sha1" @@ -101,11 +102,97 @@ func (b *builtinAesDecryptSig) vecEvalString(input *chunk.Chunk, result *chunk.C } func (b *builtinAesEncryptIVSig) vectorized() bool { - return false + return true } +// evalString evals AES_ENCRYPT(str, key_str, iv). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt func (b *builtinAesEncryptIVSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { - return errors.Errorf("not implemented") + n := input.NumRows() + strBuf, err := b.bufAllocator.get(types.ETString, n) + if err != nil { + return err + } + defer b.bufAllocator.put(strBuf) + if err := b.args[0].VecEvalString(b.ctx, input, strBuf); err != nil { + return err + } + + keyBuf, err := b.bufAllocator.get(types.ETString, n) + if err != nil { + return err + } + defer b.bufAllocator.put(keyBuf) + if err := b.args[1].VecEvalString(b.ctx, input, keyBuf); err != nil { + return err + } + + ivBuf, err := b.bufAllocator.get(types.ETString, n) + if err != nil { + return err + } + defer b.bufAllocator.put(ivBuf) + if err := b.args[2].VecEvalString(b.ctx, input, ivBuf); err != nil { + return err + } + + isCBC := false + isOFB := false + isCFB := false + switch b.modeName { + case "cbc": + isCBC = true + case "ofb": + isOFB = true + case "cfb": + isCFB = true + default: + return errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + + isConst := b.args[1].ConstItem() + var key []byte + if isConst { + key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(0), b.keySize) + } + + result.ReserveString(n) + for i := 0; i < n; i++ { + // According to doc: If either function argument is NULL, the function returns NULL. + if strBuf.IsNull(i) || keyBuf.IsNull(i) || ivBuf.IsNull(i) { + result.AppendNull() + continue + } + + iv := ivBuf.GetBytes(i) + if len(iv) < aes.BlockSize { + return errIncorrectArgs.GenWithStack("The initialization vector supplied to aes_decrypt is too short. Must be at least %d bytes long", aes.BlockSize) + } + // init_vector must be 16 bytes or longer (bytes in excess of 16 are ignored) + iv = iv[0:aes.BlockSize] + if !isConst { + key = encrypt.DeriveKeyMySQL(keyBuf.GetBytes(i), b.keySize) + } + var cipherText []byte + + // ANNOTATION: + // we can't use GetBytes here because GetBytes return raw memory in strBuf, + // and the memory will be modified in AESEncryptWithCBC & AESEncryptWithOFB & AESEncryptWithCFB + if isCBC { + cipherText, err = encrypt.AESEncryptWithCBC([]byte(strBuf.GetString(i)), key, iv) + } + if isOFB { + cipherText, err = encrypt.AESEncryptWithOFB([]byte(strBuf.GetString(i)), key, iv) + } + if isCFB { + cipherText, err = encrypt.AESEncryptWithCFB([]byte(strBuf.GetString(i)), key, iv) + } + if err != nil { + result.AppendNull() + } + result.AppendBytes(cipherText) + } + return nil } func (b *builtinDecodeSig) vectorized() bool { diff --git a/expression/builtin_encryption_vec_test.go b/expression/builtin_encryption_vec_test.go index 24ac2f69c0..09c44166f0 100644 --- a/expression/builtin_encryption_vec_test.go +++ b/expression/builtin_encryption_vec_test.go @@ -23,15 +23,18 @@ import ( var vecBuiltinEncryptionCases = map[string][]vecExprBenchCase{ ast.AesEncrypt: { - {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}}, - {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &constStrGener{"iv"}}}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}, aesModes: "aes-128-ecb"}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &constStrGener{"iv"}}, aesModes: "aes-128-ecb"}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &randLenStrGener{16, 17}}, aesModes: "aes-128-cbc"}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &randLenStrGener{16, 17}}, aesModes: "aes-128-ofb"}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &randLenStrGener{16, 17}}, aesModes: "aes-128-cfb"}, }, ast.Uncompress: { {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString}}, }, ast.AesDecrypt: { - {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}}, - {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &constStrGener{"iv"}}}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}, aesModes: "aes-128-ecb"}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &constStrGener{"iv"}}, aesModes: "aes-128-ecb"}, }, ast.Compress: { {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString}},