Files
tidb/pkg/server/internal/parse/parse.go
2023-11-16 21:30:47 +00:00

173 lines
5.1 KiB
Go

// Copyright 2023 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package parse
import (
"bytes"
"context"
"encoding/binary"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/server/internal/handshake"
util2 "github.com/pingcap/tidb/pkg/server/internal/util"
"github.com/pingcap/tidb/pkg/util/logutil"
"go.uber.org/zap"
)
// maxFetchSize constants
const (
maxFetchSize = 1024
)
// StmtFetchCmd parse COM_STMT_FETCH command
func StmtFetchCmd(data []byte) (stmtID uint32, fetchSize uint32, err error) {
if len(data) != 8 {
return 0, 0, mysql.ErrMalformPacket
}
// Please refer to https://dev.mysql.com/doc/internals/en/com-stmt-fetch.html
stmtID = binary.LittleEndian.Uint32(data[0:4])
fetchSize = binary.LittleEndian.Uint32(data[4:8])
if fetchSize > maxFetchSize {
fetchSize = maxFetchSize
}
return
}
// HandshakeResponseHeader parses the common header of SSLRequest and Response41.
func HandshakeResponseHeader(ctx context.Context, packet *handshake.Response41, data []byte) (parsedBytes int, err error) {
// Ensure there are enough data to read:
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if len(data) < 4+4+1+23 {
logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data))
return 0, mysql.ErrMalformPacket
}
offset := 0
// capability
capability := binary.LittleEndian.Uint32(data[:4])
packet.Capability = capability
offset += 4
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23
return offset, nil
}
// HandshakeResponseBody parse the HandshakeResponse (except the common header part).
func HandshakeResponseBody(ctx context.Context, packet *handshake.Response41, data []byte, offset int) (err error) {
defer func() {
// Check malformat packet cause out of range is disgusting, but don't panic!
if r := recover(); r != nil {
logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data))
err = mysql.ErrMalformPacket
}
}()
// user name
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
offset += len(packet.User) + 1
if packet.Capability&mysql.ClientPluginAuthLenencClientData > 0 {
// MySQL client sets the wrong capability, it will set this bit even server doesn't
// support ClientPluginAuthLenencClientData.
// https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478
if data[offset] == 0x1 { // No auth data
offset += 2
} else {
num, null, off := util2.ParseLengthEncodedInt(data[offset:])
offset += off
if !null {
packet.Auth = data[offset : offset+int(num)]
offset += int(num)
}
}
} else if packet.Capability&mysql.ClientSecureConnection > 0 {
// auth length and auth
authLen := int(data[offset])
offset++
packet.Auth = data[offset : offset+authLen]
offset += authLen
} else {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
offset += len(packet.Auth) + 1
}
if packet.Capability&mysql.ClientConnectWithDB > 0 {
if len(data[offset:]) > 0 {
idx := bytes.IndexByte(data[offset:], 0)
packet.DBName = string(data[offset : offset+idx])
offset += idx + 1
}
}
if packet.Capability&mysql.ClientPluginAuth > 0 {
idx := bytes.IndexByte(data[offset:], 0)
s := offset
f := offset + idx
if s < f { // handle unexpected bad packets
packet.AuthPlugin = string(data[s:f])
}
offset += idx + 1
}
if packet.Capability&mysql.ClientConnectAtts > 0 {
if len(data[offset:]) == 0 {
// Defend some ill-formated packet, connection attribute is not important and can be ignored.
return nil
}
if num, null, intOff := util2.ParseLengthEncodedInt(data[offset:]); !null {
offset += intOff // Length of variable length encoded integer itself in bytes
row := data[offset : offset+int(num)]
attrs, err := parseAttrs(row)
if err != nil {
logutil.Logger(ctx).Warn("parse attrs failed", zap.Error(err))
return nil
}
packet.Attrs = attrs
offset += int(num) // Length of attributes
}
}
if packet.Capability&mysql.ClientZstdCompressionAlgorithm > 0 {
packet.ZstdLevel = int(data[offset])
}
return nil
}
func parseAttrs(data []byte) (map[string]string, error) {
attrs := make(map[string]string)
pos := 0
for pos < len(data) {
key, _, off, err := util2.ParseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off
value, _, off, err := util2.ParseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off
attrs[string(key)] = string(value)
}
return attrs, nil
}