diff --git a/expression/bench_test.go b/expression/bench_test.go index 7ef15dc639..12b0728009 100644 --- a/expression/bench_test.go +++ b/expression/bench_test.go @@ -31,6 +31,7 @@ import ( "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/types" "github.com/pingcap/tidb/types/json" "github.com/pingcap/tidb/util/chunk" @@ -844,6 +845,10 @@ func testVectorizedBuiltinFunc(c *C, vecExprCases vecExprBenchCases) { for funcName, testCases := range vecExprCases { for _, testCase := range testCases { ctx := mock.NewContext() + if funcName == ast.AesEncrypt { + err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, "aes-128-ecb") + c.Assert(err, IsNil) + } baseFunc, fts, input, output := genVecBuiltinFuncBenchCase(ctx, funcName, testCase) baseFuncName := fmt.Sprintf("%v", reflect.TypeOf(baseFunc)) tmp := strings.Split(baseFuncName, ".") @@ -1034,6 +1039,12 @@ func benchmarkVectorizedBuiltinFunc(b *testing.B, vecExprCases vecExprBenchCases } for funcName, testCases := range vecExprCases { for _, testCase := range testCases { + if funcName == ast.AesEncrypt { + err := ctx.GetSessionVars().SetSystemVar(variable.BlockEncryptionMode, "aes-128-ecb") + if err != nil { + panic(err) + } + } baseFunc, _, input, output := genVecBuiltinFuncBenchCase(ctx, funcName, testCase) baseFuncName := fmt.Sprintf("%v", reflect.TypeOf(baseFunc)) tmp := strings.Split(baseFuncName, ".") diff --git a/expression/builtin_encryption_vec.go b/expression/builtin_encryption_vec.go index f3b6b6d1a9..11fda0404f 100644 --- a/expression/builtin_encryption_vec.go +++ b/expression/builtin_encryption_vec.go @@ -366,11 +366,62 @@ func (b *builtinCompressSig) vecEvalString(input *chunk.Chunk, result *chunk.Col } func (b *builtinAesEncryptSig) vectorized() bool { - return false + return true } +// evalString evals AES_ENCRYPT(str, key_str). +// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt func (b *builtinAesEncryptSig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error { - return errors.Errorf("not implemented") + n := input.NumRows() + strBuf, err := b.bufAllocator.get(types.ETString, n) + if err != nil { + return err + } + defer b.bufAllocator.put(strBuf) + if err := b.args[0].VecEvalString(b.ctx, input, strBuf); err != nil { + return err + } + + keyBuf, err := b.bufAllocator.get(types.ETString, n) + if err != nil { + return err + } + defer b.bufAllocator.put(keyBuf) + if err := b.args[1].VecEvalString(b.ctx, input, keyBuf); err != nil { + return err + } + + if b.modeName != "ecb" { + // For modes that do not require init_vector, it is ignored and a warning is generated if it is specified. + return errors.Errorf("unsupported block encryption mode - %v", b.modeName) + } + + isWarning := !b.ivRequired && len(b.args) == 3 + + result.ReserveString(n) + for i := 0; i < n; i++ { + // According to doc: If either function argument is NULL, the function returns NULL. + if strBuf.IsNull(i) || keyBuf.IsNull(i) { + result.AppendNull() + continue + } + if isWarning { + b.ctx.GetSessionVars().StmtCtx.AppendWarning(errWarnOptionIgnored.GenWithStackByArgs("IV")) + } + key := encrypt.DeriveKeyMySQL(keyBuf.GetBytes(i), b.keySize) + + // NOTE: we can't use GetBytes, because in AESEncryptWithECB padding is automatically + // added to str and this will damange the data layout in chunk.Column + str := []byte(strBuf.GetString(i)) + cipherText, err := encrypt.AESEncryptWithECB(str, key) + if err != nil { + result.AppendNull() + continue + } + result.AppendBytes(cipherText) + } + + return nil } func (b *builtinPasswordSig) vectorized() bool { diff --git a/expression/builtin_encryption_vec_test.go b/expression/builtin_encryption_vec_test.go index e79bd13e38..6650734c8b 100644 --- a/expression/builtin_encryption_vec_test.go +++ b/expression/builtin_encryption_vec_test.go @@ -22,7 +22,10 @@ import ( ) var vecBuiltinEncryptionCases = map[string][]vecExprBenchCase{ - ast.AesEncrypt: {}, + ast.AesEncrypt: { + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString}}, + {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString, types.ETString, types.ETString}, geners: []dataGenerator{nil, nil, &constStrGener{"iv"}}}, + }, ast.Uncompress: { {retEvalType: types.ETString, childrenTypes: []types.EvalType{types.ETString}}, }, diff --git a/util/encrypt/aes.go b/util/encrypt/aes.go index 7b1d39644b..1a9f3e2815 100644 --- a/util/encrypt/aes.go +++ b/util/encrypt/aes.go @@ -237,6 +237,7 @@ func aesDecrypt(cryptStr []byte, mode cipher.BlockMode) ([]byte, error) { } // aesEncrypt encrypts data using AES. +// NOTE: if len(str)