Files
tidb/pkg/server/internal/packetio.go

557 lines
15 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// The MIT License (MIT)
//
// Copyright (c) 2014 wandoulabs
// Copyright (c) 2014 siddontang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
package internal
import (
"bufio"
"bytes"
"compress/zlib"
"io"
"time"
"github.com/klauspost/compress/zstd"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
server_err "github.com/pingcap/tidb/pkg/server/err"
"github.com/pingcap/tidb/pkg/server/internal/util"
server_metrics "github.com/pingcap/tidb/pkg/server/metrics"
"github.com/pingcap/tidb/pkg/sessionctx/vardef"
)
const defaultWriterSize = 16 * 1024
// PacketIO is a helper to read and write data in packet format.
// MySQL Packets: https://dev.mysql.com/doc/internals/en/mysql-packet.html
type PacketIO struct {
bufReadConn *util.BufferedReadConn
bufWriter *bufio.Writer
compressedWriter *compressedWriter
compressedReader *compressedReader
readTimeout time.Duration
// maxAllowedPacket is the maximum size of one packet in ReadPacket.
maxAllowedPacket uint64
// accumulatedLength count the length of totally received 'payload' in ReadPacket.
accumulatedLength uint64
compressionAlgorithm int
zstdLevel zstd.EncoderLevel
sequence uint8
compressedSequence uint8
}
// NewPacketIO creates a new PacketIO with given net.Conn.
func NewPacketIO(bufReadConn *util.BufferedReadConn) *PacketIO {
p := &PacketIO{sequence: 0, compressionAlgorithm: mysql.CompressionNone, compressedSequence: 0, zstdLevel: 3}
p.SetBufferedReadConn(bufReadConn)
p.SetMaxAllowedPacket(vardef.DefMaxAllowedPacket)
return p
}
// NewPacketIOForTest creates a new PacketIO with given bufio.Writer.
func NewPacketIOForTest(bufWriter *bufio.Writer) *PacketIO {
p := &PacketIO{}
p.SetBufWriter(bufWriter)
return p
}
// SetZstdLevel sets the zstd compression level.
func (p *PacketIO) SetZstdLevel(level zstd.EncoderLevel) {
p.zstdLevel = level
}
// Sequence returns the sequence of PacketIO.
func (p *PacketIO) Sequence() uint8 {
return p.sequence
}
// SetSequence sets the sequence of PacketIO.
func (p *PacketIO) SetSequence(s uint8) {
p.sequence = s
}
// SetCompressedSequence sets the compressed sequence of PacketIO.
func (p *PacketIO) SetCompressedSequence(s uint8) {
p.compressedSequence = s
}
// SetBufWriter sets the bufio.Writer of PacketIO.
func (p *PacketIO) SetBufWriter(bufWriter *bufio.Writer) {
p.bufWriter = bufWriter
}
// ResetBufWriter resets the bufio.Writer of PacketIO.
func (p *PacketIO) ResetBufWriter(w io.Writer) {
p.bufWriter.Reset(w)
}
// SetCompressionAlgorithm sets the compression algorithm of PacketIO.
func (p *PacketIO) SetCompressionAlgorithm(ca int) {
p.compressionAlgorithm = ca
p.compressedWriter = newCompressedWriter(p.bufReadConn, ca, &p.compressedSequence)
p.compressedWriter.zstdLevel = p.zstdLevel
p.compressedReader = newCompressedReader(p.bufReadConn, ca, &p.compressedSequence)
p.compressedReader.zstdLevel = p.zstdLevel
p.bufWriter.Flush()
}
// SetBufferedReadConn sets the BufferedReadConn of PacketIO.
func (p *PacketIO) SetBufferedReadConn(bufReadConn *util.BufferedReadConn) {
p.bufReadConn = bufReadConn
p.bufWriter = bufio.NewWriterSize(bufReadConn, defaultWriterSize)
}
// SetReadTimeout sets the read timeout of PacketIO.
func (p *PacketIO) SetReadTimeout(timeout time.Duration) {
p.readTimeout = timeout
}
func (p *PacketIO) readOnePacket() ([]byte, error) {
var header [4]byte
r := io.NopCloser(p.bufReadConn)
if p.readTimeout > 0 {
if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
return nil, err
}
}
if p.compressionAlgorithm == mysql.CompressionNone {
if _, err := io.ReadFull(r, header[:]); err != nil {
return nil, errors.Trace(err)
}
} else {
if _, err := io.ReadFull(p.compressedReader, header[:]); err != nil {
return nil, errors.Trace(err)
}
}
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
sequence := header[3]
if sequence != p.sequence {
err := server_err.ErrInvalidSequence.GenWithStack(
"invalid sequence, received %d while expecting %d", sequence, p.sequence)
if p.compressionAlgorithm == mysql.CompressionNone {
return nil, err
}
// To be compatible with MariaDB Connector/J 2.x,
// ignore sequence check and print a log when compression protocol is active.
terror.Log(err)
}
p.sequence++
// Accumulated payload length exceeds the limit.
if p.accumulatedLength += uint64(length); p.accumulatedLength > p.maxAllowedPacket {
terror.Log(server_err.ErrNetPacketTooLarge)
return nil, server_err.ErrNetPacketTooLarge
}
data := make([]byte, length)
if p.readTimeout > 0 {
if err := p.bufReadConn.SetReadDeadline(time.Now().Add(p.readTimeout)); err != nil {
return nil, err
}
}
if p.compressionAlgorithm == mysql.CompressionNone {
if _, err := io.ReadFull(r, data); err != nil {
return nil, errors.Trace(err)
}
} else {
if _, err := io.ReadFull(p.compressedReader, data); err != nil {
return nil, errors.Trace(err)
}
}
err := r.Close()
if err != nil {
return nil, errors.Trace(err)
}
return data, nil
}
// SetMaxAllowedPacket sets the max allowed packet size of PacketIO.
func (p *PacketIO) SetMaxAllowedPacket(maxAllowedPacket uint64) {
p.maxAllowedPacket = maxAllowedPacket
}
// ReadPacket reads a packet from the connection.
func (p *PacketIO) ReadPacket() ([]byte, error) {
p.accumulatedLength = 0
if p.readTimeout == 0 {
if err := p.bufReadConn.SetReadDeadline(time.Time{}); err != nil {
return nil, errors.Trace(err)
}
}
data, err := p.readOnePacket()
if err != nil {
return nil, errors.Trace(err)
}
if len(data) < mysql.MaxPayloadLen {
server_metrics.InPacketBytes.Add(float64(len(data)))
return data, nil
}
// handle multi-packet
for {
buf, err := p.readOnePacket()
if err != nil {
return nil, errors.Trace(err)
}
data = append(data, buf...)
if len(buf) < mysql.MaxPayloadLen {
break
}
}
server_metrics.InPacketBytes.Add(float64(len(data)))
return data, nil
}
// WritePacket writes data that already have header
func (p *PacketIO) WritePacket(data []byte) error {
length := len(data) - 4
server_metrics.OutPacketBytes.Add(float64(len(data)))
maxPayloadLen := mysql.MaxPayloadLen
for length >= maxPayloadLen {
data[3] = p.sequence
data[0] = 0xff
data[1] = 0xff
data[2] = 0xff
if p.compressionAlgorithm != mysql.CompressionNone {
if n, err := p.compressedWriter.Write(data[:4+maxPayloadLen]); err != nil {
return errors.Trace(mysql.ErrBadConn)
} else if n != (4 + maxPayloadLen) {
return errors.Trace(mysql.ErrBadConn)
}
} else {
if n, err := p.bufWriter.Write(data[:4+maxPayloadLen]); err != nil {
return errors.Trace(mysql.ErrBadConn)
} else if n != (4 + maxPayloadLen) {
return errors.Trace(mysql.ErrBadConn)
}
}
p.sequence++
length -= maxPayloadLen
data = data[maxPayloadLen:]
}
data[3] = p.sequence
data[0] = byte(length)
data[1] = byte(length >> 8)
data[2] = byte(length >> 16)
if p.compressionAlgorithm != mysql.CompressionNone {
if n, err := p.compressedWriter.Write(data); err != nil {
terror.Log(errors.Trace(err))
return errors.Trace(mysql.ErrBadConn)
} else if n != len(data) {
return errors.Trace(mysql.ErrBadConn)
}
p.sequence++
return nil
}
if n, err := p.bufWriter.Write(data); err != nil {
terror.Log(errors.Trace(err))
return errors.Trace(mysql.ErrBadConn)
} else if n != len(data) {
return errors.Trace(mysql.ErrBadConn)
}
p.sequence++
return nil
}
// Flush flushes buffered data to network.
func (p *PacketIO) Flush() error {
var err error
if p.compressionAlgorithm != mysql.CompressionNone {
err = p.compressedWriter.Flush()
} else {
err = p.bufWriter.Flush()
}
if err != nil {
return errors.Trace(err)
}
if p.compressionAlgorithm != mysql.CompressionNone {
p.sequence = p.compressedSequence
}
return err
}
func newCompressedWriter(w io.Writer, ca int, seq *uint8) *compressedWriter {
return &compressedWriter{
compressorBuffer{nil, nil, nil},
w,
new(bytes.Buffer),
seq,
nil,
ca,
3,
}
}
type compressedWriter struct {
compressorBuffer compressorBuffer
w io.Writer
buf *bytes.Buffer
compressedSequence *uint8
compressedPacket *bytes.Buffer
compressionAlgorithm int
zstdLevel zstd.EncoderLevel
}
func (cw *compressedWriter) Write(data []byte) (n int, err error) {
// MySQL starts with `net_buffer_length` (default 16384) and larger packets after that.
// The length itself must fit in the 3 byte field in the header.
// Can't be bigger then the max value for `net_buffer_length` (1048576)
maxCompressedSize := 1048576 // 1 MiB
for {
remainingLen := maxCompressedSize - cw.buf.Len()
if len(data) <= remainingLen {
written, err := cw.buf.Write(data)
if err != nil {
return 0, err
}
return n + written, nil
}
written, err := cw.buf.Write(data[:remainingLen])
if err != nil {
return 0, err
}
n += written
data = data[remainingLen:]
err = cw.Flush()
if err != nil {
return 0, err
}
}
}
func (cw *compressedWriter) Flush() error {
var w io.WriteCloser
var err error
if cw.compressorBuffer.payload == nil {
cw.compressorBuffer.payload = new(bytes.Buffer)
}
if cw.compressedPacket == nil {
cw.compressedPacket = new(bytes.Buffer)
}
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_basic_compression_packet.html
// suggests a MIN_COMPRESS_LENGTH of 50.
minCompressLength := 50
data := cw.buf.Bytes()
cw.buf.Reset()
switch cw.compressionAlgorithm {
case mysql.CompressionZlib:
if cw.compressorBuffer.zlibWriter == nil {
cw.compressorBuffer.zlibWriter, err = zlib.NewWriterLevel(cw.compressorBuffer.payload, mysql.ZlibCompressDefaultLevel)
}
w = cw.compressorBuffer.zlibWriter
case mysql.CompressionZstd:
if cw.compressorBuffer.zstdWriter == nil {
cw.compressorBuffer.zstdWriter, err = zstd.NewWriter(cw.compressorBuffer.payload, zstd.WithEncoderLevel(cw.zstdLevel))
}
w = cw.compressorBuffer.zstdWriter
default:
return errors.New("Unknown compression algorithm")
}
if err != nil {
return errors.Trace(err)
}
// always reset the compressed packet buffer.
defer cw.compressedPacket.Reset()
uncompressedLength := 0
compressedHeader := make([]byte, 7)
needCompress := len(data) > minCompressLength
if needCompress {
// only reset the payload buffer if we are compressing.
defer cw.compressorBuffer.Reset()
uncompressedLength = len(data)
_, err := w.Write(data)
if err != nil {
return errors.Trace(err)
}
err = w.Close()
if err != nil {
return errors.Trace(err)
}
}
var compressedLength int
if needCompress {
compressedLength = len(cw.compressorBuffer.payload.Bytes())
} else {
compressedLength = len(data)
}
compressedHeader[0] = byte(compressedLength)
compressedHeader[1] = byte(compressedLength >> 8)
compressedHeader[2] = byte(compressedLength >> 16)
compressedHeader[3] = *cw.compressedSequence
compressedHeader[4] = byte(uncompressedLength)
compressedHeader[5] = byte(uncompressedLength >> 8)
compressedHeader[6] = byte(uncompressedLength >> 16)
_, err = cw.compressedPacket.Write(compressedHeader)
if err != nil {
return errors.Trace(err)
}
*cw.compressedSequence++
if needCompress {
_, err = cw.compressedPacket.Write(cw.compressorBuffer.payload.Bytes())
} else {
_, err = cw.compressedPacket.Write(data)
}
if err != nil {
return errors.Trace(err)
}
_, err = cw.w.Write(cw.compressedPacket.Bytes())
if err != nil {
return errors.Trace(err)
}
return nil
}
type compressorBuffer struct {
payload *bytes.Buffer
zlibWriter *zlib.Writer
zstdWriter *zstd.Encoder
}
func (c *compressorBuffer) Reset() {
if c.payload != nil {
c.payload.Reset()
}
if c.zlibWriter != nil {
c.zlibWriter.Reset(c.payload)
}
if c.zstdWriter != nil {
c.zstdWriter.Reset(c.payload)
}
}
func newCompressedReader(r io.Reader, ca int, seq *uint8) *compressedReader {
return &compressedReader{
r,
seq,
nil,
ca,
3,
0,
}
}
type compressedReader struct {
r io.Reader
compressedSequence *uint8
data []byte
compressionAlgorithm int
zstdLevel zstd.EncoderLevel
pos uint64
}
func (cr *compressedReader) Read(data []byte) (n int, err error) {
if cr.data == nil {
var compressedHeader [7]byte
if _, err = io.ReadFull(cr.r, compressedHeader[:]); err != nil {
return
}
compressedLength := int(uint32(compressedHeader[0]) | uint32(compressedHeader[1])<<8 | uint32(compressedHeader[2])<<16)
compressedSequence := compressedHeader[3]
uncompressedLength := int(uint32(compressedHeader[4]) | uint32(compressedHeader[5])<<8 | uint32(compressedHeader[6])<<16)
if compressedSequence != *cr.compressedSequence {
return n, server_err.ErrInvalidSequence.GenWithStack(
"invalid compressed sequence, received %d while expecting %d", compressedSequence, cr.compressedSequence)
}
*cr.compressedSequence++
r := io.NopCloser(cr.r)
if uncompressedLength > 0 {
switch cr.compressionAlgorithm {
case mysql.CompressionZlib:
var err error
lr := io.LimitReader(cr.r, int64(compressedLength))
r, err = zlib.NewReader(lr)
if err != nil {
return n, errors.Trace(err)
}
case mysql.CompressionZstd:
zstdReader, err := zstd.NewReader(cr.r, zstd.WithDecoderConcurrency(1))
if err != nil {
return n, errors.Trace(err)
}
r = zstdReader.IOReadCloser()
default:
return n, errors.New("Unknown compression algorithm")
}
cr.data = make([]byte, uncompressedLength)
if _, err := io.ReadFull(r, cr.data); err != nil {
return n, errors.Trace(err)
}
n = copy(data, cr.data)
} else {
cr.data = make([]byte, compressedLength)
if _, err := io.ReadFull(r, cr.data); err != nil {
return n, errors.Trace(err)
}
n = copy(data, cr.data)
}
} else {
if cr.pos > uint64(len(cr.data)) {
return n, io.EOF
}
n = copy(data, cr.data[cr.pos:])
}
cr.pos += uint64(n)
if cr.pos == uint64(len(cr.data)) {
cr.pos = 0
cr.data = nil
}
return
}
func (*compressedReader) Close() error {
return nil
}