server: support decoding prepared string args to character_set_client (#30723)

This commit is contained in:
tangenta
2021-12-15 12:14:34 +08:00
committed by GitHub
parent 22418cd8c4
commit 04a9618f5c
4 changed files with 60 additions and 3 deletions

View File

@ -191,6 +191,7 @@ type clientConn struct {
authPlugin string // default authentication plugin
isUnixSocket bool // connection is Unix Socket file
rsEncoder *resultEncoder // rsEncoder is used to encode the string result to different charsets.
inputDecoder *inputDecoder // inputDecoder is used to decode the different charsets of incoming strings to utf-8.
socketCredUID uint32 // UID from the other end of the Unix Socket
// mu is used for cancelling the execution of current transaction.
mu struct {
@ -964,6 +965,15 @@ func (cc *clientConn) initResultEncoder(ctx context.Context) {
cc.rsEncoder = newResultEncoder(chs)
}
func (cc *clientConn) initInputEncoder(ctx context.Context) {
chs, err := variable.GetSessionOrGlobalSystemVar(cc.ctx.GetSessionVars(), variable.CharacterSetClient)
if err != nil {
chs = ""
logutil.Logger(ctx).Warn("get character_set_client system variable failed", zap.Error(err))
}
cc.inputDecoder = newInputDecoder(chs)
}
// initConnect runs the initConnect SQL statement if it has been specified.
// The semantics are MySQL compatible.
func (cc *clientConn) initConnect(ctx context.Context) error {

View File

@ -46,6 +46,7 @@ import (
"github.com/pingcap/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/parser/charset"
"github.com/pingcap/tidb/parser/mysql"
"github.com/pingcap/tidb/parser/terror"
plannercore "github.com/pingcap/tidb/planner/core"
@ -167,6 +168,8 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramTypes []byte
paramValues []byte
)
cc.initInputEncoder(ctx)
defer cc.inputDecoder.clean()
numParams := stmt.NumParams()
args := make([]types.Datum, numParams)
if numParams > 0 {
@ -194,7 +197,7 @@ func (cc *clientConn) handleStmtExecute(ctx context.Context, data []byte) (err e
paramValues = data[pos+1:]
}
err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues)
err = parseExecArgs(cc.ctx.GetSessionVars().StmtCtx, args, stmt.BoundParams(), nullBitmaps, stmt.GetParamsType(), paramValues, cc.inputDecoder)
stmt.Reset()
if err != nil {
return errors.Annotate(err, cc.preparedStmt2String(stmtID))
@ -310,7 +313,8 @@ func parseStmtFetchCmd(data []byte) (uint32, uint32, error) {
return stmtID, fetchSize, nil
}
func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte) (err error) {
func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams [][]byte,
nullBitmap, paramTypes, paramValues []byte, enc *inputDecoder) (err error) {
pos := 0
var (
tmp interface{}
@ -318,6 +322,9 @@ func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams
n int
isNull bool
)
if enc == nil {
enc = newInputDecoder(charset.CharsetUTF8)
}
for i := 0; i < len(args); i++ {
// if params had received via ComStmtSendLongData, use them directly.
@ -543,6 +550,7 @@ func parseExecArgs(sc *stmtctx.StatementContext, args []types.Datum, boundParams
}
if !isNull {
v = enc.decodeInput(v)
tmp = string(hack.String(v))
} else {
tmp = nil

View File

@ -197,12 +197,25 @@ func TestParseExecArgs(t *testing.T) {
},
}
for _, tt := range tests {
err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues)
err := parseExecArgs(&stmtctx.StatementContext{}, tt.args.args, tt.args.boundParams, tt.args.nullBitmap, tt.args.paramTypes, tt.args.paramValues, nil)
require.Truef(t, terror.ErrorEqual(err, tt.err), "err %v", err)
require.Equal(t, tt.expect, tt.args.args[0].GetValue())
}
}
func TestParseExecArgsAndEncode(t *testing.T) {
dt := make([]types.Datum, 1)
err := parseExecArgs(&stmtctx.StatementContext{},
dt,
[][]byte{nil},
[]byte{0x0},
[]byte{mysql.TypeVarchar, 0},
[]byte{4, 178, 226, 202, 212},
newInputDecoder("gbk"))
require.NoError(t, err)
require.Equal(t, "测试", dt[0].GetValue())
}
func TestParseStmtFetchCmd(t *testing.T) {
tests := []struct {
arg []byte

View File

@ -290,6 +290,32 @@ func dumpBinaryRow(buffer []byte, columns []*ColumnInfo, row chunk.Row, d *resul
return buffer, nil
}
type inputDecoder struct {
encoding *charset.Encoding
buffer []byte
}
func newInputDecoder(chs string) *inputDecoder {
return &inputDecoder{
encoding: charset.NewEncoding(chs),
buffer: nil,
}
}
// clean prevents the inputDecoder from holding too much memory.
func (i *inputDecoder) clean() {
i.buffer = nil
}
func (i *inputDecoder) decodeInput(src []byte) []byte {
result, err := i.encoding.Decode(i.buffer, src)
if err != nil {
return src
}
return result
}
type resultEncoder struct {
// chsName and encoding are unchanged after the initialization from
// session variable @@character_set_results.