expression: convert charset by wrapping internal builtin function (#29736)

This commit is contained in:
tangenta
2021-11-17 15:37:47 +08:00
committed by GitHub
parent fec2938c13
commit 7889f445a6
8 changed files with 192 additions and 85 deletions

View File

@ -140,6 +140,7 @@ func newBaseBuiltinFuncWithTp(ctx sessionctx.Context, funcName string, args []Ex
args[i] = WrapWithCastAsDecimal(ctx, args[i])
case types.ETString:
args[i] = WrapWithCastAsString(ctx, args[i])
args[i] = WrapWithToBinary(ctx, args[i], funcName)
case types.ETDatetime:
args[i] = WrapWithCastAsTime(ctx, args[i], types.NewFieldType(mysql.TypeDatetime))
case types.ETTimestamp:
@ -879,6 +880,9 @@ var funcs = map[string]functionClass{
ast.NextVal: &nextValFunctionClass{baseFunctionClass{ast.NextVal, 1, 1}},
ast.LastVal: &lastValFunctionClass{baseFunctionClass{ast.LastVal, 1, 1}},
ast.SetVal: &setValFunctionClass{baseFunctionClass{ast.SetVal, 2, 2}},
// TiDB implicit internal functions.
InternalFuncToBinary: &tidbConvertCharsetFunctionClass{baseFunctionClass{InternalFuncToBinary, 1, 1}},
}
// IsFunctionSupported check if given function name is a builtin sql function.
@ -902,6 +906,7 @@ func GetDisplayName(name string) string {
func GetBuiltinList() []string {
res := make([]string, 0, len(funcs))
notImplementedFunctions := []string{ast.RowFunc, ast.IsTruthWithNull}
implicitFunctions := []string{InternalFuncToBinary}
for funcName := range funcs {
skipFunc := false
// Skip not implemented functions
@ -910,6 +915,11 @@ func GetBuiltinList() []string {
skipFunc = true
}
}
for _, implicitFunc := range implicitFunctions {
if funcName == implicitFunc {
skipFunc = true
}
}
// Skip literal functions
// (their names are not readable: 'tidb`.(dateliteral, for example)
// See: https://github.com/pingcap/parser/pull/591

View File

@ -0,0 +1,136 @@
// Copyright 2021 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package expression
import (
"fmt"
"github.com/pingcap/tidb/parser/ast"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/types"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tipb/go-tipb"
)
// InternalFuncToBinary accepts a string and returns another string encoded in a given charset.
const InternalFuncToBinary = "to_binary"
type tidbConvertCharsetFunctionClass struct {
baseFunctionClass
}
func (c *tidbConvertCharsetFunctionClass) getFunction(ctx sessionctx.Context, args []Expression) (builtinFunc, error) {
if err := c.verifyArgs(args); err != nil {
return nil, c.verifyArgs(args)
}
argTp := args[0].GetType().EvalType()
var sig builtinFunc
switch argTp {
case types.ETString:
bf, err := newBaseBuiltinFuncWithTp(ctx, c.funcName, args, types.ETString, types.ETString)
if err != nil {
return nil, err
}
sig = &builtinInternalToBinarySig{bf}
sig.setPbCode(tipb.ScalarFuncSig_ToBinary)
default:
return nil, fmt.Errorf("unexpected argTp: %d", argTp)
}
return sig, nil
}
var _ builtinFunc = &builtinInternalToBinarySig{}
type builtinInternalToBinarySig struct {
baseBuiltinFunc
}
func (b *builtinInternalToBinarySig) Clone() builtinFunc {
newSig := &builtinInternalToBinarySig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
return newSig
}
func (b *builtinInternalToBinarySig) evalString(row chunk.Row) (res string, isNull bool, err error) {
val, isNull, err := b.args[0].EvalString(b.ctx, row)
if isNull || err != nil {
return res, isNull, err
}
tp := b.args[0].GetType()
enc := charset.NewEncoding(tp.Charset)
res, err = enc.EncodeString(val)
return res, false, err
}
func (b *builtinInternalToBinarySig) vectorized() bool {
return true
}
func (b *builtinInternalToBinarySig) vecEvalString(input *chunk.Chunk, result *chunk.Column) error {
n := input.NumRows()
buf, err := b.bufAllocator.get()
if err != nil {
return err
}
defer b.bufAllocator.put(buf)
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}
enc := charset.NewEncoding(b.args[0].GetType().Charset)
result.ReserveString(n)
for i := 0; i < n; i++ {
var str string
if buf.IsNull(i) {
result.AppendNull()
continue
}
str = buf.GetString(i)
str, err = enc.EncodeString(str)
if err != nil {
return err
}
result.AppendString(str)
}
return nil
}
// toBinaryMap contains the builtin functions which arguments need to be converted to the correct charset.
var toBinaryMap = map[string]struct{}{
ast.Hex: {}, ast.Length: {}, ast.OctetLength: {}, ast.ASCII: {},
ast.ToBase64: {},
}
// WrapWithToBinary wraps `expr` with to_binary sig.
func WrapWithToBinary(ctx sessionctx.Context, expr Expression, funcName string) Expression {
exprTp := expr.GetType()
if _, err := charset.GetDefaultCollationLegacy(exprTp.Charset); err != nil {
if _, ok := toBinaryMap[funcName]; ok {
fc := funcs[InternalFuncToBinary]
sig, err := fc.getFunction(ctx, []Expression{expr})
if err != nil {
return expr
}
sf := &ScalarFunction{
FuncName: model.NewCIStr(InternalFuncToBinary),
RetType: exprTp,
Function: sig,
}
return FoldConstant(sf)
}
}
return expr
}

View File

@ -222,15 +222,6 @@ func (b *builtinLengthSig) evalInt(row chunk.Row) (int64, bool, error) {
if isNull || err != nil {
return 0, isNull, err
}
argTp := b.args[0].GetType()
if !types.IsBinaryStr(argTp) {
dBytes, err := charset.NewEncoding(argTp.Charset).EncodeString(val)
if err == nil {
return int64(len(dBytes)), false, nil
}
}
return int64(len([]byte(val))), false, nil
}
@ -272,13 +263,6 @@ func (b *builtinASCIISig) evalInt(row chunk.Row) (int64, bool, error) {
if len(val) == 0 {
return 0, false, nil
}
argTp := b.args[0].GetType()
if !types.IsBinaryStr(argTp) {
dBytes, err := charset.NewEncoding(argTp.Charset).EncodeString(val)
if err == nil {
return int64(dBytes[0]), false, nil
}
}
return int64(val[0]), false, nil
}
@ -1664,7 +1648,7 @@ func (c *hexFunctionClass) getFunction(ctx sessionctx.Context, args []Expression
argFieldTp := args[0].GetType()
// Use UTF8MB4 as default.
bf.tp.Flen = argFieldTp.Flen * 4 * 2
sig := &builtinHexStrArgSig{bf, charset.NewEncoding(argFieldTp.Charset)}
sig := &builtinHexStrArgSig{bf}
sig.setPbCode(tipb.ScalarFuncSig_HexStrArg)
return sig, nil
case types.ETInt, types.ETReal, types.ETDecimal:
@ -1684,15 +1668,11 @@ func (c *hexFunctionClass) getFunction(ctx sessionctx.Context, args []Expression
type builtinHexStrArgSig struct {
baseBuiltinFunc
encoding *charset.Encoding
}
func (b *builtinHexStrArgSig) Clone() builtinFunc {
newSig := &builtinHexStrArgSig{}
newSig.cloneFrom(&b.baseBuiltinFunc)
if b.encoding != nil {
newSig.encoding = charset.NewEncoding(b.encoding.Name())
}
return newSig
}
@ -1703,12 +1683,7 @@ func (b *builtinHexStrArgSig) evalString(row chunk.Row) (string, bool, error) {
if isNull || err != nil {
return d, isNull, err
}
dBytes := hack.Slice(d)
dBytes, err = b.encoding.Encode(nil, dBytes)
if err != nil {
return d, false, err
}
return strings.ToUpper(hex.EncodeToString(dBytes)), false, nil
return strings.ToUpper(hex.EncodeToString(hack.Slice(d))), false, nil
}
type builtinHexIntArgSig struct {
@ -3634,11 +3609,6 @@ func (b *builtinToBase64Sig) evalString(row chunk.Row) (d string, isNull bool, e
if isNull || err != nil {
return "", isNull, err
}
argTp := b.args[0].GetType()
str, err = charset.NewEncoding(argTp.Charset).EncodeString(str)
if err != nil {
return "", false, err
}
needEncodeLen := base64NeededEncodedLength(len(str))
if needEncodeLen == -1 {
return "", true, nil

View File

@ -447,7 +447,6 @@ func (b *builtinHexStrArgSig) vecEvalString(input *chunk.Chunk, result *chunk.Co
return err
}
defer b.bufAllocator.put(buf0)
var encodedBuf []byte
if err := b.args[0].VecEvalString(b.ctx, input, buf0); err != nil {
return err
}
@ -457,13 +456,7 @@ func (b *builtinHexStrArgSig) vecEvalString(input *chunk.Chunk, result *chunk.Co
result.AppendNull()
continue
}
buf0Bytes := buf0.GetBytes(i)
encodedBuf, err = b.encoding.Encode(encodedBuf, buf0Bytes)
if err != nil {
return err
}
buf0Bytes = encodedBuf
result.AppendString(strings.ToUpper(hex.EncodeToString(buf0Bytes)))
result.AppendString(strings.ToUpper(hex.EncodeToString(buf0.GetBytes(i))))
}
return nil
}
@ -912,11 +905,6 @@ func (b *builtinASCIISig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e
if err = b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}
argTp := b.args[0].GetType()
enc := charset.NewEncoding(argTp.Charset)
isBinaryStr := types.IsBinaryStr(argTp)
result.ResizeInt64(n, false)
result.MergeNulls(buf)
i64s := result.Int64s()
@ -929,14 +917,6 @@ func (b *builtinASCIISig) vecEvalInt(input *chunk.Chunk, result *chunk.Column) e
i64s[i] = 0
continue
}
if !isBinaryStr {
dBytes, err := enc.EncodeString(str)
if err != nil {
return err
}
i64s[i] = int64(dBytes[0])
continue
}
i64s[i] = int64(str[0])
}
return nil
@ -2162,27 +2142,14 @@ func (b *builtinLengthSig) vecEvalInt(input *chunk.Chunk, result *chunk.Column)
return err
}
argTp := b.args[0].GetType()
enc := charset.NewEncoding(argTp.Charset)
isBinaryStr := types.IsBinaryStr(argTp)
result.ResizeInt64(n, false)
result.MergeNulls(buf)
i64s := result.Int64s()
var encodeBuf []byte
for i := 0; i < n; i++ {
if result.IsNull(i) {
continue
}
str := buf.GetBytes(i)
if !isBinaryStr {
dBytes, err := enc.Encode(encodeBuf, str)
if err != nil {
return err
}
i64s[i] = int64(len(dBytes))
continue
}
i64s[i] = int64(len(str))
}
return nil
@ -2470,20 +2437,13 @@ func (b *builtinToBase64Sig) vecEvalString(input *chunk.Chunk, result *chunk.Col
if err := b.args[0].VecEvalString(b.ctx, input, buf); err != nil {
return err
}
argTp := b.args[0].GetType()
enc := charset.NewEncoding(argTp.Charset)
result.ReserveString(n)
for i := 0; i < n; i++ {
if buf.IsNull(i) {
result.AppendNull()
continue
}
str, err := enc.EncodeString(buf.GetString(i))
if err != nil {
return err
}
str := buf.GetString(i)
needEncodeLen := base64NeededEncodedLength(len(str))
if needEncodeLen == -1 {
result.AppendNull()

View File

@ -49,9 +49,20 @@ func newLonglong(value int64) *Constant {
}
}
func newString(value string, collation string) *Constant {
return &Constant{
Value: types.NewStringDatum(value),
RetType: types.NewFieldTypeWithCollation(mysql.TypeVarchar, collation, 255),
}
}
func newFunction(funcName string, args ...Expression) Expression {
typeLong := types.NewFieldType(mysql.TypeLonglong)
return NewFunctionInternal(mock.NewContext(), funcName, typeLong, args...)
return newFunctionWithType(funcName, mysql.TypeLonglong, args...)
}
func newFunctionWithType(funcName string, tp byte, args ...Expression) Expression {
ft := types.NewFieldType(tp)
return NewFunctionInternal(mock.NewContext(), funcName, ft, args...)
}
func TestConstantPropagation(t *testing.T) {
@ -220,6 +231,31 @@ func TestConstantFolding(t *testing.T) {
}
}
func TestConstantFoldingCharsetConvert(t *testing.T) {
t.Parallel()
tests := []struct {
condition Expression
result string
}{
{
condition: newFunction(ast.Length, newFunctionWithType(
InternalFuncToBinary, mysql.TypeVarchar,
newString("中文", "gbk_bin"))),
result: "4",
},
{
condition: newFunction(ast.Length, newFunctionWithType(
InternalFuncToBinary, mysql.TypeVarchar,
newString("中文", "utf8mb4_bin"))),
result: "6",
},
}
for _, tt := range tests {
newConds := FoldConstant(tt.condition)
require.Equalf(t, tt.result, newConds.String(), "different for expr %s", tt.condition)
}
}
func TestDeferredParamNotNull(t *testing.T) {
t.Parallel()

View File

@ -21,7 +21,6 @@ import (
"time"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/model"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/sessionctx"
@ -965,11 +964,7 @@ func getSignatureByPB(ctx sessionctx.Context, sigCode tipb.ScalarFuncSig, tp *ti
case tipb.ScalarFuncSig_HexIntArg:
f = &builtinHexIntArgSig{base}
case tipb.ScalarFuncSig_HexStrArg:
chs, args := "utf-8", base.getArgs()
if len(args) == 1 {
chs, _ = args[0].CharsetAndCollation()
}
f = &builtinHexStrArgSig{base, charset.NewEncoding(chs)}
f = &builtinHexStrArgSig{base}
case tipb.ScalarFuncSig_InsertUTF8:
f = &builtinInsertUTF8Sig{base, maxAllowedPacket}
case tipb.ScalarFuncSig_Insert:

2
go.mod
View File

@ -53,7 +53,7 @@ require (
github.com/pingcap/sysutil v0.0.0-20210730114356-fcd8a63f68c5
github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible
github.com/pingcap/tidb/parser v0.0.0-20211011031125-9b13dc409c5e
github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3
github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8
github.com/prometheus/client_golang v1.5.1
github.com/prometheus/client_model v0.2.0
github.com/prometheus/common v0.9.1

4
go.sum
View File

@ -600,8 +600,8 @@ github.com/pingcap/tidb-dashboard v0.0.0-20211008050453-a25c25809529/go.mod h1:O
github.com/pingcap/tidb-dashboard v0.0.0-20211031170437-08e58c069a2a/go.mod h1:OCXbZTBTIMRcIt0jFsuCakZP+goYRv6IjawKbwLS2TQ=
github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible h1:c7+izmker91NkjkZ6FgTlmD4k1A5FLOAq+li6Ki2/GY=
github.com/pingcap/tidb-tools v5.2.2-0.20211019062242-37a8bef2fa17+incompatible/go.mod h1:XGdcy9+yqlDSEMTpOXnwf3hiTeqrV6MN/u1se9N8yIM=
github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3 h1:xnp/Qkk5gELlB8TaY6oro0JNXMBXTafNVxU/vbrNU8I=
github.com/pingcap/tipb v0.0.0-20211105090418-71142a4d40e3/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs=
github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8 h1:Vu/6oq8EFNWgyXRHiclNzTKIu+YKHPCSI/Ba5oVrLtM=
github.com/pingcap/tipb v0.0.0-20211116093845-e9b045a0bdf8/go.mod h1:A7mrd7WHBl1o63LE2bIBGEJMTNWXqhgmYiOvMLxozfs=
github.com/pkg/errors v0.8.0/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=