559 lines
16 KiB
Go
559 lines
16 KiB
Go
// Copyright 2017 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package expression
|
|
|
|
import (
|
|
"bytes"
|
|
"compress/zlib"
|
|
"crypto/md5"
|
|
"crypto/rand"
|
|
"crypto/sha1"
|
|
"crypto/sha256"
|
|
"crypto/sha512"
|
|
"fmt"
|
|
"hash"
|
|
|
|
"github.com/juju/errors"
|
|
"github.com/pingcap/tidb/context"
|
|
"github.com/pingcap/tidb/mysql"
|
|
"github.com/pingcap/tidb/util/encrypt"
|
|
"github.com/pingcap/tidb/util/types"
|
|
)
|
|
|
|
var (
|
|
_ functionClass = &aesDecryptFunctionClass{}
|
|
_ functionClass = &aesEncryptFunctionClass{}
|
|
_ functionClass = &compressFunctionClass{}
|
|
_ functionClass = &decodeFunctionClass{}
|
|
_ functionClass = &desDecryptFunctionClass{}
|
|
_ functionClass = &desEncryptFunctionClass{}
|
|
_ functionClass = &encodeFunctionClass{}
|
|
_ functionClass = &encryptFunctionClass{}
|
|
_ functionClass = &md5FunctionClass{}
|
|
_ functionClass = &oldPasswordFunctionClass{}
|
|
_ functionClass = &passwordFunctionClass{}
|
|
_ functionClass = &randomBytesFunctionClass{}
|
|
_ functionClass = &sha1FunctionClass{}
|
|
_ functionClass = &sha2FunctionClass{}
|
|
_ functionClass = &uncompressFunctionClass{}
|
|
_ functionClass = &uncompressedLengthFunctionClass{}
|
|
_ functionClass = &validatePasswordStrengthFunctionClass{}
|
|
)
|
|
|
|
var (
|
|
_ builtinFunc = &builtinAesDecryptSig{}
|
|
_ builtinFunc = &builtinAesEncryptSig{}
|
|
_ builtinFunc = &builtinCompressSig{}
|
|
_ builtinFunc = &builtinDecodeSig{}
|
|
_ builtinFunc = &builtinDesDecryptSig{}
|
|
_ builtinFunc = &builtinDesEncryptSig{}
|
|
_ builtinFunc = &builtinEncodeSig{}
|
|
_ builtinFunc = &builtinEncryptSig{}
|
|
_ builtinFunc = &builtinMD5Sig{}
|
|
_ builtinFunc = &builtinOldPasswordSig{}
|
|
_ builtinFunc = &builtinPasswordSig{}
|
|
_ builtinFunc = &builtinRandomBytesSig{}
|
|
_ builtinFunc = &builtinSHA1Sig{}
|
|
_ builtinFunc = &builtinSHA2Sig{}
|
|
_ builtinFunc = &builtinUncompressSig{}
|
|
_ builtinFunc = &builtinUncompressedLengthSig{}
|
|
_ builtinFunc = &builtinValidatePasswordStrengthSig{}
|
|
)
|
|
|
|
// TODO: support other mode
|
|
const (
|
|
aes128ecb string = "aes-128-ecb"
|
|
)
|
|
|
|
type aesDecryptFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *aesDecryptFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
err := errors.Trace(c.verifyArgs(args))
|
|
bt := &builtinAesDecryptSig{newBaseBuiltinFunc(args, ctx)}
|
|
bt.deterministic = true
|
|
return bt, errors.Trace(err)
|
|
}
|
|
|
|
type builtinAesDecryptSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt
|
|
func (b *builtinAesDecryptSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
args, err := b.evalArgs(row)
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
for _, arg := range args {
|
|
// If either function argument is NULL, the function returns NULL.
|
|
if arg.IsNull() {
|
|
return d, nil
|
|
}
|
|
}
|
|
|
|
cryptStr, err := args[0].ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
key, err := args[1].ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
key = handleAESKey(key, aes128ecb)
|
|
// By default these functions implement AES with a 128-bit key length.
|
|
// TODO: We only support aes-128-ecb now. We should support other mode latter.
|
|
data, err := encrypt.AESDecryptWithECB(cryptStr, key)
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
d.SetString(string(data))
|
|
return d, nil
|
|
}
|
|
|
|
type aesEncryptFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *aesEncryptFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
err := errors.Trace(c.verifyArgs(args))
|
|
bt := &builtinAesEncryptSig{newBaseBuiltinFunc(args, ctx)}
|
|
return bt, errors.Trace(err)
|
|
}
|
|
|
|
type builtinAesEncryptSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_aes-decrypt
|
|
// We only support aes-128-ecb mode.
|
|
// TODO: support other mode.
|
|
func (b *builtinAesEncryptSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
args, err := b.evalArgs(row)
|
|
if err != nil {
|
|
return types.Datum{}, errors.Trace(err)
|
|
}
|
|
for _, arg := range args {
|
|
// If either function argument is NULL, the function returns NULL.
|
|
if arg.IsNull() {
|
|
return
|
|
}
|
|
}
|
|
str, err := args[0].ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
key, err := args[1].ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
key = handleAESKey(key, aes128ecb)
|
|
crypted, err := encrypt.AESEncryptWithECB(str, key)
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
d.SetString(string(crypted))
|
|
return
|
|
}
|
|
|
|
// Transforms an arbitrary long key into a fixed length AES key.
|
|
func handleAESKey(key []byte, mode string) []byte {
|
|
// TODO: get key size according to mode
|
|
keySize := 16
|
|
rKey := make([]byte, keySize)
|
|
rIdx := 0
|
|
for _, k := range key {
|
|
if rIdx == keySize {
|
|
rIdx = 0
|
|
}
|
|
rKey[rIdx] ^= k
|
|
rIdx++
|
|
}
|
|
return rKey
|
|
}
|
|
|
|
type compressFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *compressFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinCompressSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinCompressSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_compress
|
|
func (b *builtinCompressSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
args, err := b.evalArgs(row)
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
|
|
arg := args[0]
|
|
if arg.IsNull() {
|
|
return d, nil
|
|
}
|
|
|
|
compressStr, err := arg.ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
|
|
var in bytes.Buffer
|
|
w := zlib.NewWriter(&in)
|
|
w.Write(compressStr)
|
|
w.Close()
|
|
d.SetBytes(in.Bytes())
|
|
return d, nil
|
|
}
|
|
|
|
type decodeFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *decodeFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinDecodeSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinDecodeSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_decode
|
|
func (b *builtinDecodeSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("DECODE")
|
|
}
|
|
|
|
type desDecryptFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *desDecryptFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinDesDecryptSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinDesDecryptSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_des-decrypt
|
|
func (b *builtinDesDecryptSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("DES_DECRYPT")
|
|
}
|
|
|
|
type desEncryptFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *desEncryptFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinDesEncryptSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinDesEncryptSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_des-encrypt
|
|
func (b *builtinDesEncryptSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("DES_ENCRYPT")
|
|
}
|
|
|
|
type encodeFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *encodeFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinEncodeSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinEncodeSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_encode
|
|
func (b *builtinEncodeSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("ENCODE")
|
|
}
|
|
|
|
type encryptFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *encryptFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinEncryptSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinEncryptSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_encrypt
|
|
func (b *builtinEncryptSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("ENCRYPT")
|
|
}
|
|
|
|
type md5FunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *md5FunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinMD5Sig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinMD5Sig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_md5
|
|
func (b *builtinMD5Sig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
args, err := b.evalArgs(row)
|
|
if err != nil {
|
|
return types.Datum{}, errors.Trace(err)
|
|
}
|
|
// This function takes one argument.
|
|
arg := args[0]
|
|
if arg.IsNull() {
|
|
return
|
|
}
|
|
bin, err := arg.ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
sum := md5.Sum(bin)
|
|
hexStr := fmt.Sprintf("%x", sum)
|
|
d.SetString(hexStr)
|
|
return d, nil
|
|
}
|
|
|
|
type oldPasswordFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *oldPasswordFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinOldPasswordSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinOldPasswordSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_old-password
|
|
func (b *builtinOldPasswordSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("OLD_PASSWORD")
|
|
}
|
|
|
|
type passwordFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *passwordFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinPasswordSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinPasswordSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_password
|
|
func (b *builtinPasswordSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("PASSWORD")
|
|
}
|
|
|
|
type randomBytesFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *randomBytesFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinRandomBytesSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinRandomBytesSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_random-bytes
|
|
func (b *builtinRandomBytesSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
args, err := b.evalArgs(row)
|
|
if err != nil {
|
|
return types.Datum{}, errors.Trace(err)
|
|
}
|
|
arg := args[0]
|
|
if arg.IsNull() {
|
|
return d, nil
|
|
}
|
|
size := arg.GetInt64()
|
|
if size < 1 || size > 1024 {
|
|
return d, mysql.NewErr(mysql.ErrDataOutOfRange, "length", "random_bytes")
|
|
}
|
|
buf := make([]byte, size)
|
|
if n, err := rand.Read(buf); err != nil {
|
|
return d, errors.Trace(err)
|
|
} else if int64(n) != size {
|
|
return d, errors.New("fail to generate random bytes")
|
|
}
|
|
d.SetBytes(buf)
|
|
return d, nil
|
|
}
|
|
|
|
type sha1FunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *sha1FunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinSHA1Sig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinSHA1Sig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_sha1
|
|
// The value is returned as a string of 40 hexadecimal digits, or NULL if the argument was NULL.
|
|
func (b *builtinSHA1Sig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
args, err := b.evalArgs(row)
|
|
if err != nil {
|
|
return types.Datum{}, errors.Trace(err)
|
|
}
|
|
// SHA/SHA1 function only accept 1 parameter
|
|
arg := args[0]
|
|
if arg.IsNull() {
|
|
return d, nil
|
|
}
|
|
bin, err := arg.ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
hasher := sha1.New()
|
|
hasher.Write(bin)
|
|
data := fmt.Sprintf("%x", hasher.Sum(nil))
|
|
d.SetString(data)
|
|
return d, nil
|
|
}
|
|
|
|
type sha2FunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *sha2FunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinSHA2Sig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinSHA2Sig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// Supported hash length of SHA-2 family
|
|
const (
|
|
SHA0 int = 0
|
|
SHA224 int = 224
|
|
SHA256 int = 256
|
|
SHA384 int = 384
|
|
SHA512 int = 512
|
|
)
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_sha2
|
|
func (b *builtinSHA2Sig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
args, err := b.evalArgs(row)
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
for _, arg := range args {
|
|
if arg.IsNull() {
|
|
return d, nil
|
|
}
|
|
}
|
|
// Meaning of each argument:
|
|
// args[0]: the cleartext string to be hashed
|
|
// args[1]: desired bit length of result
|
|
bin, err := args[0].ToBytes()
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
hashLength, err := args[1].ToInt64(b.ctx.GetSessionVars().StmtCtx)
|
|
if err != nil {
|
|
return d, errors.Trace(err)
|
|
}
|
|
var hasher hash.Hash
|
|
switch int(hashLength) {
|
|
case SHA0, SHA256:
|
|
hasher = sha256.New()
|
|
case SHA224:
|
|
hasher = sha256.New224()
|
|
case SHA384:
|
|
hasher = sha512.New384()
|
|
case SHA512:
|
|
hasher = sha512.New()
|
|
}
|
|
if hasher != nil {
|
|
hasher.Write(bin)
|
|
data := fmt.Sprintf("%x", hasher.Sum(nil))
|
|
d.SetString(data)
|
|
}
|
|
return d, nil
|
|
}
|
|
|
|
type uncompressFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *uncompressFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinUncompressSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinUncompressSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_uncompress
|
|
func (b *builtinUncompressSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("UNCOMPRESS")
|
|
}
|
|
|
|
type uncompressedLengthFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *uncompressedLengthFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinUncompressedLengthSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinUncompressedLengthSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_uncompressed-length
|
|
func (b *builtinUncompressedLengthSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("UNCOMPRESSED_LENGTH")
|
|
}
|
|
|
|
type validatePasswordStrengthFunctionClass struct {
|
|
baseFunctionClass
|
|
}
|
|
|
|
func (c *validatePasswordStrengthFunctionClass) getFunction(args []Expression, ctx context.Context) (builtinFunc, error) {
|
|
return &builtinValidatePasswordStrengthSig{newBaseBuiltinFunc(args, ctx)}, errors.Trace(c.verifyArgs(args))
|
|
}
|
|
|
|
type builtinValidatePasswordStrengthSig struct {
|
|
baseBuiltinFunc
|
|
}
|
|
|
|
// See https://dev.mysql.com/doc/refman/5.7/en/encryption-functions.html#function_validate-password-strength
|
|
func (b *builtinValidatePasswordStrengthSig) eval(row []types.Datum) (d types.Datum, err error) {
|
|
return d, errFunctionNotExists.GenByArgs("VALIDATE_PASSWORD_STRENGTH")
|
|
}
|