458 lines
18 KiB
Go
458 lines
18 KiB
Go
// Copyright 2019 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package internal
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"compress/zlib"
|
|
"io"
|
|
"testing"
|
|
|
|
"github.com/klauspost/compress/zstd"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/server/internal/testutil"
|
|
"github.com/pingcap/tidb/pkg/server/internal/util"
|
|
"github.com/stretchr/testify/require"
|
|
)
|
|
|
|
func BenchmarkPacketIOWrite(b *testing.B) {
|
|
b.ResetTimer()
|
|
for i := 0; i < b.N; i++ {
|
|
var outBuffer bytes.Buffer
|
|
pkt := &PacketIO{bufWriter: bufio.NewWriter(&outBuffer)}
|
|
_ = pkt.WritePacket([]byte{0x6d, 0x44, 0x42, 0x3a, 0x35, 0x36, 0x0, 0x0, 0x0, 0xfc, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x68, 0x54, 0x49, 0x44, 0x3a, 0x31, 0x30, 0x38, 0x0, 0xfe})
|
|
}
|
|
}
|
|
|
|
func TestPacketIOWrite(t *testing.T) {
|
|
// Test write one packet
|
|
var outBuffer bytes.Buffer
|
|
pkt := &PacketIO{bufWriter: bufio.NewWriter(&outBuffer)}
|
|
err := pkt.WritePacket([]byte{0x00, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03})
|
|
require.NoError(t, err)
|
|
err = pkt.Flush()
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte{0x03, 0x00, 0x00, 0x00, 0x01, 0x02, 0x03}, outBuffer.Bytes())
|
|
|
|
// Test write more than one packet
|
|
outBuffer.Reset()
|
|
largeInput := make([]byte, mysql.MaxPayloadLen+4)
|
|
pkt = &PacketIO{bufWriter: bufio.NewWriter(&outBuffer)}
|
|
err = pkt.WritePacket(largeInput)
|
|
require.NoError(t, err)
|
|
err = pkt.Flush()
|
|
require.NoError(t, err)
|
|
res := outBuffer.Bytes()
|
|
require.Equal(t, byte(0xff), res[0])
|
|
require.Equal(t, byte(0xff), res[1])
|
|
require.Equal(t, byte(0xff), res[2])
|
|
require.Equal(t, byte(0), res[3])
|
|
}
|
|
|
|
func TestPacketIOWriteCompressed(t *testing.T) {
|
|
var testdata, outBuffer bytes.Buffer
|
|
|
|
seq := uint8(0)
|
|
pkt := &PacketIO{
|
|
bufWriter: bufio.NewWriter(&outBuffer),
|
|
compressionAlgorithm: mysql.CompressionZlib,
|
|
compressedWriter: newCompressedWriter(&testdata, mysql.CompressionZlib, &seq),
|
|
}
|
|
|
|
payload := bytes.Repeat([]byte{'A'}, 16*1024*1024)
|
|
err := pkt.WritePacket(payload)
|
|
require.NoError(t, err)
|
|
|
|
err = pkt.Flush()
|
|
require.NoError(t, err)
|
|
|
|
compressedLength := []byte{0x18, 0x4, 0x0} // 1048 bytes
|
|
packetNr := []byte{0x0}
|
|
uncompressedLength := []byte{0x0, 0x0, 0x10} // 1048576 bytes
|
|
|
|
require.Equal(t, compressedLength, testdata.Bytes()[:3])
|
|
require.Equal(t, packetNr, testdata.Bytes()[3:4])
|
|
require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])
|
|
}
|
|
|
|
func TestPacketIORead(t *testing.T) {
|
|
t.Run("uncompressed", func(t *testing.T) {
|
|
var inBuffer bytes.Buffer
|
|
_, err := inBuffer.Write([]byte{0x01, 0x00, 0x00, 0x00, 0x01})
|
|
require.NoError(t, err)
|
|
// Test read one packet
|
|
brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
|
|
pkt := NewPacketIO(brc)
|
|
readBytes, err := pkt.ReadPacket()
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(1), pkt.sequence)
|
|
require.Equal(t, []byte{0x01}, readBytes)
|
|
|
|
inBuffer.Reset()
|
|
buf := make([]byte, mysql.MaxPayloadLen+9)
|
|
buf[0] = 0xff
|
|
buf[1] = 0xff
|
|
buf[2] = 0xff
|
|
buf[3] = 0
|
|
buf[2+mysql.MaxPayloadLen] = 0x00
|
|
buf[3+mysql.MaxPayloadLen] = 0x00
|
|
buf[4+mysql.MaxPayloadLen] = 0x01
|
|
buf[7+mysql.MaxPayloadLen] = 0x01
|
|
buf[8+mysql.MaxPayloadLen] = 0x0a
|
|
|
|
_, err = inBuffer.Write(buf)
|
|
require.NoError(t, err)
|
|
// Test read multiple packets
|
|
brc = util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
|
|
pkt = NewPacketIO(brc)
|
|
readBytes, err = pkt.ReadPacket()
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(2), pkt.sequence)
|
|
require.Equal(t, mysql.MaxPayloadLen+1, len(readBytes))
|
|
require.Equal(t, byte(0x0a), readBytes[mysql.MaxPayloadLen])
|
|
})
|
|
t.Run("compressed_short", func(t *testing.T) {
|
|
var inBuffer bytes.Buffer
|
|
_, err := inBuffer.Write([]byte{0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00,
|
|
0x23, 0x00, 0x00, 0x00, 0x03, 0x00, 0x01, 0x73, 0x65, 0x6c, 0x65,
|
|
0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f,
|
|
0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c,
|
|
0x69, 0x6d, 0x69, 0x74, 0x20, 0x31})
|
|
|
|
require.NoError(t, err)
|
|
// Test read one packet
|
|
brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
|
|
pkt := NewPacketIO(brc)
|
|
pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
|
|
readBytes, err := pkt.ReadPacket()
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(1), pkt.sequence)
|
|
|
|
// 03 00 01 73 65 6c 65 63 74 20 40 40 76 65 72 73 |...select @@vers|
|
|
// 69 6f 6e 5f 63 6f 6d 6d 65 6e 74 20 6c 69 6d 69 |ion_comment limi|
|
|
// 74 20 31 |t 1|
|
|
expected := []byte{0x3, 0x0, 0x1, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20,
|
|
0x40, 0x40, 0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63,
|
|
0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69,
|
|
0x74, 0x20, 0x31}
|
|
|
|
require.Equal(t, expected, readBytes)
|
|
})
|
|
t.Run("zlib", func(t *testing.T) {
|
|
var inBuffer bytes.Buffer
|
|
|
|
// Header: 7 bytes (compressed_length<3>, packetnr<1>, compressed_length<3>)
|
|
// Payload: Starting with the magic nr for zlib: 0x78
|
|
_, err := inBuffer.Write([]byte{0x39, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00,
|
|
0x78, 0x5e, 0x73, 0x65, 0x60, 0x60, 0x60, 0x0e, 0x76, 0xf5, 0x71,
|
|
0x75, 0x0e, 0x51, 0x30, 0x54, 0xd0, 0xd7, 0x52, 0x48, 0x4c, 0x4a,
|
|
0x4e, 0x49, 0x4d, 0x4b, 0xcf, 0xc8, 0xcc, 0xca, 0xce, 0xc9, 0xcd,
|
|
0xcb, 0x2f, 0x28, 0x2c, 0x2a, 0x2e, 0x29, 0x2d, 0x2b, 0xaf, 0xa8,
|
|
0xac, 0x8a, 0xc7, 0x2d, 0xa5, 0xa0, 0xa5, 0x0f, 0x00, 0x59, 0xd8,
|
|
0x1a, 0x09})
|
|
|
|
require.NoError(t, err)
|
|
// Test read one packet
|
|
brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
|
|
pkt := NewPacketIO(brc)
|
|
pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
|
|
readBytes, err := pkt.ReadPacket()
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(1), pkt.sequence)
|
|
|
|
// 03 53 45 4c 45 43 54 20 31 20 2f 2a 20 61 62 63 |.SELECT 1 /* abc|
|
|
// 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70 71 72 73 |defghijklmnopqrs|
|
|
// 74 75 76 77 78 79 7a 5f 61 62 63 64 65 66 67 68 |tuvwxyz_abcdefgh|
|
|
// 69 6a 6b 6c 6d 6e 6f 70 71 72 73 74 75 76 77 78 |ijklmnopqrstuvwx|
|
|
// 79 7a 20 2a 2f |yz */|
|
|
|
|
expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x31, 0x20,
|
|
0x2f, 0x2a, 0x20, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68,
|
|
0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73,
|
|
0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x5f, 0x61, 0x62, 0x63,
|
|
0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
|
|
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
|
|
0x7a, 0x20, 0x2a, 0x2f}
|
|
|
|
require.Equal(t, expected, readBytes)
|
|
})
|
|
t.Run("zstd", func(t *testing.T) {
|
|
var inBuffer bytes.Buffer
|
|
// Header: 7 bytes (compressed_length<3>, packetnr<1>, compressed_length<3>)
|
|
// Payload: Starting with the magic nr for zstd: 0x28, 0xb5, 0x2f, 0xfd
|
|
_, err := inBuffer.Write([]byte{
|
|
0x40, 0x00, 0x00, 0x00, 0x49, 0x00, 0x00, 0x28, 0xb5, 0x2f, 0xfd, 0x20,
|
|
0x49, 0xbd, 0x01, 0x00, 0xf4, 0x02, 0x45, 0x00, 0x00, 0x00, 0x03, 0x53,
|
|
0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x31, 0x20, 0x2f, 0x2a, 0x20, 0x61,
|
|
0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d,
|
|
0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
|
|
0x7a, 0x5f, 0x20, 0x2a, 0x2f, 0x01, 0x00, 0x74, 0x7b, 0x96, 0x01})
|
|
|
|
require.NoError(t, err)
|
|
// Test read one packet
|
|
brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
|
|
pkt := NewPacketIO(brc)
|
|
pkt.SetCompressionAlgorithm(mysql.CompressionZstd)
|
|
readBytes, err := pkt.ReadPacket()
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(1), pkt.sequence)
|
|
|
|
// 03 53 45 4c 45 43 54 20 31 20 2f 2a 20 61 62 63 |.SELECT 1 /* abc|
|
|
// 64 65 66 67 68 69 6a 6b 6c 6d 6e 6f 70 71 72 73 |defghijklmnopqrs|
|
|
// 74 75 76 77 78 79 7a 5f 61 62 63 64 65 66 67 68 |tuvwxyz_abcdefgh|
|
|
// 69 6a 6b 6c 6d 6e 6f 70 71 72 73 74 75 76 77 78 |ijklmnopqrstuvwx|
|
|
// 79 7a 20 2a 2f |yz */|
|
|
|
|
expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x31, 0x20,
|
|
0x2f, 0x2a, 0x20, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68,
|
|
0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73,
|
|
0x74, 0x75, 0x76, 0x77, 0x78, 0x79, 0x7a, 0x5f, 0x61, 0x62, 0x63,
|
|
0x64, 0x65, 0x66, 0x67, 0x68, 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e,
|
|
0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, 0x78, 0x79,
|
|
0x7a, 0x20, 0x2a, 0x2f}
|
|
|
|
require.Equal(t, expected, readBytes)
|
|
})
|
|
}
|
|
|
|
// Small payloads of less than minCompressLength (50 bytes) don't get actually
|
|
// compressed, but have the same header as a compressed packet.
|
|
//
|
|
// Header:
|
|
// 0a 00 00 Compressed length: 10
|
|
// 00 Packetnr: 0
|
|
// 00 00 00 Uncompressed length: 0 (meaning not compressed)
|
|
//
|
|
// Payload:
|
|
// 74 65 73 74 5f 73 68 6f 72 74 test_short
|
|
func TestCompressedWriterShort(t *testing.T) {
|
|
var testdata bytes.Buffer
|
|
payload := []byte("test_short")
|
|
seq := uint8(0)
|
|
|
|
cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq)
|
|
cw.Write(payload)
|
|
cw.Flush()
|
|
|
|
// Test Header
|
|
compressedLength := []byte{0xa, 0x0, 0x0}
|
|
packetNr := []byte{0x0}
|
|
uncompressedLength := []byte{0x0, 0x0, 0x0}
|
|
require.Equal(t, compressedLength, testdata.Bytes()[:3])
|
|
require.Equal(t, packetNr, testdata.Bytes()[3:4])
|
|
require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])
|
|
|
|
// Payload, making sure it isn't compressed.
|
|
require.Equal(t, payload, testdata.Bytes()[7:])
|
|
}
|
|
|
|
func TestCompressedWriterLong(t *testing.T) {
|
|
t.Run("zlib", func(t *testing.T) {
|
|
var testdata, decoded bytes.Buffer
|
|
payload := []byte("test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib test_zlib")
|
|
seq := uint8(0)
|
|
|
|
cw := newCompressedWriter(&testdata, mysql.CompressionZlib, &seq)
|
|
cw.Write(payload)
|
|
cw.Flush()
|
|
|
|
// Header:
|
|
// 18 00 00 Compressed length: 24
|
|
// 00 Packetnr: 0
|
|
// 45 00 00 Uncompressed length: 69
|
|
compressedLength := []byte{0x18, 0x0, 0x0}
|
|
packetNr := []byte{0x0}
|
|
uncompressedLength := []byte{0x45, 0x0, 0x0}
|
|
require.Equal(t, compressedLength, testdata.Bytes()[:3])
|
|
require.Equal(t, packetNr, testdata.Bytes()[3:4])
|
|
require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])
|
|
|
|
// Payload:
|
|
r, err := zlib.NewReader(bytes.NewReader(testdata.Bytes()[7:]))
|
|
require.NoError(t, err)
|
|
io.Copy(&decoded, r)
|
|
require.Equal(t, payload, decoded.Bytes())
|
|
})
|
|
|
|
t.Run("zstd", func(t *testing.T) {
|
|
var testdata bytes.Buffer
|
|
payload := []byte("test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd test_zstd")
|
|
seq := uint8(0)
|
|
|
|
cw := newCompressedWriter(&testdata, mysql.CompressionZstd, &seq)
|
|
cw.Write(payload)
|
|
cw.Flush()
|
|
|
|
// Header:
|
|
// 1e 00 00 Compressed length: 30
|
|
// 00 Packetnr: 0
|
|
// 45 00 00 Uncompressed length: 69
|
|
compressedLength := []byte{0x1e, 0x0, 0x0}
|
|
packetNr := []byte{0x0}
|
|
uncompressedLength := []byte{0x45, 0x0, 0x0}
|
|
require.Equal(t, compressedLength, testdata.Bytes()[:3])
|
|
require.Equal(t, packetNr, testdata.Bytes()[3:4])
|
|
require.Equal(t, uncompressedLength, testdata.Bytes()[4:7])
|
|
|
|
// Payload:
|
|
d, err := zstd.NewReader(nil, zstd.WithDecoderConcurrency(0))
|
|
require.NoError(t, err)
|
|
decoded, err := d.DecodeAll(testdata.Bytes()[7:], nil)
|
|
require.NoError(t, err)
|
|
require.Equal(t, payload, decoded)
|
|
})
|
|
}
|
|
|
|
// TestCompressedReaderShort test a compressed protocol packet that has an uncompressed
|
|
// length of 0, which means the actual payload isn't compressed.
|
|
func TestCompressedReaderShort(t *testing.T) {
|
|
// payload: 7 bytes compressed header, 37 bytes uncompressed payload
|
|
payload := []byte{0x25, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x21, 0x0, 0x0, 0x0, 0x3,
|
|
0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40, 0x76, 0x65, 0x72,
|
|
0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d, 0x65, 0x6e, 0x74,
|
|
0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31}
|
|
r := bytes.NewReader(payload)
|
|
seq := uint8(0)
|
|
cr := newCompressedReader(r, mysql.CompressionZlib, &seq)
|
|
|
|
// Read 4 byte header from the payload. This is the regular packet header,
|
|
// not the compressed header.
|
|
header := make([]byte, 4)
|
|
_, err := cr.Read(header)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte{0x21, 0x0, 0x0, 0x0}, header)
|
|
|
|
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
|
|
sequence := header[3]
|
|
require.Equal(t, 33, length)
|
|
require.Equal(t, uint8(0), sequence)
|
|
|
|
data := make([]byte, length)
|
|
_, err = cr.Read(data)
|
|
require.NoError(t, err)
|
|
expected := []byte{0x3, 0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x40, 0x40,
|
|
0x76, 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x5f, 0x63, 0x6f, 0x6d, 0x6d,
|
|
0x65, 0x6e, 0x74, 0x20, 0x6c, 0x69, 0x6d, 0x69, 0x74, 0x20, 0x31}
|
|
require.Equal(t, expected, data)
|
|
}
|
|
|
|
func TestCompressedReaderLong(t *testing.T) {
|
|
t.Run("zlib", func(t *testing.T) {
|
|
payload := []byte{0x19, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x78, 0x5e, 0x9b,
|
|
0xc1, 0xc0, 0xc0, 0xc0, 0x1c, 0xec, 0xea, 0xe3, 0xea, 0x1c, 0xa2,
|
|
0xa0, 0xe4, 0x38, 0xa8, 0x80, 0x12, 0x0, 0xbe, 0xe6, 0x26, 0xce}
|
|
r := bytes.NewReader(payload)
|
|
seq := uint8(0)
|
|
cr := newCompressedReader(r, mysql.CompressionZlib, &seq)
|
|
header := make([]byte, 4)
|
|
_, err := cr.Read(header)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header)
|
|
|
|
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
|
|
sequence := header[3]
|
|
require.Equal(t, 152, length)
|
|
require.Equal(t, uint8(0), sequence)
|
|
|
|
data := make([]byte, length)
|
|
_, err = cr.Read(data)
|
|
require.NoError(t, err)
|
|
expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22}
|
|
require.Equal(t, expected, data)
|
|
})
|
|
t.Run("zstd", func(t *testing.T) {
|
|
payload := []byte{0x1f, 0x0, 0x0, 0x0, 0x9c, 0x0, 0x0, 0x28, 0xb5, 0x2f, 0xfd,
|
|
0x20, 0x9c, 0xb5, 0x0, 0x0, 0x78, 0x98, 0x0, 0x0, 0x0, 0x3, 0x53,
|
|
0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22, 0x41, 0x22, 0x1, 0x0, 0xa,
|
|
0xa, 0x28, 0x1}
|
|
r := bytes.NewReader(payload)
|
|
seq := uint8(0)
|
|
cr := newCompressedReader(r, mysql.CompressionZstd, &seq)
|
|
header := make([]byte, 4)
|
|
_, err := cr.Read(header)
|
|
require.NoError(t, err)
|
|
require.Equal(t, []byte{0x98, 0x0, 0x0, 0x0}, header)
|
|
|
|
length := int(uint32(header[0]) | uint32(header[1])<<8 | uint32(header[2])<<16)
|
|
sequence := header[3]
|
|
require.Equal(t, 152, length)
|
|
require.Equal(t, uint8(0), sequence)
|
|
|
|
data := make([]byte, length)
|
|
_, err = cr.Read(data)
|
|
require.NoError(t, err)
|
|
expected := []byte{0x3, 0x53, 0x45, 0x4c, 0x45, 0x43, 0x54, 0x20, 0x22,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41,
|
|
0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x41, 0x22}
|
|
require.Equal(t, expected, data)
|
|
})
|
|
}
|
|
|
|
// MariaDB Java Connecter 2.X would generate wrong sequence number for the sub header.
|
|
// TiDB should be compatible with it.
|
|
//
|
|
// MySQL Compressed Protocol Header:
|
|
// 0e 00 00 Compressed length
|
|
// 00 Compressed Packetnr
|
|
// 00 00 00 Uncompressed length
|
|
//
|
|
// MySQL Protocol Header:
|
|
// 0a 00 00 Payload length
|
|
// 01 Packet Sequence Number (it should be 0x00, but MariaDB Connector/J 2.x sets it to 0x01)
|
|
// 03 COM_QUERY
|
|
// 73 65 6c 65 63 74 20 31 3b "select 1;"
|
|
func TestSubHeaderWithWrongSequenceNumber(t *testing.T) {
|
|
var inBuffer bytes.Buffer
|
|
_, err := inBuffer.Write([]byte{0x0e, 0x00, 0x00, 0x00, 0x00, 0x00,
|
|
0x00, 0x0a, 0x00, 0x00, 0x01, 0x03, 0x73, 0x65, 0x6c,
|
|
0x65, 0x63, 0x74, 0x20, 0x31, 0x3b})
|
|
require.NoError(t, err)
|
|
// Test read one packet
|
|
brc := util.NewBufferedReadConn(&testutil.BytesConn{Buffer: inBuffer})
|
|
pkt := NewPacketIO(brc)
|
|
pkt.SetCompressionAlgorithm(mysql.CompressionZlib)
|
|
readBytes, err := pkt.ReadPacket()
|
|
require.NoError(t, err)
|
|
require.Equal(t, uint8(1), pkt.sequence)
|
|
require.Equal(t, uint8(1), pkt.compressedSequence)
|
|
require.Equal(t, []byte{0x03, 0x73, 0x65, 0x6c, 0x65, 0x63,
|
|
0x74, 0x20, 0x31, 0x3b}, readBytes)
|
|
}
|