Files
tidb/server/conn.go
2019-08-13 21:04:11 +08:00

1502 lines
47 KiB
Go

// Copyright 2013 The Go-MySQL-Driver Authors. All rights reserved.
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this file,
// You can obtain one at http://mozilla.org/MPL/2.0/.
// The MIT License (MIT)
//
// Copyright (c) 2014 wandoulabs
// Copyright (c) 2014 siddontang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package server
import (
"bytes"
"context"
"crypto/tls"
"encoding/binary"
"fmt"
"io"
"net"
"runtime"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/opentracing/opentracing-go"
"github.com/pingcap/errors"
"github.com/pingcap/failpoint"
"github.com/pingcap/parser/auth"
"github.com/pingcap/parser/mysql"
"github.com/pingcap/parser/terror"
"github.com/pingcap/tidb/executor"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/metrics"
"github.com/pingcap/tidb/plugin"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/util/arena"
"github.com/pingcap/tidb/util/chunk"
"github.com/pingcap/tidb/util/hack"
"github.com/pingcap/tidb/util/logutil"
"github.com/pingcap/tidb/util/memory"
"github.com/pingcap/tidb/util/sqlexec"
"go.uber.org/zap"
)
const (
connStatusDispatching int32 = iota
connStatusReading
connStatusShutdown // Closed by server.
connStatusWaitShutdown // Notified by server to close.
)
var (
queryTotalCounterComSleepOK = metrics.QueryTotalCounter.WithLabelValues("Sleep", "OK")
queryTotalCounterComSleepError = metrics.QueryTotalCounter.WithLabelValues("Sleep", "Error")
queryTotalCounterComQuitOK = metrics.QueryTotalCounter.WithLabelValues("Quit", "OK")
queryTotalCounterComQuitError = metrics.QueryTotalCounter.WithLabelValues("Quit", "Error")
queryTotalCounterComInitDBOK = metrics.QueryTotalCounter.WithLabelValues("InitDB", "OK")
queryTotalCounterComInitDBError = metrics.QueryTotalCounter.WithLabelValues("InitDB", "Error")
queryTotalCounterComQueryOK = metrics.QueryTotalCounter.WithLabelValues("Query", "OK")
queryTotalCounterComQueryError = metrics.QueryTotalCounter.WithLabelValues("Query", "Error")
queryTotalCounterComPingOK = metrics.QueryTotalCounter.WithLabelValues("Ping", "OK")
queryTotalCounterComPingError = metrics.QueryTotalCounter.WithLabelValues("Ping", "Error")
queryTotalCounterComFieldListOK = metrics.QueryTotalCounter.WithLabelValues("FieldList", "OK")
queryTotalCounterComFieldListError = metrics.QueryTotalCounter.WithLabelValues("FieldList", "Error")
queryTotalCounterComPrepareOK = metrics.QueryTotalCounter.WithLabelValues("StmtPrepare", "OK")
queryTotalCounterComPrepareError = metrics.QueryTotalCounter.WithLabelValues("StmtPrepare", "Error")
queryTotalCounterComExecuteOK = metrics.QueryTotalCounter.WithLabelValues("StmtExecute", "OK")
queryTotalCounterComExecuteError = metrics.QueryTotalCounter.WithLabelValues("StmtExecute", "Error")
queryTotalCounterComFetchOK = metrics.QueryTotalCounter.WithLabelValues("StmtFetch", "OK")
queryTotalCounterComFetchError = metrics.QueryTotalCounter.WithLabelValues("StmtFetch", "Error")
queryTotalCounterComCloseOK = metrics.QueryTotalCounter.WithLabelValues("StmtClose", "OK")
queryTotalCounterComCloseError = metrics.QueryTotalCounter.WithLabelValues("StmtClose", "Error")
queryTotalCounterComSendLongDataOK = metrics.QueryTotalCounter.WithLabelValues("StmtSendLongData", "OK")
queryTotalCounterComSendLongDataError = metrics.QueryTotalCounter.WithLabelValues("StmtSendLongData", "Error")
queryTotalCounterComResetOK = metrics.QueryTotalCounter.WithLabelValues("StmtReset", "OK")
queryTotalCounterComResetError = metrics.QueryTotalCounter.WithLabelValues("StmtReset", "Error")
queryTotalCounterComSetOptionOK = metrics.QueryTotalCounter.WithLabelValues("SetOption", "OK")
queryTotalCounterComSetOptionError = metrics.QueryTotalCounter.WithLabelValues("SetOption", "Error")
queryDurationHistogramUse = metrics.QueryDurationHistogram.WithLabelValues("Use")
queryDurationHistogramShow = metrics.QueryDurationHistogram.WithLabelValues("Show")
queryDurationHistogramBegin = metrics.QueryDurationHistogram.WithLabelValues("Begin")
queryDurationHistogramCommit = metrics.QueryDurationHistogram.WithLabelValues("Commit")
queryDurationHistogramRollback = metrics.QueryDurationHistogram.WithLabelValues("Rollback")
queryDurationHistogramInsert = metrics.QueryDurationHistogram.WithLabelValues("Insert")
queryDurationHistogramReplace = metrics.QueryDurationHistogram.WithLabelValues("Replace")
queryDurationHistogramDelete = metrics.QueryDurationHistogram.WithLabelValues("Delete")
queryDurationHistogramUpdate = metrics.QueryDurationHistogram.WithLabelValues("Update")
queryDurationHistogramSelect = metrics.QueryDurationHistogram.WithLabelValues("Select")
queryDurationHistogramExecute = metrics.QueryDurationHistogram.WithLabelValues("Execute")
queryDurationHistogramSet = metrics.QueryDurationHistogram.WithLabelValues("Set")
queryDurationHistogramGeneral = metrics.QueryDurationHistogram.WithLabelValues(metrics.LblGeneral)
)
// newClientConn creates a *clientConn object.
func newClientConn(s *Server) *clientConn {
return &clientConn{
server: s,
connectionID: atomic.AddUint32(&baseConnID, 1),
collation: mysql.DefaultCollationID,
alloc: arena.NewAllocator(32 * 1024),
status: connStatusDispatching,
}
}
// clientConn represents a connection between server and client, it maintains connection specific state,
// handles client query.
type clientConn struct {
pkt *packetIO // a helper to read and write data in packet format.
bufReadConn *bufferedReadConn // a buffered-read net.Conn or buffered-read tls.Conn.
tlsConn *tls.Conn // TLS connection, nil if not TLS.
server *Server // a reference of server instance.
capability uint32 // client capability affects the way server handles client request.
connectionID uint32 // atomically allocated by a global variable, unique in process scope.
user string // user of the client.
dbname string // default database name.
salt []byte // random bytes used for authentication.
alloc arena.Allocator // an memory allocator for reducing memory allocation.
lastCmd string // latest sql query string, currently used for logging error.
ctx QueryCtx // an interface to execute sql statements.
attrs map[string]string // attributes parsed from client handshake response, not used for now.
peerHost string // peer host
peerPort string // peer port
status int32 // dispatching/reading/shutdown/waitshutdown
lastCode uint16 // last error code
collation uint8 // collation used by client, may be different from the collation used by database.
}
func (cc *clientConn) String() string {
collationStr := mysql.Collations[cc.collation]
return fmt.Sprintf("id:%d, addr:%s status:%d, collation:%s, user:%s",
cc.connectionID, cc.bufReadConn.RemoteAddr(), cc.ctx.Status(), collationStr, cc.user,
)
}
// handshake works like TCP handshake, but in a higher level, it first writes initial packet to client,
// during handshake, client and server negotiate compatible features and do authentication.
// After handshake, client can send sql query to server.
func (cc *clientConn) handshake(ctx context.Context) error {
if err := cc.writeInitialHandshake(); err != nil {
return err
}
if err := cc.readOptionalSSLRequestAndHandshakeResponse(ctx); err != nil {
err1 := cc.writeError(err)
if err1 != nil {
logutil.Logger(ctx).Debug("writeError failed", zap.Error(err1))
}
return err
}
data := cc.alloc.AllocWithLen(4, 32)
data = append(data, mysql.OKHeader)
data = append(data, 0, 0)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dumpUint16(data, mysql.ServerStatusAutocommit)
data = append(data, 0, 0)
}
err := cc.writePacket(data)
cc.pkt.sequence = 0
if err != nil {
return err
}
return cc.flush()
}
func (cc *clientConn) Close() error {
cc.server.rwlock.Lock()
delete(cc.server.clients, cc.connectionID)
connections := len(cc.server.clients)
cc.server.rwlock.Unlock()
return closeConn(cc, connections)
}
func closeConn(cc *clientConn, connections int) error {
metrics.ConnGauge.Set(float64(connections))
err := cc.bufReadConn.Close()
terror.Log(err)
if cc.ctx != nil {
return cc.ctx.Close()
}
return nil
}
func (cc *clientConn) closeWithoutLock() error {
delete(cc.server.clients, cc.connectionID)
return closeConn(cc, len(cc.server.clients))
}
// writeInitialHandshake sends server version, connection ID, server capability, collation, server status
// and auth salt to the client.
func (cc *clientConn) writeInitialHandshake() error {
data := make([]byte, 4, 128)
// min version 10
data = append(data, 10)
// server version[00]
data = append(data, mysql.ServerVersion...)
data = append(data, 0)
// connection id
data = append(data, byte(cc.connectionID), byte(cc.connectionID>>8), byte(cc.connectionID>>16), byte(cc.connectionID>>24))
// auth-plugin-data-part-1
data = append(data, cc.salt[0:8]...)
// filler [00]
data = append(data, 0)
// capability flag lower 2 bytes, using default capability here
data = append(data, byte(cc.server.capability), byte(cc.server.capability>>8))
// charset
if cc.collation == 0 {
cc.collation = uint8(mysql.DefaultCollationID)
}
data = append(data, cc.collation)
// status
data = dumpUint16(data, mysql.ServerStatusAutocommit)
// below 13 byte may not be used
// capability flag upper 2 bytes, using default capability here
data = append(data, byte(cc.server.capability>>16), byte(cc.server.capability>>24))
// length of auth-plugin-data
data = append(data, byte(len(cc.salt)+1))
// reserved 10 [00]
data = append(data, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0)
// auth-plugin-data-part-2
data = append(data, cc.salt[8:]...)
data = append(data, 0)
// auth-plugin name
data = append(data, []byte("mysql_native_password")...)
data = append(data, 0)
err := cc.writePacket(data)
if err != nil {
return err
}
return cc.flush()
}
func (cc *clientConn) readPacket() ([]byte, error) {
return cc.pkt.readPacket()
}
func (cc *clientConn) writePacket(data []byte) error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.writePacket(data)
}
// getSessionVarsWaitTimeout get session variable wait_timeout
func (cc *clientConn) getSessionVarsWaitTimeout(ctx context.Context) uint64 {
valStr, exists := cc.ctx.GetSessionVars().GetSystemVar(variable.WaitTimeout)
if !exists {
return variable.DefWaitTimeout
}
waitTimeout, err := strconv.ParseUint(valStr, 10, 64)
if err != nil {
logutil.Logger(ctx).Warn("get sysval wait_timeout error, use default value", zap.Error(err))
// if get waitTimeout error, use default value
return variable.DefWaitTimeout
}
return waitTimeout
}
type handshakeResponse41 struct {
Capability uint32
Collation uint8
User string
DBName string
Auth []byte
Attrs map[string]string
}
// parseOldHandshakeResponseHeader parses the old version handshake header HandshakeResponse320
func parseOldHandshakeResponseHeader(ctx context.Context, packet *handshakeResponse41, data []byte) (parsedBytes int, err error) {
// Ensure there are enough data to read:
// https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse320
logutil.Logger(ctx).Debug("try to parse hanshake response as Protocol::HandshakeResponse320", zap.ByteString("packetData", data))
if len(data) < 2+3 {
logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data))
return 0, mysql.ErrMalformPacket
}
offset := 0
// capability
capability := binary.LittleEndian.Uint16(data[:2])
packet.Capability = uint32(capability)
// be compatible with Protocol::HandshakeResponse41
packet.Capability = packet.Capability | mysql.ClientProtocol41
offset += 2
// skip max packet size
offset += 3
// usa default CharsetID
packet.Collation = mysql.CollationNames["utf8mb4_general_ci"]
return offset, nil
}
// parseOldHandshakeResponseBody parse the HandshakeResponse for Protocol::HandshakeResponse320 (except the common header part).
func parseOldHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41, data []byte, offset int) (err error) {
defer func() {
// Check malformat packet cause out of range is disgusting, but don't panic!
if r := recover(); r != nil {
logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data))
err = mysql.ErrMalformPacket
}
}()
// user name
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
offset += len(packet.User) + 1
if packet.Capability&mysql.ClientConnectWithDB > 0 {
if len(data[offset:]) > 0 {
idx := bytes.IndexByte(data[offset:], 0)
packet.DBName = string(data[offset : offset+idx])
offset = offset + idx + 1
}
if len(data[offset:]) > 0 {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
}
} else {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
}
return nil
}
// parseHandshakeResponseHeader parses the common header of SSLRequest and HandshakeResponse41.
func parseHandshakeResponseHeader(ctx context.Context, packet *handshakeResponse41, data []byte) (parsedBytes int, err error) {
// Ensure there are enough data to read:
// http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::SSLRequest
if len(data) < 4+4+1+23 {
logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data))
return 0, mysql.ErrMalformPacket
}
offset := 0
// capability
capability := binary.LittleEndian.Uint32(data[:4])
packet.Capability = capability
offset += 4
// skip max packet size
offset += 4
// charset, skip, if you want to use another charset, use set names
packet.Collation = data[offset]
offset++
// skip reserved 23[00]
offset += 23
return offset, nil
}
// parseHandshakeResponseBody parse the HandshakeResponse (except the common header part).
func parseHandshakeResponseBody(ctx context.Context, packet *handshakeResponse41, data []byte, offset int) (err error) {
defer func() {
// Check malformat packet cause out of range is disgusting, but don't panic!
if r := recover(); r != nil {
logutil.Logger(ctx).Error("handshake panic", zap.ByteString("packetData", data))
err = mysql.ErrMalformPacket
}
}()
// user name
packet.User = string(data[offset : offset+bytes.IndexByte(data[offset:], 0)])
offset += len(packet.User) + 1
if packet.Capability&mysql.ClientPluginAuthLenencClientData > 0 {
// MySQL client sets the wrong capability, it will set this bit even server doesn't
// support ClientPluginAuthLenencClientData.
// https://github.com/mysql/mysql-server/blob/5.7/sql-common/client.c#L3478
num, null, off := parseLengthEncodedInt(data[offset:])
offset += off
if !null {
packet.Auth = data[offset : offset+int(num)]
offset += int(num)
}
} else if packet.Capability&mysql.ClientSecureConnection > 0 {
// auth length and auth
authLen := int(data[offset])
offset++
packet.Auth = data[offset : offset+authLen]
offset += authLen
} else {
packet.Auth = data[offset : offset+bytes.IndexByte(data[offset:], 0)]
offset += len(packet.Auth) + 1
}
if packet.Capability&mysql.ClientConnectWithDB > 0 {
if len(data[offset:]) > 0 {
idx := bytes.IndexByte(data[offset:], 0)
packet.DBName = string(data[offset : offset+idx])
offset = offset + idx + 1
}
}
if packet.Capability&mysql.ClientPluginAuth > 0 {
// TODO: Support mysql.ClientPluginAuth, skip it now
idx := bytes.IndexByte(data[offset:], 0)
offset = offset + idx + 1
}
if packet.Capability&mysql.ClientConnectAtts > 0 {
if len(data[offset:]) == 0 {
// Defend some ill-formated packet, connection attribute is not important and can be ignored.
return nil
}
if num, null, off := parseLengthEncodedInt(data[offset:]); !null {
offset += off
row := data[offset : offset+int(num)]
attrs, err := parseAttrs(row)
if err != nil {
logutil.Logger(ctx).Warn("parse attrs failed", zap.Error(err))
return nil
}
packet.Attrs = attrs
}
}
return nil
}
func parseAttrs(data []byte) (map[string]string, error) {
attrs := make(map[string]string)
pos := 0
for pos < len(data) {
key, _, off, err := parseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off
value, _, off, err := parseLengthEncodedBytes(data[pos:])
if err != nil {
return attrs, err
}
pos += off
attrs[string(key)] = string(value)
}
return attrs, nil
}
func (cc *clientConn) readOptionalSSLRequestAndHandshakeResponse(ctx context.Context) error {
// Read a packet. It may be a SSLRequest or HandshakeResponse.
data, err := cc.readPacket()
if err != nil {
return err
}
isOldVersion := false
var resp handshakeResponse41
var pos int
if len(data) < 2 {
logutil.Logger(ctx).Error("got malformed handshake response", zap.ByteString("packetData", data))
return mysql.ErrMalformPacket
}
capability := uint32(binary.LittleEndian.Uint16(data[:2]))
if capability&mysql.ClientProtocol41 > 0 {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
isOldVersion = true
}
if err != nil {
return err
}
if (resp.Capability&mysql.ClientSSL > 0) && cc.server.tlsConfig != nil {
// The packet is a SSLRequest, let's switch to TLS.
if err = cc.upgradeToTLS(cc.server.tlsConfig); err != nil {
return err
}
// Read the following HandshakeResponse packet.
data, err = cc.readPacket()
if err != nil {
return err
}
if isOldVersion {
pos, err = parseOldHandshakeResponseHeader(ctx, &resp, data)
} else {
pos, err = parseHandshakeResponseHeader(ctx, &resp, data)
}
if err != nil {
return err
}
}
// Read the remaining part of the packet.
if isOldVersion {
err = parseOldHandshakeResponseBody(ctx, &resp, data, pos)
} else {
err = parseHandshakeResponseBody(ctx, &resp, data, pos)
}
if err != nil {
return err
}
cc.capability = resp.Capability & cc.server.capability
cc.user = resp.User
cc.dbname = resp.DBName
cc.collation = resp.Collation
cc.attrs = resp.Attrs
err = cc.openSessionAndDoAuth(resp.Auth)
return err
}
func (cc *clientConn) openSessionAndDoAuth(authData []byte) error {
var tlsStatePtr *tls.ConnectionState
if cc.tlsConn != nil {
tlsState := cc.tlsConn.ConnectionState()
tlsStatePtr = &tlsState
}
var err error
cc.ctx, err = cc.server.driver.OpenCtx(uint64(cc.connectionID), cc.capability, cc.collation, cc.dbname, tlsStatePtr)
if err != nil {
return err
}
hasPassword := "YES"
if len(authData) == 0 {
hasPassword = "NO"
}
host, err := cc.PeerHost(hasPassword)
if err != nil {
return err
}
if !cc.ctx.Auth(&auth.UserIdentity{Username: cc.user, Hostname: host}, authData, cc.salt) {
return errAccessDenied.GenWithStackByArgs(cc.user, host, hasPassword)
}
if cc.dbname != "" {
err = cc.useDB(context.Background(), cc.dbname)
if err != nil {
return err
}
}
cc.ctx.SetSessionManager(cc.server)
return nil
}
func (cc *clientConn) PeerHost(hasPassword string) (host string, err error) {
if len(cc.peerHost) > 0 {
return cc.peerHost, nil
}
host = variable.DefHostname
if cc.server.isUnixSocket() {
cc.peerHost = host
return
}
addr := cc.bufReadConn.RemoteAddr().String()
var port string
host, port, err = net.SplitHostPort(addr)
if err != nil {
err = errAccessDenied.GenWithStackByArgs(cc.user, addr, hasPassword)
return
}
cc.peerHost = host
cc.peerPort = port
return
}
// Run reads client query and writes query result to client in for loop, if there is a panic during query handling,
// it will be recovered and log the panic error.
// This function returns and the connection is closed if there is an IO error or there is a panic.
func (cc *clientConn) Run(ctx context.Context) {
const size = 4096
defer func() {
r := recover()
if r != nil {
buf := make([]byte, size)
stackSize := runtime.Stack(buf, false)
buf = buf[:stackSize]
logutil.Logger(ctx).Error("connection running loop panic",
zap.String("lastCmd", cc.lastCmd),
zap.Reflect("err", r),
zap.String("stack", string(buf)),
)
metrics.PanicCounter.WithLabelValues(metrics.LabelSession).Inc()
}
if atomic.LoadInt32(&cc.status) != connStatusShutdown {
err := cc.Close()
terror.Log(err)
}
}()
// Usually, client connection status changes between [dispatching] <=> [reading].
// When some event happens, server may notify this client connection by setting
// the status to special values, for example: kill or graceful shutdown.
// The client connection would detect the events when it fails to change status
// by CAS operation, it would then take some actions accordingly.
for {
if !atomic.CompareAndSwapInt32(&cc.status, connStatusDispatching, connStatusReading) {
return
}
cc.alloc.Reset()
// close connection when idle time is more than wait_timout
waitTimeout := cc.getSessionVarsWaitTimeout(ctx)
cc.pkt.setReadTimeout(time.Duration(waitTimeout) * time.Second)
start := time.Now()
data, err := cc.readPacket()
if err != nil {
if terror.ErrorNotEqual(err, io.EOF) {
if netErr, isNetErr := errors.Cause(err).(net.Error); isNetErr && netErr.Timeout() {
idleTime := time.Since(start)
logutil.Logger(ctx).Info("read packet timeout, close this connection",
zap.Duration("idle", idleTime),
zap.Uint64("waitTimeout", waitTimeout),
zap.Error(err),
)
} else {
errStack := errors.ErrorStack(err)
if !strings.Contains(errStack, "use of closed network connection") {
logutil.Logger(ctx).Warn("read packet failed, close this connection",
zap.Error(errors.SuspendStack(err)))
}
}
}
return
}
if !atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusDispatching) {
return
}
startTime := time.Now()
if err = cc.dispatch(ctx, data); err != nil {
if terror.ErrorEqual(err, io.EOF) {
cc.addMetrics(data[0], startTime, nil)
return
} else if terror.ErrResultUndetermined.Equal(err) {
logutil.Logger(ctx).Error("result undetermined, close this connection", zap.Error(err))
return
} else if terror.ErrCritical.Equal(err) {
logutil.Logger(ctx).Error("critical error, stop the server listener", zap.Error(err))
metrics.CriticalErrorCounter.Add(1)
select {
case cc.server.stopListenerCh <- struct{}{}:
default:
}
return
}
logutil.Logger(ctx).Warn("dispatch error",
zap.String("connInfo", cc.String()),
zap.String("sql", queryStrForLog(string(data[1:]))),
zap.String("err", errStrForLog(err)),
)
err1 := cc.writeError(err)
terror.Log(err1)
}
cc.addMetrics(data[0], startTime, err)
cc.pkt.sequence = 0
}
}
// ShutdownOrNotify will Shutdown this client connection, or do its best to notify.
func (cc *clientConn) ShutdownOrNotify() bool {
if (cc.ctx.Status() & mysql.ServerStatusInTrans) > 0 {
return false
}
// If the client connection status is reading, it's safe to shutdown it.
if atomic.CompareAndSwapInt32(&cc.status, connStatusReading, connStatusShutdown) {
return true
}
// If the client connection status is dispatching, we can't shutdown it immediately,
// so set the status to WaitShutdown as a notification, the client will detect it
// and then exit.
atomic.StoreInt32(&cc.status, connStatusWaitShutdown)
return false
}
func queryStrForLog(query string) string {
const size = 4096
if len(query) > size {
return query[:size] + fmt.Sprintf("(len: %d)", len(query))
}
return query
}
func errStrForLog(err error) string {
if kv.ErrKeyExists.Equal(err) {
// Do not log stack for duplicated entry error.
return err.Error()
}
return errors.ErrorStack(err)
}
func (cc *clientConn) addMetrics(cmd byte, startTime time.Time, err error) {
switch cmd {
case mysql.ComSleep:
if err != nil {
queryTotalCounterComSleepError.Inc()
} else {
queryTotalCounterComSleepOK.Inc()
}
case mysql.ComQuit:
if err != nil {
queryTotalCounterComQuitError.Inc()
} else {
queryTotalCounterComQuitOK.Inc()
}
case mysql.ComQuery:
if cc.ctx.Value(sessionctx.LastExecuteDDL) != nil {
// Don't take DDL execute time into account.
// It's already recorded by other metrics in ddl package.
return
}
if err != nil {
queryTotalCounterComQueryError.Inc()
} else {
queryTotalCounterComQueryOK.Inc()
}
case mysql.ComPing:
if err != nil {
queryTotalCounterComPingError.Inc()
} else {
queryTotalCounterComPingOK.Inc()
}
case mysql.ComInitDB:
if err != nil {
queryTotalCounterComInitDBError.Inc()
} else {
queryTotalCounterComInitDBOK.Inc()
}
case mysql.ComFieldList:
if err != nil {
queryTotalCounterComFieldListError.Inc()
} else {
queryTotalCounterComFieldListOK.Inc()
}
case mysql.ComStmtPrepare:
if err != nil {
queryTotalCounterComPrepareError.Inc()
} else {
queryTotalCounterComPrepareOK.Inc()
}
case mysql.ComStmtExecute:
if err != nil {
queryTotalCounterComExecuteError.Inc()
} else {
queryTotalCounterComExecuteOK.Inc()
}
case mysql.ComStmtFetch:
if err != nil {
queryTotalCounterComFetchError.Inc()
} else {
queryTotalCounterComFetchOK.Inc()
}
case mysql.ComStmtClose:
if err != nil {
queryTotalCounterComCloseError.Inc()
} else {
queryTotalCounterComCloseOK.Inc()
}
case mysql.ComStmtSendLongData:
if err != nil {
queryTotalCounterComSendLongDataError.Inc()
} else {
queryTotalCounterComSendLongDataOK.Inc()
}
case mysql.ComStmtReset:
if err != nil {
queryTotalCounterComResetError.Inc()
} else {
queryTotalCounterComResetOK.Inc()
}
case mysql.ComSetOption:
if err != nil {
queryTotalCounterComSetOptionError.Inc()
} else {
queryTotalCounterComSetOptionOK.Inc()
}
default:
label := strconv.Itoa(int(cmd))
if err != nil {
metrics.QueryTotalCounter.WithLabelValues(label, "ERROR").Inc()
} else {
metrics.QueryTotalCounter.WithLabelValues(label, "OK").Inc()
}
}
stmtType := cc.ctx.GetSessionVars().StmtCtx.StmtType
sqlType := metrics.LblGeneral
if stmtType != "" {
sqlType = stmtType
}
switch sqlType {
case "Use":
queryDurationHistogramUse.Observe(time.Since(startTime).Seconds())
case "Show":
queryDurationHistogramShow.Observe(time.Since(startTime).Seconds())
case "Begin":
queryDurationHistogramBegin.Observe(time.Since(startTime).Seconds())
case "Commit":
queryDurationHistogramCommit.Observe(time.Since(startTime).Seconds())
case "Rollback":
queryDurationHistogramRollback.Observe(time.Since(startTime).Seconds())
case "Insert":
queryDurationHistogramInsert.Observe(time.Since(startTime).Seconds())
case "Replace":
queryDurationHistogramReplace.Observe(time.Since(startTime).Seconds())
case "Delete":
queryDurationHistogramDelete.Observe(time.Since(startTime).Seconds())
case "Update":
queryDurationHistogramUpdate.Observe(time.Since(startTime).Seconds())
case "Select":
queryDurationHistogramSelect.Observe(time.Since(startTime).Seconds())
case "Execute":
queryDurationHistogramExecute.Observe(time.Since(startTime).Seconds())
case "Set":
queryDurationHistogramSet.Observe(time.Since(startTime).Seconds())
case metrics.LblGeneral:
queryDurationHistogramGeneral.Observe(time.Since(startTime).Seconds())
default:
metrics.QueryDurationHistogram.WithLabelValues(sqlType).Observe(time.Since(startTime).Seconds())
}
}
// dispatch handles client request based on command which is the first byte of the data.
// It also gets a token from server which is used to limit the concurrently handling clients.
// The most frequently used command is ComQuery.
func (cc *clientConn) dispatch(ctx context.Context, data []byte) error {
span := opentracing.StartSpan("server.dispatch")
t := time.Now()
cmd := data[0]
data = data[1:]
cc.lastCmd = string(hack.String(data))
token := cc.server.getToken()
defer func() {
// if handleChangeUser failed, cc.ctx may be nil
if cc.ctx != nil {
cc.ctx.SetProcessInfo("", t, mysql.ComSleep, 0)
}
cc.server.releaseToken(token)
span.Finish()
}()
vars := cc.ctx.GetSessionVars()
atomic.StoreUint32(&vars.Killed, 0)
if cmd < mysql.ComEnd {
cc.ctx.SetCommandValue(cmd)
}
dataStr := string(hack.String(data))
switch cmd {
case mysql.ComPing, mysql.ComStmtClose, mysql.ComStmtSendLongData, mysql.ComStmtReset,
mysql.ComSetOption, mysql.ComChangeUser:
cc.ctx.SetProcessInfo("", t, cmd, 0)
case mysql.ComInitDB:
cc.ctx.SetProcessInfo("use "+dataStr, t, cmd, 0)
}
switch cmd {
case mysql.ComSleep:
// TODO: According to mysql document, this command is supposed to be used only internally.
// So it's just a temp fix, not sure if it's done right.
// Investigate this command and write test case later.
return nil
case mysql.ComQuit:
return io.EOF
case mysql.ComQuery: // Most frequently used command.
// For issue 1989
// Input payload may end with byte '\0', we didn't find related mysql document about it, but mysql
// implementation accept that case. So trim the last '\0' here as if the payload an EOF string.
// See http://dev.mysql.com/doc/internals/en/com-query.html
if len(data) > 0 && data[len(data)-1] == 0 {
data = data[:len(data)-1]
dataStr = string(hack.String(data))
}
return cc.handleQuery(ctx, dataStr)
case mysql.ComPing:
return cc.writeOK()
case mysql.ComInitDB:
if err := cc.useDB(ctx, dataStr); err != nil {
return err
}
return cc.writeOK()
case mysql.ComFieldList:
return cc.handleFieldList(dataStr)
case mysql.ComStmtPrepare:
return cc.handleStmtPrepare(dataStr)
case mysql.ComStmtExecute:
return cc.handleStmtExecute(ctx, data)
case mysql.ComStmtFetch:
return cc.handleStmtFetch(ctx, data)
case mysql.ComStmtClose:
return cc.handleStmtClose(data)
case mysql.ComStmtSendLongData:
return cc.handleStmtSendLongData(data)
case mysql.ComStmtReset:
return cc.handleStmtReset(data)
case mysql.ComSetOption:
return cc.handleSetOption(data)
case mysql.ComChangeUser:
return cc.handleChangeUser(ctx, data)
default:
return mysql.NewErrf(mysql.ErrUnknown, "command %d not supported now", cmd)
}
}
func (cc *clientConn) useDB(ctx context.Context, db string) (err error) {
// if input is "use `SELECT`", mysql client just send "SELECT"
// so we add `` around db.
_, err = cc.ctx.Execute(ctx, "use `"+db+"`")
if err != nil {
return err
}
cc.dbname = db
return
}
func (cc *clientConn) flush() error {
failpoint.Inject("FakeClientConn", func() {
if cc.pkt == nil {
failpoint.Return(nil)
}
})
return cc.pkt.flush()
}
func (cc *clientConn) writeOK() error {
msg := cc.ctx.LastMessage()
return cc.writeOkWith(msg, cc.ctx.AffectedRows(), cc.ctx.LastInsertID(), cc.ctx.Status(), cc.ctx.WarningCount())
}
func (cc *clientConn) writeOkWith(msg string, affectedRows, lastInsertID uint64, status, warnCnt uint16) error {
enclen := 0
if len(msg) > 0 {
enclen = lengthEncodedIntSize(uint64(len(msg))) + len(msg)
}
data := cc.alloc.AllocWithLen(4, 32+enclen)
data = append(data, mysql.OKHeader)
data = dumpLengthEncodedInt(data, affectedRows)
data = dumpLengthEncodedInt(data, lastInsertID)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dumpUint16(data, status)
data = dumpUint16(data, warnCnt)
}
if enclen > 0 {
// although MySQL manual says the info message is string<EOF>(https://dev.mysql.com/doc/internals/en/packet-OK_Packet.html),
// it is actually string<lenenc>
data = dumpLengthEncodedString(data, []byte(msg))
}
err := cc.writePacket(data)
if err != nil {
return err
}
return cc.flush()
}
func (cc *clientConn) writeError(e error) error {
var (
m *mysql.SQLError
te *terror.Error
ok bool
)
originErr := errors.Cause(e)
if te, ok = originErr.(*terror.Error); ok {
m = te.ToSQLError()
} else {
m = mysql.NewErrf(mysql.ErrUnknown, "%s", e.Error())
}
cc.lastCode = m.Code
data := cc.alloc.AllocWithLen(4, 16+len(m.Message))
data = append(data, mysql.ErrHeader)
data = append(data, byte(m.Code), byte(m.Code>>8))
if cc.capability&mysql.ClientProtocol41 > 0 {
data = append(data, '#')
data = append(data, m.State...)
}
data = append(data, m.Message...)
err := cc.writePacket(data)
if err != nil {
return err
}
return cc.flush()
}
// writeEOF writes an EOF packet.
// Note this function won't flush the stream because maybe there are more
// packets following it.
// serverStatus, a flag bit represents server information
// in the packet.
func (cc *clientConn) writeEOF(serverStatus uint16) error {
data := cc.alloc.AllocWithLen(4, 9)
data = append(data, mysql.EOFHeader)
if cc.capability&mysql.ClientProtocol41 > 0 {
data = dumpUint16(data, cc.ctx.WarningCount())
status := cc.ctx.Status()
status |= serverStatus
data = dumpUint16(data, status)
}
err := cc.writePacket(data)
return err
}
func (cc *clientConn) writeReq(filePath string) error {
data := cc.alloc.AllocWithLen(4, 5+len(filePath))
data = append(data, mysql.LocalInFileHeader)
data = append(data, filePath...)
err := cc.writePacket(data)
if err != nil {
return err
}
return cc.flush()
}
var defaultLoadDataBatchCnt uint64 = 20000
func insertDataWithCommit(ctx context.Context, prevData, curData []byte, loadDataInfo *executor.LoadDataInfo) ([]byte, error) {
var err error
var reachLimit bool
for {
prevData, reachLimit, err = loadDataInfo.InsertData(ctx, prevData, curData)
if err != nil {
return nil, err
}
if !reachLimit {
break
}
err := loadDataInfo.CheckAndInsertOneBatch(ctx)
if err != nil {
return nil, err
}
if err = loadDataInfo.Ctx.StmtCommit(); err != nil {
return nil, err
}
// Make sure that there are no retries when committing.
if err = loadDataInfo.Ctx.RefreshTxnCtx(ctx); err != nil {
return nil, err
}
curData = prevData
prevData = nil
}
return prevData, nil
}
// handleLoadData does the additional work after processing the 'load data' query.
// It sends client a file path, then reads the file content from client, inserts data into database.
func (cc *clientConn) handleLoadData(ctx context.Context, loadDataInfo *executor.LoadDataInfo) error {
// If the server handles the load data request, the client has to set the ClientLocalFiles capability.
if cc.capability&mysql.ClientLocalFiles == 0 {
return errNotAllowedCommand
}
if loadDataInfo == nil {
return errors.New("load data info is empty")
}
err := cc.writeReq(loadDataInfo.Path)
if err != nil {
return err
}
var shouldBreak bool
var prevData, curData []byte
// TODO: Make the loadDataRowCnt settable.
loadDataInfo.SetMaxRowsInBatch(defaultLoadDataBatchCnt)
err = loadDataInfo.Ctx.NewTxn(ctx)
if err != nil {
return err
}
for {
curData, err = cc.readPacket()
if err != nil {
if terror.ErrorNotEqual(err, io.EOF) {
logutil.Logger(ctx).Error("read packet failed", zap.Error(err))
break
}
}
if len(curData) == 0 {
shouldBreak = true
if len(prevData) == 0 {
break
}
}
prevData, err = insertDataWithCommit(ctx, prevData, curData, loadDataInfo)
if err != nil {
break
}
if shouldBreak {
break
}
}
loadDataInfo.SetMessage()
if err != nil {
loadDataInfo.Ctx.StmtRollback()
} else {
err = loadDataInfo.CheckAndInsertOneBatch(ctx)
if err == nil {
err = loadDataInfo.Ctx.StmtCommit()
}
}
var txn kv.Transaction
var err1 error
txn, err1 = loadDataInfo.Ctx.Txn(true)
if err1 == nil {
if txn != nil && txn.Valid() {
if err != nil {
if err1 := txn.Rollback(); err1 != nil {
logutil.Logger(ctx).Error("load data rollback failed", zap.Error(err1))
}
return err
}
return cc.ctx.CommitTxn(sessionctx.SetCommitCtx(ctx, loadDataInfo.Ctx))
}
}
// Should never reach here.
panic(err1)
}
// handleLoadStats does the additional work after processing the 'load stats' query.
// It sends client a file path, then reads the file content from client, loads it into the storage.
func (cc *clientConn) handleLoadStats(ctx context.Context, loadStatsInfo *executor.LoadStatsInfo) error {
// If the server handles the load data request, the client has to set the ClientLocalFiles capability.
if cc.capability&mysql.ClientLocalFiles == 0 {
return errNotAllowedCommand
}
if loadStatsInfo == nil {
return errors.New("load stats: info is empty")
}
err := cc.writeReq(loadStatsInfo.Path)
if err != nil {
return err
}
var prevData, curData []byte
for {
curData, err = cc.readPacket()
if err != nil && terror.ErrorNotEqual(err, io.EOF) {
return err
}
if len(curData) == 0 {
break
}
prevData = append(prevData, curData...)
}
if len(prevData) == 0 {
return nil
}
return loadStatsInfo.Update(prevData)
}
// handleQuery executes the sql query string and writes result set or result ok to the client.
// As the execution time of this function represents the performance of TiDB, we do time log and metrics here.
// There is a special query `load data` that does not return result, which is handled differently.
// Query `load stats` does not return result either.
func (cc *clientConn) handleQuery(ctx context.Context, sql string) (err error) {
rs, err := cc.ctx.Execute(ctx, sql)
if err != nil {
metrics.ExecuteErrorCounter.WithLabelValues(metrics.ExecuteErrorToLabel(err)).Inc()
return err
}
status := atomic.LoadInt32(&cc.status)
if rs != nil && (status == connStatusShutdown || status == connStatusWaitShutdown) {
killConn(cc)
return executor.ErrQueryInterrupted
}
if rs != nil {
if len(rs) == 1 {
err = cc.writeResultset(ctx, rs[0], false, 0, 0)
} else {
err = cc.writeMultiResultset(ctx, rs, false)
}
} else {
loadDataInfo := cc.ctx.Value(executor.LoadDataVarKey)
if loadDataInfo != nil {
defer cc.ctx.SetValue(executor.LoadDataVarKey, nil)
if err = cc.handleLoadData(ctx, loadDataInfo.(*executor.LoadDataInfo)); err != nil {
return err
}
}
loadStats := cc.ctx.Value(executor.LoadStatsVarKey)
if loadStats != nil {
defer cc.ctx.SetValue(executor.LoadStatsVarKey, nil)
if err = cc.handleLoadStats(ctx, loadStats.(*executor.LoadStatsInfo)); err != nil {
return err
}
}
err = cc.writeOK()
}
return err
}
// handleFieldList returns the field list for a table.
// The sql string is composed of a table name and a terminating character \x00.
func (cc *clientConn) handleFieldList(sql string) (err error) {
parts := strings.Split(sql, "\x00")
columns, err := cc.ctx.FieldList(parts[0])
if err != nil {
return err
}
data := cc.alloc.AllocWithLen(4, 1024)
for _, column := range columns {
// Current we doesn't output defaultValue but reserve defaultValue length byte to make mariadb client happy.
// https://dev.mysql.com/doc/internals/en/com-query-response.html#column-definition
// TODO: fill the right DefaultValues.
column.DefaultValueLength = 0
column.DefaultValue = []byte{}
data = data[0:4]
data = column.Dump(data)
if err := cc.writePacket(data); err != nil {
return err
}
}
if err := cc.writeEOF(0); err != nil {
return err
}
return cc.flush()
}
// writeResultset writes data into a resultset and uses rs.Next to get row data back.
// If binary is true, the data would be encoded in BINARY format.
// serverStatus, a flag bit represents server information.
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
// resultsets, it's used to support the MULTI_RESULTS capability in mysql protocol.
func (cc *clientConn) writeResultset(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16, fetchSize int) (runErr error) {
defer func() {
// close ResultSet when cursor doesn't exist
if !mysql.HasCursorExistsFlag(serverStatus) {
terror.Call(rs.Close)
}
r := recover()
if r == nil {
return
}
if str, ok := r.(string); !ok || !strings.HasPrefix(str, memory.PanicMemoryExceed) {
panic(r)
}
// TODO(jianzhang.zj: add metrics here)
runErr = errors.Errorf("%v", r)
buf := make([]byte, 4096)
stackSize := runtime.Stack(buf, false)
buf = buf[:stackSize]
logutil.Logger(ctx).Error("write query result panic", zap.String("lastCmd", cc.lastCmd), zap.String("stack", string(buf)))
}()
var err error
if mysql.HasCursorExistsFlag(serverStatus) {
err = cc.writeChunksWithFetchSize(ctx, rs, serverStatus, fetchSize)
} else {
err = cc.writeChunks(ctx, rs, binary, serverStatus)
}
if err != nil {
return err
}
return cc.flush()
}
func (cc *clientConn) writeColumnInfo(columns []*ColumnInfo, serverStatus uint16) error {
data := cc.alloc.AllocWithLen(4, 1024)
data = dumpLengthEncodedInt(data, uint64(len(columns)))
if err := cc.writePacket(data); err != nil {
return err
}
for _, v := range columns {
data = data[0:4]
data = v.Dump(data)
if err := cc.writePacket(data); err != nil {
return err
}
}
return cc.writeEOF(serverStatus)
}
// writeChunks writes data from a Chunk, which filled data by a ResultSet, into a connection.
// binary specifies the way to dump data. It throws any error while dumping data.
// serverStatus, a flag bit represents server information
func (cc *clientConn) writeChunks(ctx context.Context, rs ResultSet, binary bool, serverStatus uint16) error {
data := cc.alloc.AllocWithLen(4, 1024)
req := rs.NewChunk()
gotColumnInfo := false
for {
// Here server.tidbResultSet implements Next method.
err := rs.Next(ctx, req)
if err != nil {
return err
}
if !gotColumnInfo {
// We need to call Next before we get columns.
// Otherwise, we will get incorrect columns info.
columns := rs.Columns()
err = cc.writeColumnInfo(columns, serverStatus)
if err != nil {
return err
}
gotColumnInfo = true
}
rowCount := req.NumRows()
if rowCount == 0 {
break
}
for i := 0; i < rowCount; i++ {
data = data[0:4]
if binary {
data, err = dumpBinaryRow(data, rs.Columns(), req.GetRow(i))
} else {
data, err = dumpTextRow(data, rs.Columns(), req.GetRow(i))
}
if err != nil {
return err
}
if err = cc.writePacket(data); err != nil {
return err
}
}
}
return cc.writeEOF(serverStatus)
}
// writeChunksWithFetchSize writes data from a Chunk, which filled data by a ResultSet, into a connection.
// binary specifies the way to dump data. It throws any error while dumping data.
// serverStatus, a flag bit represents server information.
// fetchSize, the desired number of rows to be fetched each time when client uses cursor.
func (cc *clientConn) writeChunksWithFetchSize(ctx context.Context, rs ResultSet, serverStatus uint16, fetchSize int) error {
fetchedRows := rs.GetFetchedRows()
// if fetchedRows is not enough, getting data from recordSet.
req := rs.NewChunk()
for len(fetchedRows) < fetchSize {
// Here server.tidbResultSet implements Next method.
err := rs.Next(ctx, req)
if err != nil {
return err
}
rowCount := req.NumRows()
if rowCount == 0 {
break
}
// filling fetchedRows with chunk
for i := 0; i < rowCount; i++ {
fetchedRows = append(fetchedRows, req.GetRow(i))
}
req = chunk.Renew(req, cc.ctx.GetSessionVars().MaxChunkSize)
}
// tell the client COM_STMT_FETCH has finished by setting proper serverStatus,
// and close ResultSet.
if len(fetchedRows) == 0 {
serverStatus |= mysql.ServerStatusLastRowSend
terror.Call(rs.Close)
return cc.writeEOF(serverStatus)
}
// construct the rows sent to the client according to fetchSize.
var curRows []chunk.Row
if fetchSize < len(fetchedRows) {
curRows = fetchedRows[:fetchSize]
fetchedRows = fetchedRows[fetchSize:]
} else {
curRows = fetchedRows[:]
fetchedRows = fetchedRows[:0]
}
rs.StoreFetchedRows(fetchedRows)
data := cc.alloc.AllocWithLen(4, 1024)
var err error
for _, row := range curRows {
data = data[0:4]
data, err = dumpBinaryRow(data, rs.Columns(), row)
if err != nil {
return err
}
if err = cc.writePacket(data); err != nil {
return err
}
}
return cc.writeEOF(serverStatus)
}
func (cc *clientConn) writeMultiResultset(ctx context.Context, rss []ResultSet, binary bool) error {
for i, rs := range rss {
lastRs := i == len(rss)-1
if r, ok := rs.(*tidbResultSet).recordSet.(sqlexec.MultiQueryNoDelayResult); ok {
status := r.Status()
if !lastRs {
status |= mysql.ServerMoreResultsExists
}
if err := cc.writeOkWith(r.LastMessage(), r.AffectedRows(), r.LastInsertID(), status, r.WarnCount()); err != nil {
return err
}
continue
}
status := uint16(0)
if !lastRs {
status |= mysql.ServerMoreResultsExists
}
if err := cc.writeResultset(ctx, rs, binary, status, 0); err != nil {
return err
}
}
return nil
}
func (cc *clientConn) setConn(conn net.Conn) {
cc.bufReadConn = newBufferedReadConn(conn)
if cc.pkt == nil {
cc.pkt = newPacketIO(cc.bufReadConn)
} else {
// Preserve current sequence number.
cc.pkt.setBufferedReadConn(cc.bufReadConn)
}
}
func (cc *clientConn) upgradeToTLS(tlsConfig *tls.Config) error {
// Important: read from buffered reader instead of the original net.Conn because it may contain data we need.
tlsConn := tls.Server(cc.bufReadConn, tlsConfig)
if err := tlsConn.Handshake(); err != nil {
return err
}
cc.setConn(tlsConn)
cc.tlsConn = tlsConn
return nil
}
func (cc *clientConn) handleChangeUser(ctx context.Context, data []byte) error {
user, data := parseNullTermString(data)
cc.user = string(hack.String(user))
if len(data) < 1 {
return mysql.ErrMalformPacket
}
passLen := int(data[0])
data = data[1:]
if passLen > len(data) {
return mysql.ErrMalformPacket
}
pass := data[:passLen]
data = data[passLen:]
dbName, _ := parseNullTermString(data)
cc.dbname = string(hack.String(dbName))
err := cc.ctx.Close()
if err != nil {
logutil.Logger(ctx).Debug("close old context error", zap.Error(err))
}
err = cc.openSessionAndDoAuth(pass)
if err != nil {
return err
}
if plugin.IsEnable(plugin.Audit) {
cc.ctx.GetSessionVars().ConnectionInfo = cc.connectInfo()
}
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
connInfo := cc.ctx.GetSessionVars().ConnectionInfo
err = authPlugin.OnConnectionEvent(context.Background(), plugin.ChangeUser, connInfo)
if err != nil {
return err
}
}
return nil
})
if err != nil {
return err
}
return cc.writeOK()
}