Files
tidb/pkg/server/server.go

1272 lines
39 KiB
Go

// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// The MIT License (MIT)
//
// Copyright (c) 2014 wandoulabs
// Copyright (c) 2014 siddontang
//
// Permission is hereby granted, free of charge, to any person obtaining a copy of
// this software and associated documentation files (the "Software"), to deal in
// the Software without restriction, including without limitation the rights to
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
// the Software, and to permit persons to whom the Software is furnished to do so,
// subject to the following conditions:
//
// The above copyright notice and this permission notice shall be included in all
// copies or substantial portions of the Software.
package server
import (
"context"
"crypto/tls"
"fmt"
"io"
"maps"
"net"
"net/http" //nolint:goimports
_ "net/http/pprof" // #nosec G108 for pprof
"os"
"os/user"
"reflect"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"unsafe"
"github.com/blacktear23/go-proxyprotocol"
"github.com/pingcap/errors"
"github.com/pingcap/log"
autoid "github.com/pingcap/tidb/pkg/autoid_service"
"github.com/pingcap/tidb/pkg/config"
"github.com/pingcap/tidb/pkg/domain"
"github.com/pingcap/tidb/pkg/executor/mppcoordmanager"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/infoschema/issyncer/mdldef"
"github.com/pingcap/tidb/pkg/metrics"
"github.com/pingcap/tidb/pkg/parser/ast"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/parser/terror"
"github.com/pingcap/tidb/pkg/planner/core"
"github.com/pingcap/tidb/pkg/plugin"
"github.com/pingcap/tidb/pkg/privilege/privileges"
"github.com/pingcap/tidb/pkg/resourcegroup"
servererr "github.com/pingcap/tidb/pkg/server/err"
"github.com/pingcap/tidb/pkg/session"
"github.com/pingcap/tidb/pkg/session/sessmgr"
"github.com/pingcap/tidb/pkg/session/txninfo"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util"
"github.com/pingcap/tidb/pkg/util"
"github.com/pingcap/tidb/pkg/util/fastrand"
"github.com/pingcap/tidb/pkg/util/kvcache"
"github.com/pingcap/tidb/pkg/util/logutil"
"github.com/pingcap/tidb/pkg/util/sqlkiller"
"github.com/pingcap/tidb/pkg/util/sys/linux"
"github.com/pingcap/tidb/pkg/util/timeutil"
tlsutil "github.com/pingcap/tidb/pkg/util/tls"
uatomic "go.uber.org/atomic"
"go.uber.org/zap"
"google.golang.org/grpc"
)
var (
serverPID int
osUser string
osVersion string
// RunInGoTest represents whether we are run code in test.
RunInGoTest bool
// RunInGoTestChan is used to control the RunInGoTest.
RunInGoTestChan chan struct{}
)
func init() {
serverPID = os.Getpid()
currentUser, err := user.Current()
if err != nil {
osUser = ""
} else {
osUser = currentUser.Name
}
osVersion, err = linux.OSVersion()
if err != nil {
osVersion = ""
}
}
// DefaultCapability is the capability of the server when it is created using the default configuration.
// When server is configured with SSL, the server will have extra capabilities compared to DefaultCapability.
const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows |
mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles |
mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive |
mysql.ClientDeprecateEOF | mysql.ClientCompress | mysql.ClientZstdCompressionAlgorithm
const normalClosedConnsCapacity = 1000
// Server is the MySQL protocol server
type Server struct {
cfg *config.Config
tlsConfig unsafe.Pointer // *tls.Config
driver IDriver
listener net.Listener
socket net.Listener
concurrentLimiter *util.TokenLimiter
rwlock sync.RWMutex
clients map[uint64]*clientConn
normalClosedConnsMutex sync.Mutex
normalClosedConns *kvcache.SimpleLRUCache
userResLock sync.RWMutex // userResLock used to protect userResource
userResource map[string]*userResourceLimits
capability uint32
dom *domain.Domain
statusAddr string
statusListener net.Listener
statusServer atomic.Pointer[http.Server]
grpcServer *grpc.Server
inShutdownMode *uatomic.Bool
health *uatomic.Bool
sessionMapMutex sync.Mutex
internalSessions map[any]struct{}
autoIDService *autoid.Service
authTokenCancelFunc context.CancelFunc
wg sync.WaitGroup
printMDLLogTime time.Time
StandbyController
}
// NewTestServer creates a new Server for test.
func NewTestServer(cfg *config.Config) *Server {
return &Server{
cfg: cfg,
}
}
// Socket returns the server's socket file.
func (s *Server) Socket() net.Listener {
return s.socket
}
// Listener returns the server's listener.
func (s *Server) Listener() net.Listener {
return s.listener
}
// ListenAddr returns the server's listener's network address.
func (s *Server) ListenAddr() net.Addr {
return s.listener.Addr()
}
// StatusListenerAddr returns the server's status listener's network address.
func (s *Server) StatusListenerAddr() net.Addr {
return s.statusListener.Addr()
}
// BitwiseXorCapability gets the capability of the server.
func (s *Server) BitwiseXorCapability(capability uint32) {
s.capability ^= capability
}
// BitwiseOrAssignCapability adds the capability to the server.
func (s *Server) BitwiseOrAssignCapability(capability uint32) {
s.capability |= capability
}
// GetStatusServerAddr gets statusServer address for MppCoordinatorManager usage
func (s *Server) GetStatusServerAddr() (on bool, addr string) {
if !s.cfg.Status.ReportStatus {
return false, ""
}
if strings.Contains(s.statusAddr, config.DefStatusHost) {
if len(s.cfg.AdvertiseAddress) != 0 {
return true, strings.ReplaceAll(s.statusAddr, config.DefStatusHost, s.cfg.AdvertiseAddress)
}
return false, ""
}
return true, s.statusAddr
}
type normalCloseConnKey struct {
keyspaceName string
connID string
}
func (k normalCloseConnKey) Hash() []byte {
return []byte(fmt.Sprintf("%s-%s", k.keyspaceName, k.connID))
}
// SetNormalClosedConn sets the normal closed connection message by specified connID.
func (s *Server) SetNormalClosedConn(keyspaceName, connID, msg string) {
if connID == "" {
return
}
s.normalClosedConnsMutex.Lock()
defer s.normalClosedConnsMutex.Unlock()
s.normalClosedConns.Put(normalCloseConnKey{keyspaceName: keyspaceName, connID: connID}, msg)
}
// GetNormalClosedConn gets the normal closed connection message.
func (s *Server) GetNormalClosedConn(keyspaceName, connID string) string {
if connID == "" {
return ""
}
s.normalClosedConnsMutex.Lock()
defer s.normalClosedConnsMutex.Unlock()
v, ok := s.normalClosedConns.Get(normalCloseConnKey{keyspaceName: keyspaceName, connID: connID})
if !ok {
return ""
}
return v.(string)
}
// ConnectionCount gets current connection count.
func (s *Server) ConnectionCount() int {
s.rwlock.RLock()
cnt := len(s.clients)
s.rwlock.RUnlock()
return cnt
}
func (s *Server) getToken() *util.Token {
start := time.Now()
tok := s.concurrentLimiter.Get()
metrics.TokenGauge.Inc()
// Note that data smaller than one microsecond is ignored, because that case can be viewed as non-block.
metrics.GetTokenDurationHistogram.Observe(float64(time.Since(start).Nanoseconds() / 1e3))
return tok
}
func (s *Server) releaseToken(token *util.Token) {
s.concurrentLimiter.Put(token)
metrics.TokenGauge.Dec()
}
// SetDomain use to set the server domain.
func (s *Server) SetDomain(dom *domain.Domain) {
s.dom = dom
}
// newConn creates a new *clientConn from a net.Conn.
// It allocates a connection ID and random salt data for authentication.
func (s *Server) newConn(conn net.Conn) *clientConn {
cc := newClientConn(s)
if tcpConn, ok := conn.(*net.TCPConn); ok {
if err := tcpConn.SetKeepAlive(s.cfg.Performance.TCPKeepAlive); err != nil {
logutil.BgLogger().Error("failed to set tcp keep alive option", zap.Error(err))
}
if err := tcpConn.SetNoDelay(s.cfg.Performance.TCPNoDelay); err != nil {
logutil.BgLogger().Error("failed to set tcp no delay option", zap.Error(err))
}
}
cc.setConn(conn)
cc.salt = fastrand.Buf(20)
metrics.ConnGauge.WithLabelValues(resourcegroup.DefaultResourceGroupName).Inc()
return cc
}
// NewServer creates a new Server.
func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
s := &Server{
cfg: cfg,
driver: driver,
concurrentLimiter: util.NewTokenLimiter(cfg.TokenLimit),
clients: make(map[uint64]*clientConn),
normalClosedConns: kvcache.NewSimpleLRUCache(normalClosedConnsCapacity, 0, 0),
userResource: make(map[string]*userResourceLimits),
internalSessions: make(map[any]struct{}, 100),
health: uatomic.NewBool(false),
inShutdownMode: uatomic.NewBool(false),
printMDLLogTime: time.Now(),
}
s.capability = defaultCapability
setSystemTimeZoneVariable()
tlsConfig, autoReload, err := util.LoadTLSCertificates(
s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert,
s.cfg.Security.AutoTLS, s.cfg.Security.RSAKeySize)
// LoadTLSCertificates will auto generate certificates if autoTLS is enabled.
// It only returns an error if certificates are specified and invalid.
// In which case, we should halt server startup as a misconfiguration could
// lead to a connection downgrade.
if err != nil {
return nil, errors.Trace(err)
}
// Automatically reload auto-generated certificates.
// The certificates are re-created every 30 days and are valid for 90 days.
if autoReload {
go func() {
for range time.Tick(time.Hour * 24 * 30) { // 30 days
logutil.BgLogger().Info("Rotating automatically created TLS Certificates")
tlsConfig, _, err = util.LoadTLSCertificates(
s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert,
s.cfg.Security.AutoTLS, s.cfg.Security.RSAKeySize)
if err != nil {
logutil.BgLogger().Warn("TLS Certificate rotation failed", zap.Error(err))
}
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
}
}()
}
if tlsConfig != nil {
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
logutil.BgLogger().Info("mysql protocol server secure connection is enabled",
zap.Bool("client verification enabled", len(variable.GetSysVar("ssl_ca").Value) > 0))
}
if s.tlsConfig != nil {
s.capability |= mysql.ClientSSL
}
variable.RegisterStatistics(s)
return s, nil
}
func (s *Server) initTiDBListener() (err error) {
if s.cfg.Host != "" && (s.cfg.Port != 0 || RunInGoTest) {
addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(int(s.cfg.Port)))
tcpProto := "tcp"
if s.cfg.EnableTCP4Only {
tcpProto = "tcp4"
}
if s.listener, err = net.Listen(tcpProto, addr); err != nil {
return errors.Trace(err)
}
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr))
if RunInGoTest && s.cfg.Port == 0 {
s.cfg.Port = uint(s.listener.Addr().(*net.TCPAddr).Port)
}
}
if s.cfg.Socket != "" {
if err := cleanupStaleSocket(s.cfg.Socket); err != nil {
return errors.Trace(err)
}
if s.socket, err = net.Listen("unix", s.cfg.Socket); err != nil {
return errors.Trace(err)
}
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", s.cfg.Socket))
}
if s.socket == nil && s.listener == nil {
err = errors.New("Server not configured to listen on either -socket or -host and -port")
return errors.Trace(err)
}
if s.cfg.ProxyProtocol.Networks != "" {
proxyTarget := s.listener
if proxyTarget == nil {
proxyTarget = s.socket
}
ppListener, err := proxyprotocol.NewLazyListener(proxyTarget, s.cfg.ProxyProtocol.Networks,
int(s.cfg.ProxyProtocol.HeaderTimeout), s.cfg.ProxyProtocol.Fallbackable)
if err != nil {
logutil.BgLogger().Error("ProxyProtocol networks parameter invalid")
return errors.Trace(err)
}
if s.listener != nil {
s.listener = ppListener
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("host", s.cfg.Host))
} else {
s.socket = ppListener
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("socket", s.cfg.Socket))
}
}
return nil
}
func (s *Server) initHTTPListener() (err error) {
if s.cfg.Status.ReportStatus {
if err = s.listenStatusHTTPServer(); err != nil {
return errors.Trace(err)
}
}
// Automatically reload JWKS for tidb_auth_token.
if len(s.cfg.Security.AuthTokenJWKS) > 0 {
var (
timeInterval time.Duration
err error
ctx context.Context
)
if timeInterval, err = time.ParseDuration(s.cfg.Security.AuthTokenRefreshInterval); err != nil {
logutil.BgLogger().Error("Fail to parse security.auth-token-refresh-interval. Use default value",
zap.String("security.auth-token-refresh-interval", s.cfg.Security.AuthTokenRefreshInterval))
timeInterval = config.DefAuthTokenRefreshInterval
}
ctx, s.authTokenCancelFunc = context.WithCancel(context.Background())
if err = privileges.GlobalJWKS.LoadJWKS4AuthToken(ctx, &s.wg, s.cfg.Security.AuthTokenJWKS, timeInterval); err != nil {
logutil.BgLogger().Error("Fail to load JWKS from the path", zap.String("jwks", s.cfg.Security.AuthTokenJWKS))
}
}
return
}
func cleanupStaleSocket(socket string) error {
sockStat, err := os.Stat(socket)
if err != nil {
return nil
}
if sockStat.Mode().Type() != os.ModeSocket {
return fmt.Errorf(
"the specified socket file %s is a %s instead of a socket file",
socket, sockStat.Mode().String())
}
if _, err = net.Dial("unix", socket); err == nil {
return fmt.Errorf("unix socket %s exists and is functional, not removing it", socket)
}
if err2 := os.Remove(socket); err2 != nil {
return fmt.Errorf("failed to cleanup stale Unix socket file %s: %w", socket, err)
}
return nil
}
func setSSLVariable(ca, key, cert string) {
variable.SetSysVar("have_openssl", "YES")
variable.SetSysVar("have_ssl", "YES")
variable.SetSysVar("ssl_cert", cert)
variable.SetSysVar("ssl_key", key)
variable.SetSysVar("ssl_ca", ca)
}
// Export config-related metrics
func (s *Server) reportConfig() {
metrics.ConfigStatus.WithLabelValues("token-limit").Set(float64(s.cfg.TokenLimit))
metrics.ConfigStatus.WithLabelValues("max_connections").Set(float64(s.cfg.Instance.MaxConnections))
}
// Run runs the server.
func (s *Server) Run(dom *domain.Domain) error {
metrics.ServerEventCounter.WithLabelValues(metrics.ServerStart).Inc()
s.reportConfig()
// Start HTTP API to report tidb info such as TPS.
if s.cfg.Status.ReportStatus {
err := s.startStatusHTTP()
if err != nil {
log.Error("failed to create the server", zap.Error(err), zap.Stack("stack"))
return err
}
mppcoordmanager.InstanceMPPCoordinatorManager.InitServerAddr(s.GetStatusServerAddr())
}
if config.GetGlobalConfig().Performance.ForceInitStats && dom != nil {
<-dom.StatsHandle().InitStatsDone
}
// If error should be reported and exit the server it can be sent on this
// channel. Otherwise, end with sending a nil error to signal "done"
errChan := make(chan error, 2)
err := s.initTiDBListener()
if err != nil {
log.Error("failed to create the server", zap.Error(err), zap.Stack("stack"))
return err
}
// Register error API is not thread-safe, the caller MUST NOT register errors after initialization.
// To prevent misuse, set a flag to indicate that register new error will panic immediately.
// For regression of issue like https://github.com/pingcap/tidb/issues/28190
terror.RegisterFinish()
go s.startNetworkListener(s.listener, false, errChan)
go s.startNetworkListener(s.socket, true, errChan)
if RunInGoTest && !isClosed(RunInGoTestChan) {
close(RunInGoTestChan)
}
s.health.Store(true)
err = <-errChan
if err != nil {
return err
}
return <-errChan
}
// isClosed is to check if the channel is closed
func isClosed(ch chan struct{}) bool {
select {
case <-ch:
return true
default:
}
return false
}
func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, errChan chan error) {
if listener == nil {
errChan <- nil
return
}
for {
conn, err := listener.Accept()
if err != nil {
if opErr, ok := err.(*net.OpError); ok {
if opErr.Err.Error() == "use of closed network connection" {
if s.inShutdownMode.Load() {
errChan <- nil
} else {
errChan <- err
}
return
}
}
// If we got PROXY protocol error, we should continue to accept.
if proxyprotocol.IsProxyProtocolError(err) {
logutil.BgLogger().Error("PROXY protocol failed", zap.Error(err))
continue
}
logutil.BgLogger().Error("accept failed", zap.Error(err))
errChan <- err
return
}
logutil.BgLogger().Debug("accept new connection success")
clientConn := s.newConn(conn)
if isUnixSocket {
var (
uc *net.UnixConn
ok bool
)
if clientConn.ppEnabled {
// Using reflect to get Raw Conn object from proxy protocol wrapper connection object
ppv := reflect.ValueOf(conn)
vconn := ppv.Elem().FieldByName("Conn")
rconn := vconn.Interface()
uc, ok = rconn.(*net.UnixConn)
} else {
uc, ok = conn.(*net.UnixConn)
}
if !ok {
logutil.BgLogger().Error("Expected UNIX socket, but got something else")
return
}
clientConn.isUnixSocket = true
clientConn.peerHost = "localhost"
clientConn.socketCredUID, err = linux.GetSockUID(*uc)
if err != nil {
logutil.BgLogger().Error("Failed to get UNIX socket peer credentials", zap.Error(err))
return
}
}
err = nil
if !clientConn.ppEnabled {
// Check audit plugins when ProxyProtocol not enabled
err = s.checkAuditPlugin(clientConn)
}
if err != nil {
continue
}
if s.dom != nil && s.dom.IsLostConnectionToPD() {
logutil.BgLogger().Warn("reject connection due to lost connection to PD")
terror.Log(clientConn.Close())
continue
}
go s.onConn(clientConn)
}
}
func (*Server) checkAuditPlugin(clientConn *clientConn) error {
return plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent == nil {
return nil
}
host, _, err := clientConn.PeerHost("", false)
if err != nil {
logutil.BgLogger().Error("get peer host failed", zap.Error(err))
terror.Log(clientConn.Close())
return errors.Trace(err)
}
if err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth,
&variable.ConnectionInfo{Host: host}); err != nil {
logutil.BgLogger().Info("do connection event failed", zap.Error(err))
terror.Log(clientConn.Close())
return errors.Trace(err)
}
return nil
})
}
func (s *Server) startShutdown() {
logutil.BgLogger().Info("setting tidb-server to report unhealthy (shutting-down)")
s.health.Store(false)
// give the load balancer a chance to receive a few unhealthy health reports
// before acquiring the s.rwlock and blocking connections.
waitTime := time.Duration(s.cfg.GracefulWaitBeforeShutdown) * time.Second
if waitTime > 0 {
logutil.BgLogger().Info("waiting for stray connections before starting shutdown process", zap.Duration("waitTime", waitTime))
time.Sleep(waitTime)
}
}
func (s *Server) closeListener() {
if s.listener != nil {
err := s.listener.Close()
terror.Log(errors.Trace(err))
s.listener = nil
}
if s.socket != nil {
err := s.socket.Close()
terror.Log(errors.Trace(err))
s.socket = nil
}
if statusServer := s.statusServer.Load(); statusServer != nil {
err := statusServer.Close()
terror.Log(errors.Trace(err))
s.statusServer.Store(nil)
}
if s.grpcServer != nil {
s.grpcServer.Stop()
s.grpcServer = nil
}
if s.autoIDService != nil {
s.autoIDService.Close()
}
if s.authTokenCancelFunc != nil {
s.authTokenCancelFunc()
}
s.wg.Wait()
metrics.ServerEventCounter.WithLabelValues(metrics.ServerStop).Inc()
}
// Close closes the server.
func (s *Server) Close() {
s.startShutdown()
s.rwlock.Lock() // // prevent new connections
defer s.rwlock.Unlock()
s.inShutdownMode.Store(true)
s.closeListener()
}
func (s *Server) registerConn(conn *clientConn) bool {
s.rwlock.Lock()
defer s.rwlock.Unlock()
logger := logutil.BgLogger()
if s.inShutdownMode.Load() {
logger.Info("close connection directly when shutting down")
terror.Log(closeConn(conn))
return false
}
s.clients[conn.connectionID] = conn
return true
}
// onConn runs in its own goroutine, handles queries from this connection.
func (s *Server) onConn(conn *clientConn) {
if s.StandbyController != nil {
s.StandbyController.OnConnActive()
}
// init the connInfo
_, _, err := conn.PeerHost("", false)
if err != nil {
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
Error("get peer host failed", zap.Error(err))
terror.Log(conn.Close())
return
}
extensions, err := extension.GetExtensions()
if err != nil {
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
Error("error in get extensions", zap.Error(err))
terror.Log(conn.Close())
return
}
if sessExtensions := extensions.NewSessionExtensions(); sessExtensions != nil {
conn.extensions = sessExtensions
conn.onExtensionConnEvent(extension.ConnConnected, nil)
defer func() {
conn.onExtensionConnEvent(extension.ConnDisconnected, nil)
}()
}
ctx := logutil.WithConnID(context.Background(), conn.connectionID)
if err := conn.handshake(ctx); err != nil {
conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err)
if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil {
conn.getCtx().GetSessionVars().ConnectionInfo = conn.connectInfo()
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
pluginCtx := context.WithValue(context.Background(), plugin.RejectReasonCtxValue{}, err.Error())
return authPlugin.OnConnectionEvent(pluginCtx, plugin.Reject, conn.ctx.GetSessionVars().ConnectionInfo)
}
return nil
})
terror.Log(err)
}
switch errors.Cause(err) {
case io.EOF:
// `EOF` means the connection is closed normally, we do not treat it as a noticeable error and log it in 'DEBUG' level.
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
Debug("EOF", zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
case servererr.ErrConCount:
if err := conn.writeError(ctx, err); err != nil {
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
Warn("error in writing errConCount", zap.Error(err),
zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
}
default:
metrics.HandShakeErrorCounter.Inc()
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
Warn("Server.onConn handshake", zap.Error(err),
zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
}
terror.Log(conn.Close())
return
}
logutil.Logger(ctx).Debug("new connection", zap.String("remoteAddr", conn.bufReadConn.RemoteAddr().String()))
defer func() {
terror.Log(conn.Close())
logutil.Logger(ctx).Debug("connection closed")
}()
if err := conn.increaseUserConnectionsCount(); err != nil {
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
Warn("failed to increase the count of connections", zap.Error(err),
zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
return
}
defer conn.decreaseUserConnectionCount()
if !s.registerConn(conn) {
return
}
sessionVars := conn.ctx.GetSessionVars()
sessionVars.ConnectionInfo = conn.connectInfo()
conn.onExtensionConnEvent(extension.ConnHandshakeAccepted, nil)
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
return authPlugin.OnConnectionEvent(context.Background(), plugin.Connected, sessionVars.ConnectionInfo)
}
return nil
})
if err != nil {
return
}
connectedTime := time.Now()
conn.Run(ctx)
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
if authPlugin.OnConnectionEvent != nil {
sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond)
err := authPlugin.OnConnectionEvent(context.Background(), plugin.Disconnect, sessionVars.ConnectionInfo)
if err != nil {
logutil.BgLogger().Warn("do connection event failed", zap.String("plugin", authPlugin.Name), zap.Error(err))
}
}
return nil
})
if err != nil {
return
}
}
func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
connType := variable.ConnTypeSocket
sslVersion := ""
if cc.isUnixSocket {
connType = variable.ConnTypeUnixSocket
} else if cc.tlsConn != nil {
connType = variable.ConnTypeTLS
sslVersionNum := cc.tlsConn.ConnectionState().Version
switch sslVersionNum {
case tls.VersionTLS12:
sslVersion = "TLSv1.2"
case tls.VersionTLS13:
sslVersion = "TLSv1.3"
default:
sslVersion = fmt.Sprintf("Unknown TLS version: %d", sslVersionNum)
}
}
connInfo := &variable.ConnectionInfo{
ConnectionID: cc.connectionID,
ConnectionType: connType,
Host: cc.peerHost,
ClientIP: cc.peerHost,
ClientPort: cc.peerPort,
ServerID: 1,
ServerIP: cc.serverHost,
ServerPort: int(cc.server.cfg.Port),
User: cc.user,
ServerOSLoginUser: osUser,
OSVersion: osVersion,
ServerVersion: mysql.TiDBReleaseVersion,
SSLVersion: sslVersion,
PID: serverPID,
DB: cc.dbname,
AuthMethod: cc.authPlugin,
Attributes: cc.attrs,
}
return connInfo
}
func (s *Server) checkConnectionCount() error {
// When the value of Instance.MaxConnections is 0, the number of connections is unlimited.
if int(s.cfg.Instance.MaxConnections) == 0 {
return nil
}
conns := s.ConnectionCount()
if conns >= int(s.cfg.Instance.MaxConnections) {
logutil.BgLogger().Error("too many connections",
zap.Uint32("max connections", s.cfg.Instance.MaxConnections), zap.Error(servererr.ErrConCount))
return servererr.ErrConCount
}
return nil
}
// ShowProcessList implements the SessionManager interface.
func (s *Server) ShowProcessList() map[uint64]*sessmgr.ProcessInfo {
rs := make(map[uint64]*sessmgr.ProcessInfo)
maps.Copy(rs, s.GetUserProcessList())
if s.dom != nil {
maps.Copy(rs, s.dom.SysProcTracker().GetSysProcessList())
}
return rs
}
// GetUserProcessList returns all process info that are created by user.
func (s *Server) GetUserProcessList() map[uint64]*sessmgr.ProcessInfo {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
rs := make(map[uint64]*sessmgr.ProcessInfo)
for _, client := range s.clients {
if pi := client.ctx.ShowProcess(); pi != nil {
rs[pi.ID] = pi
}
}
return rs
}
// GetClientCapabilityList returns all client capability.
func (s *Server) GetClientCapabilityList() map[uint64]uint32 {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
rs := make(map[uint64]uint32)
for id, client := range s.clients {
if client.ctx.Session != nil {
rs[id] = client.capability
}
}
return rs
}
// ShowTxnList shows all txn info for displaying in `TIDB_TRX`.
// Internal sessions are not taken into consideration.
func (s *Server) ShowTxnList() []*txninfo.TxnInfo {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
rs := make([]*txninfo.TxnInfo, 0, len(s.clients))
for _, client := range s.clients {
if client.ctx.Session != nil {
info := client.ctx.Session.TxnInfo()
if info != nil && info.ProcessInfo != nil {
rs = append(rs, info)
}
}
}
return rs
}
// UpdateProcessCPUTime updates specific process's tidb CPU time when the process is still running
// It implements ProcessCPUTimeUpdater interface
func (s *Server) UpdateProcessCPUTime(connID uint64, sqlID uint64, cpuTime time.Duration) {
s.rwlock.RLock()
conn, ok := s.clients[connID]
s.rwlock.RUnlock()
if !ok {
return
}
vars := conn.ctx.GetSessionVars()
if vars != nil {
vars.SQLCPUUsages.MergeTidbCPUTime(sqlID, cpuTime)
}
}
// GetProcessInfo implements the SessionManager interface.
func (s *Server) GetProcessInfo(id uint64) (*sessmgr.ProcessInfo, bool) {
s.rwlock.RLock()
conn, ok := s.clients[id]
s.rwlock.RUnlock()
if !ok {
if s.dom != nil {
if pinfo, ok2 := s.dom.SysProcTracker().GetSysProcessList()[id]; ok2 {
return pinfo, true
}
}
return &sessmgr.ProcessInfo{}, false
}
return conn.ctx.ShowProcess(), ok
}
// GetConAttrs returns the connection attributes
func (s *Server) GetConAttrs(user *auth.UserIdentity) map[uint64]map[string]string {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
rs := make(map[uint64]map[string]string)
for _, client := range s.clients {
if user != nil {
if user.Username != client.user {
continue
}
if user.Hostname != client.peerHost {
continue
}
}
if pi := client.ctx.ShowProcess(); pi != nil {
rs[pi.ID] = client.attrs
}
}
return rs
}
// Kill implements the SessionManager interface.
func (s *Server) Kill(connectionID uint64, query bool, maxExecutionTime bool, runaway bool) {
logutil.BgLogger().Info("kill", zap.Uint64("conn", connectionID),
zap.Bool("query", query), zap.Bool("maxExecutionTime", maxExecutionTime), zap.Bool("runawayExceed", runaway))
metrics.ServerEventCounter.WithLabelValues(metrics.EventKill).Inc()
s.rwlock.RLock()
defer s.rwlock.RUnlock()
conn, ok := s.clients[connectionID]
if !ok && s.dom != nil {
s.dom.SysProcTracker().KillSysProcess(connectionID)
return
}
if !query {
// Mark the client connection status as WaitShutdown, when clientConn.Run detect
// this, it will end the dispatch loop and exit.
conn.setStatus(connStatusWaitShutdown)
if conn.bufReadConn != nil {
// When attempting to 'kill connection' and TiDB is stuck in the network stack while writing packets,
// we can quickly exit the network stack and terminate the SQL execution by setting WriteDeadline.
if err := conn.bufReadConn.SetWriteDeadline(time.Now()); err != nil {
logutil.BgLogger().Warn("error setting write deadline for kill.", zap.Error(err))
}
if err := conn.bufReadConn.SetReadDeadline(time.Now()); err != nil {
logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err))
}
}
}
killQuery(conn, maxExecutionTime, runaway)
}
// UpdateTLSConfig implements the SessionManager interface.
func (s *Server) UpdateTLSConfig(cfg *tls.Config) {
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg))
}
// GetTLSConfig implements the SessionManager interface.
func (s *Server) GetTLSConfig() *tls.Config {
return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig))
}
func killQuery(conn *clientConn, maxExecutionTime, runaway bool) {
sessVars := conn.ctx.GetSessionVars()
if runaway {
sessVars.SQLKiller.SendKillSignal(sqlkiller.RunawayQueryExceeded)
} else if maxExecutionTime {
sessVars.SQLKiller.SendKillSignal(sqlkiller.MaxExecTimeExceeded)
} else {
sessVars.SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted)
}
conn.mu.RLock()
cancelFunc := conn.mu.cancelFunc
conn.mu.RUnlock()
if cancelFunc != nil {
cancelFunc()
}
sessVars.SQLKiller.FinishResultSet()
}
// KillSysProcesses kill sys processes such as auto analyze.
func (s *Server) KillSysProcesses() {
if s.dom == nil {
return
}
sysProcTracker := s.dom.SysProcTracker()
for connID := range sysProcTracker.GetSysProcessList() {
sysProcTracker.KillSysProcess(connID)
}
}
// KillAllConnections implements the SessionManager interface.
// KillAllConnections kills all connections.
func (s *Server) KillAllConnections() {
logutil.BgLogger().Info("kill all connections.", zap.String("category", "server"))
s.rwlock.RLock()
defer s.rwlock.RUnlock()
for _, conn := range s.clients {
conn.setStatus(connStatusShutdown)
if err := conn.closeWithoutLock(); err != nil {
terror.Log(err)
}
if conn.bufReadConn != nil {
if err := conn.bufReadConn.SetReadDeadline(time.Now()); err != nil {
logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err))
}
}
killQuery(conn, false, false)
}
s.KillSysProcesses()
}
// DrainClients drain all connections in drainWait.
// After drainWait duration, we kill all connections still not quit explicitly and wait for cancelWait.
func (s *Server) DrainClients(drainWait time.Duration, cancelWait time.Duration) {
logger := logutil.BgLogger()
logger.Info("start drain clients")
conns := make(map[uint64]*clientConn)
s.rwlock.Lock()
maps.Copy(conns, s.clients)
s.rwlock.Unlock()
allDone := make(chan struct{})
quitWaitingForConns := make(chan struct{})
defer close(quitWaitingForConns)
go func() {
defer close(allDone)
for _, conn := range conns {
// Wait for the connections with explicit transaction or an executing auto-commit query.
if conn.getStatus() == connStatusReading && !conn.getCtx().GetSessionVars().InTxn() {
// The waitgroup is not protected by the `quitWaitingForConns`. However, the implementation
// of `client-go` will guarantee this `Wait` will return at least after killing the
// connections. We also wait for a similar `WaitGroup` on the store after killing the connections.
//
// Therefore, it'll not cause goroutine leak. Even if, it's not a big issue when the TiDB is
// going to shutdown.
//
// It should be waited for connections in all status, even if it's not in transactions and is reading
// from the client. Because it may run background commit goroutines at any time.
conn.getCtx().Session.GetCommitWaitGroup().Wait()
continue
}
select {
case <-conn.quit:
case <-quitWaitingForConns:
return
}
// Wait for the commit wait group after waiting for the `conn.quit` channel to make sure the foreground
// process has finished to avoid the situation that after waiting for the wait group, the transaction starts
// a new background goroutine and increase the wait group.
conn.getCtx().Session.GetCommitWaitGroup().Wait()
}
}()
select {
case <-allDone:
logger.Info("all sessions quit in drain wait time")
case <-time.After(drainWait):
logger.Info("timeout waiting all sessions quit")
}
s.KillAllConnections()
select {
case <-allDone:
case <-time.After(cancelWait):
logger.Warn("some sessions do not quit in cancel wait time")
}
}
// ServerID implements SessionManager interface.
func (s *Server) ServerID() uint64 {
return s.dom.ServerID()
}
// StoreInternalSession implements SessionManager interface.
// @param addr The address of a session.session struct variable
func (s *Server) StoreInternalSession(se any) {
s.sessionMapMutex.Lock()
s.internalSessions[se] = struct{}{}
metrics.InternalSessions.Set(float64(len(s.internalSessions)))
s.sessionMapMutex.Unlock()
}
// ContainsInternalSession implements SessionManager interface.
func (s *Server) ContainsInternalSession(se any) bool {
s.sessionMapMutex.Lock()
defer s.sessionMapMutex.Unlock()
_, ok := s.internalSessions[se]
return ok
}
// InternalSessionCount implements sessionmgr.InfoSchemaCoordinator interface.
func (s *Server) InternalSessionCount() int {
s.sessionMapMutex.Lock()
defer s.sessionMapMutex.Unlock()
return len(s.internalSessions)
}
// DeleteInternalSession implements SessionManager interface.
// @param addr The address of a session.session struct variable
func (s *Server) DeleteInternalSession(se any) {
s.sessionMapMutex.Lock()
delete(s.internalSessions, se)
metrics.InternalSessions.Set(float64(len(s.internalSessions)))
s.sessionMapMutex.Unlock()
}
// GetInternalSessionStartTSList implements SessionManager interface.
func (s *Server) GetInternalSessionStartTSList() []uint64 {
s.sessionMapMutex.Lock()
defer s.sessionMapMutex.Unlock()
tsList := make([]uint64, 0, len(s.internalSessions))
for se := range s.internalSessions {
if ts, processInfoID := session.GetStartTSFromSession(se); ts != 0 {
if statsutil.GlobalAutoAnalyzeProcessList.Contains(processInfoID) {
continue
}
tsList = append(tsList, ts)
}
}
return tsList
}
// InternalSessionExists is used for test
func (s *Server) InternalSessionExists(se any) bool {
s.sessionMapMutex.Lock()
_, ok := s.internalSessions[se]
s.sessionMapMutex.Unlock()
return ok
}
// setSysTimeZoneOnce is used for parallel run tests. When several servers are running,
// only the first will actually do setSystemTimeZoneVariable, thus we can avoid data race.
var setSysTimeZoneOnce = &sync.Once{}
func setSystemTimeZoneVariable() {
setSysTimeZoneOnce.Do(func() {
tz, err := timeutil.GetSystemTZ()
if err != nil {
logutil.BgLogger().Error(
"Error getting SystemTZ, use default value instead",
zap.Error(err),
zap.String("default system_time_zone", variable.GetSysVar("system_time_zone").Value))
return
}
variable.SetSysVar("system_time_zone", tz)
})
}
// CheckOldRunningTxn implements SessionManager interface.
func (s *Server) CheckOldRunningTxn(jobs map[int64]*mdldef.JobMDL) {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
printLog := false
if time.Since(s.printMDLLogTime) > 10*time.Second {
printLog = true
s.printMDLLogTime = time.Now()
}
for _, client := range s.clients {
se := client.ctx.Session
if se != nil {
variable.RemoveLockDDLJobs(se.GetSessionVars(), jobs, printLog)
}
}
}
// KillNonFlashbackClusterConn implements SessionManager interface.
func (s *Server) KillNonFlashbackClusterConn() {
s.rwlock.RLock()
connIDs := make([]uint64, 0, len(s.clients))
for _, client := range s.clients {
if client.ctx.Session != nil {
processInfo := client.ctx.Session.ShowProcess()
ddl, ok := processInfo.StmtCtx.GetPlan().(*core.DDL)
if !ok {
connIDs = append(connIDs, client.connectionID)
continue
}
_, ok = ddl.Statement.(*ast.FlashBackToTimestampStmt)
if !ok {
connIDs = append(connIDs, client.connectionID)
continue
}
}
}
s.rwlock.RUnlock()
for _, id := range connIDs {
s.Kill(id, false, false, false)
}
}
// GetStatusVars is getting the per process status variables from the server
func (s *Server) GetStatusVars() map[uint64]map[string]string {
s.rwlock.RLock()
defer s.rwlock.RUnlock()
rs := make(map[uint64]map[string]string)
for _, client := range s.clients {
if pi := client.ctx.ShowProcess(); pi != nil {
if client.tlsConn != nil {
connState := client.tlsConn.ConnectionState()
rs[pi.ID] = map[string]string{
"Ssl_cipher": tlsutil.CipherSuiteName(connState.CipherSuite),
"Ssl_version": tlsutil.VersionName(connState.Version),
}
}
}
}
return rs
}
// Health returns if the server is healthy (begin to shut down)
func (s *Server) Health() bool {
return s.health.Load()
}