148 lines
4.6 KiB
Go
148 lines
4.6 KiB
Go
// Copyright 2023 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"github.com/pingcap/tidb/pkg/errno"
|
|
"github.com/pingcap/tidb/pkg/param"
|
|
"github.com/pingcap/tidb/pkg/parser/charset"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
util2 "github.com/pingcap/tidb/pkg/server/internal/util"
|
|
"github.com/pingcap/tidb/pkg/types"
|
|
"github.com/pingcap/tidb/pkg/util/dbterror"
|
|
)
|
|
|
|
var errUnknownFieldType = dbterror.ClassServer.NewStd(errno.ErrUnknownFieldType)
|
|
|
|
// parseBinaryParams decodes the binary params according to the protocol
|
|
func parseBinaryParams(params []param.BinaryParam, boundParams [][]byte, nullBitmap, paramTypes, paramValues []byte, enc *util2.InputDecoder) (err error) {
|
|
pos := 0
|
|
if enc == nil {
|
|
enc = util2.NewInputDecoder(charset.CharsetUTF8)
|
|
}
|
|
|
|
for i := 0; i < len(params); i++ {
|
|
// if params had received via ComStmtSendLongData, use them directly.
|
|
// ref https://dev.mysql.com/doc/internals/en/com-stmt-send-long-data.html
|
|
// see clientConn#handleStmtSendLongData
|
|
if boundParams[i] != nil {
|
|
params[i] = param.BinaryParam{
|
|
Tp: mysql.TypeBlob,
|
|
Val: boundParams[i],
|
|
}
|
|
|
|
// The legacy logic is kept: if the `paramTypes` somehow didn't contain the type information, it will be treated as
|
|
// BLOB type. We didn't return `mysql.ErrMalformPacket` to keep compatibility with older versions, though it's
|
|
// meaningless if every clients work properly.
|
|
if (i<<1)+1 < len(paramTypes) {
|
|
// Only TEXT or BLOB type will be sent through `SEND_LONG_DATA`.
|
|
tp := paramTypes[i<<1]
|
|
|
|
switch tp {
|
|
case mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString, mysql.TypeBit:
|
|
params[i].Tp = tp
|
|
params[i].Val = enc.DecodeInput(boundParams[i])
|
|
case mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
|
|
params[i].Tp = tp
|
|
params[i].Val = boundParams[i]
|
|
}
|
|
}
|
|
continue
|
|
}
|
|
|
|
// check nullBitMap to determine the NULL arguments.
|
|
// ref https://dev.mysql.com/doc/internals/en/com-stmt-execute.html
|
|
// notice: some client(e.g. mariadb) will set nullBitMap even if data had be sent via ComStmtSendLongData,
|
|
// so this check need place after boundParam's check.
|
|
if nullBitmap[i>>3]&(1<<(uint(i)%8)) > 0 {
|
|
var nilDatum types.Datum
|
|
nilDatum.SetNull()
|
|
params[i] = param.BinaryParam{
|
|
Tp: mysql.TypeNull,
|
|
}
|
|
continue
|
|
}
|
|
|
|
if (i<<1)+1 >= len(paramTypes) {
|
|
return mysql.ErrMalformPacket
|
|
}
|
|
|
|
tp := paramTypes[i<<1]
|
|
isUnsigned := (paramTypes[(i<<1)+1] & 0x80) > 0
|
|
isNull := false
|
|
|
|
decodeWithDecoder := false
|
|
|
|
var length uint64
|
|
switch tp {
|
|
case mysql.TypeNull:
|
|
length = 0
|
|
isNull = true
|
|
case mysql.TypeTiny:
|
|
length = 1
|
|
case mysql.TypeShort, mysql.TypeYear:
|
|
length = 2
|
|
case mysql.TypeInt24, mysql.TypeLong, mysql.TypeFloat:
|
|
length = 4
|
|
case mysql.TypeLonglong, mysql.TypeDouble:
|
|
length = 8
|
|
case mysql.TypeDate, mysql.TypeTimestamp, mysql.TypeDatetime, mysql.TypeDuration:
|
|
if len(paramValues) < (pos + 1) {
|
|
err = mysql.ErrMalformPacket
|
|
return
|
|
}
|
|
length = uint64(paramValues[pos])
|
|
pos++
|
|
case mysql.TypeNewDecimal, mysql.TypeBlob, mysql.TypeTinyBlob, mysql.TypeMediumBlob, mysql.TypeLongBlob:
|
|
if len(paramValues) < (pos + 1) {
|
|
err = mysql.ErrMalformPacket
|
|
return
|
|
}
|
|
var n int
|
|
length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:])
|
|
pos += n
|
|
case mysql.TypeUnspecified, mysql.TypeVarchar, mysql.TypeVarString, mysql.TypeString,
|
|
mysql.TypeEnum, mysql.TypeSet, mysql.TypeGeometry, mysql.TypeBit:
|
|
if len(paramValues) < (pos + 1) {
|
|
err = mysql.ErrMalformPacket
|
|
return
|
|
}
|
|
var n int
|
|
length, isNull, n = util2.ParseLengthEncodedInt(paramValues[pos:])
|
|
pos += n
|
|
decodeWithDecoder = true
|
|
default:
|
|
err = errUnknownFieldType.GenWithStack("stmt unknown field type %d", tp)
|
|
return
|
|
}
|
|
|
|
if len(paramValues) < (pos + int(length)) {
|
|
err = mysql.ErrMalformPacket
|
|
return
|
|
}
|
|
params[i] = param.BinaryParam{
|
|
Tp: tp,
|
|
IsUnsigned: isUnsigned,
|
|
IsNull: isNull,
|
|
Val: paramValues[pos : pos+int(length)],
|
|
}
|
|
if decodeWithDecoder {
|
|
params[i].Val = enc.DecodeInput(params[i].Val)
|
|
}
|
|
pos += int(length)
|
|
}
|
|
return
|
|
}
|