expression: convert charset by wrapping internal builtin function (#29736)
This commit is contained in:
@ -140,6 +140,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
|
||||
args[i] = WrapWithCastAsDecimal(ctx, args[i])
|
||||
case types.ETString:
|
||||
args[i] = WrapWithCastAsString(ctx, args[i])
|
||||
args[i] = WrapWithToBinary(ctx, args[i], funcName)
|
||||
case types.ETDatetime:
|
||||
args[i] = WrapWithCastAsTime(ctx, args[i], types.NewFieldType(mysql.TypeDatetime))
|
||||
case types.ETTimestamp:
|
||||
@ -879,6 +880,9 @@ var funcs = map[string]functionClass{
|
||||
ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}},
|
||||
ast.LastVal: &lastValFunctionClass{baseFunctionClass{ast.LastVal, 1, 1}},
|
||||
ast.SetVal: &setValFunctionClass{baseFunctionClass{ast.SetVal, 2, 2}},
|
||||
|
||||
// TiDB implicit internal functions.
|
||||
InternalFuncToBinary: &tidbConvertCharsetFunctionClass{baseFunctionClass{InternalFuncToBinary, 1, 1}},
|
||||
}
|
||||
|
||||
// IsFunctionSupported check if given function name is a builtin sql function.
|
||||
@ -902,6 +906,7 @@ func GetDisplayName(name string) string {
|
||||
func GetBuiltinList() []string {
|
||||
res := make([]string, 0, len(funcs))
|
||||
notImplementedFunctions := []string{ast.RowFunc, ast.IsTruthWithNull}
|
||||
implicitFunctions := []string{InternalFuncToBinary}
|
||||
for funcName := range funcs {
|
||||
skipFunc := false
|
||||
// Skip not implemented functions
|
||||
@ -910,6 +915,11 @@ func GetBuiltinList() []string {
|
||||
skipFunc = true
|
||||
}
|
||||
}
|
||||
for _, implicitFunc := range implicitFunctions {
|
||||
if funcName == implicitFunc {
|
||||
skipFunc = true
|
||||
}
|
||||
}
|
||||
// Skip literal functions
|
||||
// (their names are not readable: 'tidb`.(dateliteral, for example)
|
||||
// See: https://github.com/pingcap/parser/pull/591
|
||||
|
||||
136
expression/builtin_convert_charset.go
Normal file
136
expression/builtin_convert_charset.go
Normal file
@ -0,0 +1,136 @@
|
||||
// Copyright 2021 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package expression
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pingcap/tidb/parser/ast"
|
||||
"github.com/pingcap/tidb/parser/charset"
|
||||
"github.com/pingcap/tidb/parser/model"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
"github.com/pingcap/tidb/types"
|
||||
"github.com/pingcap/tidb/util/chunk"
|
||||
"github.com/pingcap/tipb/go-tipb"
|
||||
)
|
||||
|
||||
// InternalFuncToBinary accepts a string and returns another string encoded in a given charset.
|
||||
const InternalFuncToBinary = "to_binary"
|
||||
|
||||
type tidbConvertCharsetFunctionClass struct {
|
||||
baseFunctionClass
|
||||
}
|
||||
|
||||
func (c *tidbConvertCharsetFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
|
||||
if err := c.verifyArgs(args); err != nil {
|
||||
return nil, c.verifyArgs(args)
|
||||
}
|
||||
argTp := args[0].GetType().EvalType()
|
||||
var sig builtinFunc
|
||||
switch argTp {
|
||||
case types.ETString:
|
||||
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sig = &builtinInternalToBinarySig{bf}
|
||||
sig.setPbCode(tipb.ScalarFuncSig_ToBinary)
|
||||
default:
|
||||
return nil, fmt.Errorf("unexpected argTp: %d", argTp)
|
||||
}
|
||||
return sig, nil
|
||||
}
|
||||
|
||||
var _ builtinFunc = &builtinInternalToBinarySig{}
|
||||
|
||||
type builtinInternalToBinarySig struct {
|
||||
baseBuiltinFunc
|
||||
}
|
||||
|
||||
func (b *builtinInternalToBinarySig) Clone() builtinFunc {
|
||||
newSig := &builtinInternalToBinarySig{}
|
||||
newSig.cloneFrom(&b.baseBuiltinFunc)
|
||||
return newSig
|
||||
}
|
||||
|
||||
func (b *builtinInternalToBinarySig) evalString(row chunk.Row) (res string, isNull bool, err error) {
|
||||
val, isNull, err := b.args[0].EvalString(b.ctx, row)
|
||||
if isNull || err != nil {
|
||||
return res, isNull, err
|
||||
}
|
||||
tp := b.args[0].GetType()
|
||||
enc := charset.NewEncoding(tp.Charset)
|
||||
res, err = enc.EncodeString(val)
|
||||
return res, false, err
|
||||
}
|
||||
|
||||
func (b *builtinInternalToBinarySig) vectorized() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error {
|
||||
n := input.NumRows()
|
||||
buf, err := b.bufAllocator.get()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer b.bufAllocator.put(buf)
|
||||
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
enc := charset.NewEncoding(b.args[0].GetType().Charset)
|
||||
result.ReserveString(n)
|
||||
for i := 0; i < n; i++ {
|
||||
var str string
|
||||
if buf.IsNull(i) {
|
||||
result.AppendNull()
|
||||
continue
|
||||
}
|
||||
str = buf.GetString(i)
|
||||
str, err = enc.EncodeString(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
result.AppendString(str)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// toBinaryMap contains the builtin functions which arguments need to be converted to the correct charset.
|
||||
var toBinaryMap = map[string]struct{}{
|
||||
ast.Hex: {}, ast.Length: {}, ast.OctetLength: {}, ast.ASCII: {},
|
||||
ast.ToBase64: {},
|
||||
}
|
||||
|
||||
// WrapWithToBinary wraps `expr` with to_binary sig.
|
||||
func WrapWithToBinary(ctx sessionctx.Context, expr Expression, funcName string) Expression {
|
||||
exprTp := expr.GetType()
|
||||
if _, err := charset.GetDefaultCollationLegacy(exprTp.Charset); err != nil {
|
||||
if _, ok := toBinaryMap[funcName]; ok {
|
||||
fc := funcs[InternalFuncToBinary]
|
||||
sig, err := fc.getFunction(ctx, []Expression{expr})
|
||||
if err != nil {
|
||||
return expr
|
||||
}
|
||||
sf := &ScalarFunction{
|
||||
FuncName: model.NewCIStr(InternalFuncToBinary),
|
||||
RetType: exprTp,
|
||||
Function: sig,
|
||||
}
|
||||
return FoldConstant(sf)
|
||||
}
|
||||
}
|
||||
return expr
|
||||
}
|
||||
@ -222,15 +222,6 @@ func (b *builtinLengthSig) evalInt(row chunk.Row) (int64, bool, error) {
|
||||
if isNull || err != nil {
|
||||
return 0, isNull, err
|
||||
}
|
||||
|
||||
argTp := b.args[0].GetType()
|
||||
if !types.IsBinaryStr(argTp) {
|
||||
dBytes, err := charset.NewEncoding(argTp.Charset).EncodeString(val)
|
||||
if err == nil {
|
||||
return int64(len(dBytes)), false, nil
|
||||
}
|
||||
}
|
||||
|
||||
return int64(len([]byte(val))), false, nil
|
||||
}
|
||||
|
||||
@ -272,13 +263,6 @@ func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) {
|
||||
if len(val) == 0 {
|
||||
return 0, false, nil
|
||||
}
|
||||
argTp := b.args[0].GetType()
|
||||
if !types.IsBinaryStr(argTp) {
|
||||
dBytes, err := charset.NewEncoding(argTp.Charset).EncodeString(val)
|
||||
if err == nil {
|
||||
return int64(dBytes[0]), false, nil
|
||||
}
|
||||
}
|
||||
return int64(val[0]), false, nil
|
||||
}
|
||||
|
||||
@ -1664,7 +1648,7 @@ func (c *hexFunctionClass) getFunction(ctx sessionctx.Context, args []Expression
|
||||
argFieldTp := args[0].GetType()
|
||||
// Use UTF8MB4 as default.
|
||||
bf.tp.Flen = argFieldTp.Flen * 4 * 2
|
||||
sig := &builtinHexStrArgSig{bf, charset.NewEncoding(argFieldTp.Charset)}
|
||||
sig := &builtinHexStrArgSig{bf}
|
||||
sig.setPbCode(tipb.ScalarFuncSig_HexStrArg)
|
||||
return sig, nil
|
||||
case types.ETInt, types.ETReal, types.ETDecimal:
|
||||
@ -1684,15 +1668,11 @@ func (c *hexFunctionClass) getFunction(ctx sessionctx.Context, args []Expression
|
||||
|
||||
type builtinHexStrArgSig struct {
|
||||
baseBuiltinFunc
|
||||
encoding *charset.Encoding
|
||||
}
|
||||
|
||||
func (b *builtinHexStrArgSig) Clone() builtinFunc {
|
||||
newSig := &builtinHexStrArgSig{}
|
||||
newSig.cloneFrom(&b.baseBuiltinFunc)
|
||||
if b.encoding != nil {
|
||||
newSig.encoding = charset.NewEncoding(b.encoding.Name())
|
||||
}
|
||||
return newSig
|
||||
}
|
||||
|
||||
@ -1703,12 +1683,7 @@ func (b *builtinHexStrArgSig) evalString(row chunk.Row) (string, bool, error) {
|
||||
if isNull || err != nil {
|
||||
return d, isNull, err
|
||||
}
|
||||
dBytes := hack.Slice(d)
|
||||
dBytes, err = b.encoding.Encode(nil, dBytes)
|
||||
if err != nil {
|
||||
return d, false, err
|
||||
}
|
||||
return strings.ToUpper(hex.EncodeToString(dBytes)), false, nil
|
||||
return strings.ToUpper(hex.EncodeToString(hack.Slice(d))), false, nil
|
||||
}
|
||||
|
||||
type builtinHexIntArgSig struct {
|
||||
@ -3634,11 +3609,6 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e
|
||||
if isNull || err != nil {
|
||||
return "", isNull, err
|
||||
}
|
||||
argTp := b.args[0].GetType()
|
||||
str, err = charset.NewEncoding(argTp.Charset).EncodeString(str)
|
||||
if err != nil {
|
||||
return "", false, err
|
||||
}
|
||||
needEncodeLen := base64NeededEncodedLength(len(str))
|
||||
if needEncodeLen == -1 {
|
||||
return "", true, nil
|
||||
|
||||
@ -447,7 +447,6 @@ func (b *builtinHexStrArgSig) vecEvalString(input *chunk.Chunk, result *chunk.Co
|
||||
return err
|
||||
}
|
||||
defer b.bufAllocator.put(buf0)
|
||||
var encodedBuf []byte
|
||||
if err := b.args[0].VecEvalString(b.ctx, input, buf0); err != nil {
|
||||
return err
|
||||
}
|
||||
@ -457,13 +456,7 @@ func (b *builtinHexStrArgSig) vecEvalString(input *chunk.Chunk, result *chunk.Co
|
||||
result.AppendNull()
|
||||
continue
|
||||
}
|
||||
buf0Bytes := buf0.GetBytes(i)
|
||||
encodedBuf, err = b.encoding.Encode(encodedBuf, buf0Bytes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
buf0Bytes = encodedBuf
|
||||
result.AppendString(strings.ToUpper(hex.EncodeToString(buf0Bytes)))
|
||||
result.AppendString(strings.ToUpper(hex.EncodeToString(buf0.GetBytes(i))))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -912,11 +905,6 @@ func (b *builtinASCIISig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e
|
||||
if err = b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
argTp := b.args[0].GetType()
|
||||
enc := charset.NewEncoding(argTp.Charset)
|
||||
isBinaryStr := types.IsBinaryStr(argTp)
|
||||
|
||||
result.ResizeInt64(n, false)
|
||||
result.MergeNulls(buf)
|
||||
i64s := result.Int64s()
|
||||
@ -929,14 +917,6 @@ func (b *builtinASCIISig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e
|
||||
i64s[i] = 0
|
||||
continue
|
||||
}
|
||||
if !isBinaryStr {
|
||||
dBytes, err := enc.EncodeString(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i64s[i] = int64(dBytes[0])
|
||||
continue
|
||||
}
|
||||
i64s[i] = int64(str[0])
|
||||
}
|
||||
return nil
|
||||
@ -2162,27 +2142,14 @@ func (b *builtinLengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column)
|
||||
return err
|
||||
}
|
||||
|
||||
argTp := b.args[0].GetType()
|
||||
enc := charset.NewEncoding(argTp.Charset)
|
||||
isBinaryStr := types.IsBinaryStr(argTp)
|
||||
|
||||
result.ResizeInt64(n, false)
|
||||
result.MergeNulls(buf)
|
||||
i64s := result.Int64s()
|
||||
var encodeBuf []byte
|
||||
for i := 0; i < n; i++ {
|
||||
if result.IsNull(i) {
|
||||
continue
|
||||
}
|
||||
str := buf.GetBytes(i)
|
||||
if !isBinaryStr {
|
||||
dBytes, err := enc.Encode(encodeBuf, str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
i64s[i] = int64(len(dBytes))
|
||||
continue
|
||||
}
|
||||
i64s[i] = int64(len(str))
|
||||
}
|
||||
return nil
|
||||
@ -2470,20 +2437,13 @@ func (b *builtinToBase64Sig) vecEvalString(input *chunk.Chunk, result *chunk.Col
|
||||
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
argTp := b.args[0].GetType()
|
||||
enc := charset.NewEncoding(argTp.Charset)
|
||||
|
||||
result.ReserveString(n)
|
||||
for i := 0; i < n; i++ {
|
||||
if buf.IsNull(i) {
|
||||
result.AppendNull()
|
||||
continue
|
||||
}
|
||||
str, err := enc.EncodeString(buf.GetString(i))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
str := buf.GetString(i)
|
||||
needEncodeLen := base64NeededEncodedLength(len(str))
|
||||
if needEncodeLen == -1 {
|
||||
result.AppendNull()
|
||||
|
||||
@ -49,9 +49,20 @@ func newLonglong(value int64) *Constant {
|
||||
}
|
||||
}
|
||||
|
||||
func newString(value string, collation string) *Constant {
|
||||
return &Constant{
|
||||
Value: types.NewStringDatum(value),
|
||||
RetType: types.NewFieldTypeWithCollation(mysql.TypeVarchar, collation, 255),
|
||||
}
|
||||
}
|
||||
|
||||
func newFunction(funcName string, args ...Expression) Expression {
|
||||
typeLong := types.NewFieldType(mysql.TypeLonglong)
|
||||
return NewFunctionInternal(mock.NewContext(), funcName, typeLong, args...)
|
||||
return newFunctionWithType(funcName, mysql.TypeLonglong, args...)
|
||||
}
|
||||
|
||||
func newFunctionWithType(funcName string, tp byte, args ...Expression) Expression {
|
||||
ft := types.NewFieldType(tp)
|
||||
return NewFunctionInternal(mock.NewContext(), funcName, ft, args...)
|
||||
}
|
||||
|
||||
func TestConstantPropagation(t *testing.T) {
|
||||
@ -220,6 +231,31 @@ func TestConstantFolding(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestConstantFoldingCharsetConvert(t *testing.T) {
|
||||
t.Parallel()
|
||||
tests := []struct {
|
||||
condition Expression
|
||||
result string
|
||||
}{
|
||||
{
|
||||
condition: newFunction(ast.Length, newFunctionWithType(
|
||||
InternalFuncToBinary, mysql.TypeVarchar,
|
||||
newString("中文", "gbk_bin"))),
|
||||
result: "4",
|
||||
},
|
||||
{
|
||||
condition: newFunction(ast.Length, newFunctionWithType(
|
||||
InternalFuncToBinary, mysql.TypeVarchar,
|
||||
newString("中文", "utf8mb4_bin"))),
|
||||
result: "6",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
newConds := FoldConstant(tt.condition)
|
||||
require.Equalf(t, tt.result, newConds.String(), "different for expr %s", tt.condition)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeferredParamNotNull(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
|
||||
@ -21,7 +21,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/tidb/parser/charset"
|
||||
"github.com/pingcap/tidb/parser/model"
|
||||
"github.com/pingcap/tidb/parser/mysql"
|
||||
"github.com/pingcap/tidb/sessionctx"
|
||||
@ -965,11 +964,7 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
|
||||
case tipb.ScalarFuncSig_HexIntArg:
|
||||
f = &builtinHexIntArgSig{base}
|
||||
case tipb.ScalarFuncSig_HexStrArg:
|
||||
chs, args := "utf-8", base.getArgs()
|
||||
if len(args) == 1 {
|
||||
chs, _ = args[0].CharsetAndCollation()
|
||||
}
|
||||
f = &builtinHexStrArgSig{base, charset.NewEncoding(chs)}
|
||||
f = &builtinHexStrArgSig{base}
|
||||
case tipb.ScalarFuncSig_InsertUTF8:
|
||||
f = &builtinInsertUTF8Sig{base, maxAllowedPacket}
|
||||
case tipb.ScalarFuncSig_Insert:
|
||||
|
||||
2
go.mod
2
go.mod
@ -53,7 +53,7 @@ require (
|
||||
github.com/pingcap/sysutil v0.0.0-20210730114356-fcd8a63f68c5
|
||||
github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible
|
||||
github.com/pingcap/tidb/parser v0.0.0-20211011031125-9b13dc409c5e
|
||||
github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3
|
||||
github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8
|
||||
github.com/prometheus/client_golang v1.5.1
|
||||
github.com/prometheus/client_model v0.2.0
|
||||
github.com/prometheus/common v0.9.1
|
||||
|
||||
4
go.sum
4
go.sum
@ -600,8 +600,8 @@ github.com/pingcap/tidb-dashboard v0.0.0-20211008050453-a25c25809529/go.mod h1:O
|
||||
github.com/pingcap/tidb-dashboard v0.0.0-20211031170437-08e58c069a2a/go.mod h1:OCXbZTBTIMRcIt0jFsuCakZP+goYRv6IjawKbwLS2TQ=
|
||||
github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible h1:c7+izmker91NkjkZ6FgTlmD4k1A5FLOAq+li6Ki2/GY=
|
||||
github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM=
|
||||
github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3 h1:xnp/Qkk5gELlB8TaY6oro0JNXMBXTafNVxU/vbrNU8I=
|
||||
github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs=
|
||||
github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8 h1:Vu/6oq8EFNWgyXRHiclNzTKIu+YKHPCSI/Ba5oVrLtM=
|
||||
github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs=
|
||||
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
|
||||
Reference in New Issue
Block a user