1272 lines
39 KiB
Go
1272 lines
39 KiB
Go
// Copyright 2015 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
// The MIT License (MIT)
|
|
//
|
|
// Copyright (c) 2014 wandoulabs
|
|
// Copyright (c) 2014 siddontang
|
|
//
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
// this software and associated documentation files (the "Software"), to deal in
|
|
// the Software without restriction, including without limitation the rights to
|
|
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
// the Software, and to permit persons to whom the Software is furnished to do so,
|
|
// subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in all
|
|
// copies or substantial portions of the Software.
|
|
|
|
package server
|
|
|
|
import (
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"io"
|
|
"maps"
|
|
"net"
|
|
"net/http" //nolint:goimports
|
|
_ "net/http/pprof" // #nosec G108 for pprof
|
|
"os"
|
|
"os/user"
|
|
"reflect"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/blacktear23/go-proxyprotocol"
|
|
"github.com/pingcap/errors"
|
|
"github.com/pingcap/log"
|
|
autoid "github.com/pingcap/tidb/pkg/autoid_service"
|
|
"github.com/pingcap/tidb/pkg/config"
|
|
"github.com/pingcap/tidb/pkg/domain"
|
|
"github.com/pingcap/tidb/pkg/executor/mppcoordmanager"
|
|
"github.com/pingcap/tidb/pkg/extension"
|
|
"github.com/pingcap/tidb/pkg/infoschema/issyncer/mdldef"
|
|
"github.com/pingcap/tidb/pkg/metrics"
|
|
"github.com/pingcap/tidb/pkg/parser/ast"
|
|
"github.com/pingcap/tidb/pkg/parser/auth"
|
|
"github.com/pingcap/tidb/pkg/parser/mysql"
|
|
"github.com/pingcap/tidb/pkg/parser/terror"
|
|
"github.com/pingcap/tidb/pkg/planner/core"
|
|
"github.com/pingcap/tidb/pkg/plugin"
|
|
"github.com/pingcap/tidb/pkg/privilege/privileges"
|
|
"github.com/pingcap/tidb/pkg/resourcegroup"
|
|
servererr "github.com/pingcap/tidb/pkg/server/err"
|
|
"github.com/pingcap/tidb/pkg/session"
|
|
"github.com/pingcap/tidb/pkg/session/sessmgr"
|
|
"github.com/pingcap/tidb/pkg/session/txninfo"
|
|
"github.com/pingcap/tidb/pkg/sessionctx/variable"
|
|
statsutil "github.com/pingcap/tidb/pkg/statistics/handle/util"
|
|
"github.com/pingcap/tidb/pkg/util"
|
|
"github.com/pingcap/tidb/pkg/util/fastrand"
|
|
"github.com/pingcap/tidb/pkg/util/kvcache"
|
|
"github.com/pingcap/tidb/pkg/util/logutil"
|
|
"github.com/pingcap/tidb/pkg/util/sqlkiller"
|
|
"github.com/pingcap/tidb/pkg/util/sys/linux"
|
|
"github.com/pingcap/tidb/pkg/util/timeutil"
|
|
tlsutil "github.com/pingcap/tidb/pkg/util/tls"
|
|
uatomic "go.uber.org/atomic"
|
|
"go.uber.org/zap"
|
|
"google.golang.org/grpc"
|
|
)
|
|
|
|
var (
|
|
serverPID int
|
|
osUser string
|
|
osVersion string
|
|
// RunInGoTest represents whether we are run code in test.
|
|
RunInGoTest bool
|
|
// RunInGoTestChan is used to control the RunInGoTest.
|
|
RunInGoTestChan chan struct{}
|
|
)
|
|
|
|
func init() {
|
|
serverPID = os.Getpid()
|
|
currentUser, err := user.Current()
|
|
if err != nil {
|
|
osUser = ""
|
|
} else {
|
|
osUser = currentUser.Name
|
|
}
|
|
osVersion, err = linux.OSVersion()
|
|
if err != nil {
|
|
osVersion = ""
|
|
}
|
|
}
|
|
|
|
// DefaultCapability is the capability of the server when it is created using the default configuration.
|
|
// When server is configured with SSL, the server will have extra capabilities compared to DefaultCapability.
|
|
const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag |
|
|
mysql.ClientConnectWithDB | mysql.ClientProtocol41 |
|
|
mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows |
|
|
mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles |
|
|
mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive |
|
|
mysql.ClientDeprecateEOF | mysql.ClientCompress | mysql.ClientZstdCompressionAlgorithm
|
|
|
|
const normalClosedConnsCapacity = 1000
|
|
|
|
// Server is the MySQL protocol server
|
|
type Server struct {
|
|
cfg *config.Config
|
|
tlsConfig unsafe.Pointer // *tls.Config
|
|
driver IDriver
|
|
listener net.Listener
|
|
socket net.Listener
|
|
concurrentLimiter *util.TokenLimiter
|
|
|
|
rwlock sync.RWMutex
|
|
clients map[uint64]*clientConn
|
|
|
|
normalClosedConnsMutex sync.Mutex
|
|
normalClosedConns *kvcache.SimpleLRUCache
|
|
|
|
userResLock sync.RWMutex // userResLock used to protect userResource
|
|
userResource map[string]*userResourceLimits
|
|
|
|
capability uint32
|
|
dom *domain.Domain
|
|
|
|
statusAddr string
|
|
statusListener net.Listener
|
|
statusServer atomic.Pointer[http.Server]
|
|
grpcServer *grpc.Server
|
|
inShutdownMode *uatomic.Bool
|
|
health *uatomic.Bool
|
|
|
|
sessionMapMutex sync.Mutex
|
|
internalSessions map[any]struct{}
|
|
autoIDService *autoid.Service
|
|
authTokenCancelFunc context.CancelFunc
|
|
wg sync.WaitGroup
|
|
printMDLLogTime time.Time
|
|
|
|
StandbyController
|
|
}
|
|
|
|
// NewTestServer creates a new Server for test.
|
|
func NewTestServer(cfg *config.Config) *Server {
|
|
return &Server{
|
|
cfg: cfg,
|
|
}
|
|
}
|
|
|
|
// Socket returns the server's socket file.
|
|
func (s *Server) Socket() net.Listener {
|
|
return s.socket
|
|
}
|
|
|
|
// Listener returns the server's listener.
|
|
func (s *Server) Listener() net.Listener {
|
|
return s.listener
|
|
}
|
|
|
|
// ListenAddr returns the server's listener's network address.
|
|
func (s *Server) ListenAddr() net.Addr {
|
|
return s.listener.Addr()
|
|
}
|
|
|
|
// StatusListenerAddr returns the server's status listener's network address.
|
|
func (s *Server) StatusListenerAddr() net.Addr {
|
|
return s.statusListener.Addr()
|
|
}
|
|
|
|
// BitwiseXorCapability gets the capability of the server.
|
|
func (s *Server) BitwiseXorCapability(capability uint32) {
|
|
s.capability ^= capability
|
|
}
|
|
|
|
// BitwiseOrAssignCapability adds the capability to the server.
|
|
func (s *Server) BitwiseOrAssignCapability(capability uint32) {
|
|
s.capability |= capability
|
|
}
|
|
|
|
// GetStatusServerAddr gets statusServer address for MppCoordinatorManager usage
|
|
func (s *Server) GetStatusServerAddr() (on bool, addr string) {
|
|
if !s.cfg.Status.ReportStatus {
|
|
return false, ""
|
|
}
|
|
if strings.Contains(s.statusAddr, config.DefStatusHost) {
|
|
if len(s.cfg.AdvertiseAddress) != 0 {
|
|
return true, strings.ReplaceAll(s.statusAddr, config.DefStatusHost, s.cfg.AdvertiseAddress)
|
|
}
|
|
return false, ""
|
|
}
|
|
return true, s.statusAddr
|
|
}
|
|
|
|
type normalCloseConnKey struct {
|
|
keyspaceName string
|
|
connID string
|
|
}
|
|
|
|
func (k normalCloseConnKey) Hash() []byte {
|
|
return []byte(fmt.Sprintf("%s-%s", k.keyspaceName, k.connID))
|
|
}
|
|
|
|
// SetNormalClosedConn sets the normal closed connection message by specified connID.
|
|
func (s *Server) SetNormalClosedConn(keyspaceName, connID, msg string) {
|
|
if connID == "" {
|
|
return
|
|
}
|
|
|
|
s.normalClosedConnsMutex.Lock()
|
|
defer s.normalClosedConnsMutex.Unlock()
|
|
s.normalClosedConns.Put(normalCloseConnKey{keyspaceName: keyspaceName, connID: connID}, msg)
|
|
}
|
|
|
|
// GetNormalClosedConn gets the normal closed connection message.
|
|
func (s *Server) GetNormalClosedConn(keyspaceName, connID string) string {
|
|
if connID == "" {
|
|
return ""
|
|
}
|
|
|
|
s.normalClosedConnsMutex.Lock()
|
|
defer s.normalClosedConnsMutex.Unlock()
|
|
v, ok := s.normalClosedConns.Get(normalCloseConnKey{keyspaceName: keyspaceName, connID: connID})
|
|
if !ok {
|
|
return ""
|
|
}
|
|
return v.(string)
|
|
}
|
|
|
|
// ConnectionCount gets current connection count.
|
|
func (s *Server) ConnectionCount() int {
|
|
s.rwlock.RLock()
|
|
cnt := len(s.clients)
|
|
s.rwlock.RUnlock()
|
|
return cnt
|
|
}
|
|
|
|
func (s *Server) getToken() *util.Token {
|
|
start := time.Now()
|
|
tok := s.concurrentLimiter.Get()
|
|
metrics.TokenGauge.Inc()
|
|
// Note that data smaller than one microsecond is ignored, because that case can be viewed as non-block.
|
|
metrics.GetTokenDurationHistogram.Observe(float64(time.Since(start).Nanoseconds() / 1e3))
|
|
return tok
|
|
}
|
|
|
|
func (s *Server) releaseToken(token *util.Token) {
|
|
s.concurrentLimiter.Put(token)
|
|
metrics.TokenGauge.Dec()
|
|
}
|
|
|
|
// SetDomain use to set the server domain.
|
|
func (s *Server) SetDomain(dom *domain.Domain) {
|
|
s.dom = dom
|
|
}
|
|
|
|
// newConn creates a new *clientConn from a net.Conn.
|
|
// It allocates a connection ID and random salt data for authentication.
|
|
func (s *Server) newConn(conn net.Conn) *clientConn {
|
|
cc := newClientConn(s)
|
|
if tcpConn, ok := conn.(*net.TCPConn); ok {
|
|
if err := tcpConn.SetKeepAlive(s.cfg.Performance.TCPKeepAlive); err != nil {
|
|
logutil.BgLogger().Error("failed to set tcp keep alive option", zap.Error(err))
|
|
}
|
|
if err := tcpConn.SetNoDelay(s.cfg.Performance.TCPNoDelay); err != nil {
|
|
logutil.BgLogger().Error("failed to set tcp no delay option", zap.Error(err))
|
|
}
|
|
}
|
|
cc.setConn(conn)
|
|
cc.salt = fastrand.Buf(20)
|
|
metrics.ConnGauge.WithLabelValues(resourcegroup.DefaultResourceGroupName).Inc()
|
|
return cc
|
|
}
|
|
|
|
// NewServer creates a new Server.
|
|
func NewServer(cfg *config.Config, driver IDriver) (*Server, error) {
|
|
s := &Server{
|
|
cfg: cfg,
|
|
driver: driver,
|
|
concurrentLimiter: util.NewTokenLimiter(cfg.TokenLimit),
|
|
clients: make(map[uint64]*clientConn),
|
|
normalClosedConns: kvcache.NewSimpleLRUCache(normalClosedConnsCapacity, 0, 0),
|
|
userResource: make(map[string]*userResourceLimits),
|
|
internalSessions: make(map[any]struct{}, 100),
|
|
health: uatomic.NewBool(false),
|
|
inShutdownMode: uatomic.NewBool(false),
|
|
printMDLLogTime: time.Now(),
|
|
}
|
|
s.capability = defaultCapability
|
|
setSystemTimeZoneVariable()
|
|
|
|
tlsConfig, autoReload, err := util.LoadTLSCertificates(
|
|
s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert,
|
|
s.cfg.Security.AutoTLS, s.cfg.Security.RSAKeySize)
|
|
|
|
// LoadTLSCertificates will auto generate certificates if autoTLS is enabled.
|
|
// It only returns an error if certificates are specified and invalid.
|
|
// In which case, we should halt server startup as a misconfiguration could
|
|
// lead to a connection downgrade.
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
// Automatically reload auto-generated certificates.
|
|
// The certificates are re-created every 30 days and are valid for 90 days.
|
|
if autoReload {
|
|
go func() {
|
|
for range time.Tick(time.Hour * 24 * 30) { // 30 days
|
|
logutil.BgLogger().Info("Rotating automatically created TLS Certificates")
|
|
tlsConfig, _, err = util.LoadTLSCertificates(
|
|
s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert,
|
|
s.cfg.Security.AutoTLS, s.cfg.Security.RSAKeySize)
|
|
if err != nil {
|
|
logutil.BgLogger().Warn("TLS Certificate rotation failed", zap.Error(err))
|
|
}
|
|
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
|
|
}
|
|
}()
|
|
}
|
|
|
|
if tlsConfig != nil {
|
|
setSSLVariable(s.cfg.Security.SSLCA, s.cfg.Security.SSLKey, s.cfg.Security.SSLCert)
|
|
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(tlsConfig))
|
|
logutil.BgLogger().Info("mysql protocol server secure connection is enabled",
|
|
zap.Bool("client verification enabled", len(variable.GetSysVar("ssl_ca").Value) > 0))
|
|
}
|
|
if s.tlsConfig != nil {
|
|
s.capability |= mysql.ClientSSL
|
|
}
|
|
variable.RegisterStatistics(s)
|
|
return s, nil
|
|
}
|
|
|
|
func (s *Server) initTiDBListener() (err error) {
|
|
if s.cfg.Host != "" && (s.cfg.Port != 0 || RunInGoTest) {
|
|
addr := net.JoinHostPort(s.cfg.Host, strconv.Itoa(int(s.cfg.Port)))
|
|
tcpProto := "tcp"
|
|
if s.cfg.EnableTCP4Only {
|
|
tcpProto = "tcp4"
|
|
}
|
|
if s.listener, err = net.Listen(tcpProto, addr); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr))
|
|
if RunInGoTest && s.cfg.Port == 0 {
|
|
s.cfg.Port = uint(s.listener.Addr().(*net.TCPAddr).Port)
|
|
}
|
|
}
|
|
|
|
if s.cfg.Socket != "" {
|
|
if err := cleanupStaleSocket(s.cfg.Socket); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
if s.socket, err = net.Listen("unix", s.cfg.Socket); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", s.cfg.Socket))
|
|
}
|
|
|
|
if s.socket == nil && s.listener == nil {
|
|
err = errors.New("Server not configured to listen on either -socket or -host and -port")
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
if s.cfg.ProxyProtocol.Networks != "" {
|
|
proxyTarget := s.listener
|
|
if proxyTarget == nil {
|
|
proxyTarget = s.socket
|
|
}
|
|
ppListener, err := proxyprotocol.NewLazyListener(proxyTarget, s.cfg.ProxyProtocol.Networks,
|
|
int(s.cfg.ProxyProtocol.HeaderTimeout), s.cfg.ProxyProtocol.Fallbackable)
|
|
if err != nil {
|
|
logutil.BgLogger().Error("ProxyProtocol networks parameter invalid")
|
|
return errors.Trace(err)
|
|
}
|
|
if s.listener != nil {
|
|
s.listener = ppListener
|
|
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("host", s.cfg.Host))
|
|
} else {
|
|
s.socket = ppListener
|
|
logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("socket", s.cfg.Socket))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) initHTTPListener() (err error) {
|
|
if s.cfg.Status.ReportStatus {
|
|
if err = s.listenStatusHTTPServer(); err != nil {
|
|
return errors.Trace(err)
|
|
}
|
|
}
|
|
|
|
// Automatically reload JWKS for tidb_auth_token.
|
|
if len(s.cfg.Security.AuthTokenJWKS) > 0 {
|
|
var (
|
|
timeInterval time.Duration
|
|
err error
|
|
ctx context.Context
|
|
)
|
|
if timeInterval, err = time.ParseDuration(s.cfg.Security.AuthTokenRefreshInterval); err != nil {
|
|
logutil.BgLogger().Error("Fail to parse security.auth-token-refresh-interval. Use default value",
|
|
zap.String("security.auth-token-refresh-interval", s.cfg.Security.AuthTokenRefreshInterval))
|
|
timeInterval = config.DefAuthTokenRefreshInterval
|
|
}
|
|
ctx, s.authTokenCancelFunc = context.WithCancel(context.Background())
|
|
if err = privileges.GlobalJWKS.LoadJWKS4AuthToken(ctx, &s.wg, s.cfg.Security.AuthTokenJWKS, timeInterval); err != nil {
|
|
logutil.BgLogger().Error("Fail to load JWKS from the path", zap.String("jwks", s.cfg.Security.AuthTokenJWKS))
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func cleanupStaleSocket(socket string) error {
|
|
sockStat, err := os.Stat(socket)
|
|
if err != nil {
|
|
return nil
|
|
}
|
|
|
|
if sockStat.Mode().Type() != os.ModeSocket {
|
|
return fmt.Errorf(
|
|
"the specified socket file %s is a %s instead of a socket file",
|
|
socket, sockStat.Mode().String())
|
|
}
|
|
|
|
if _, err = net.Dial("unix", socket); err == nil {
|
|
return fmt.Errorf("unix socket %s exists and is functional, not removing it", socket)
|
|
}
|
|
|
|
if err2 := os.Remove(socket); err2 != nil {
|
|
return fmt.Errorf("failed to cleanup stale Unix socket file %s: %w", socket, err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func setSSLVariable(ca, key, cert string) {
|
|
variable.SetSysVar("have_openssl", "YES")
|
|
variable.SetSysVar("have_ssl", "YES")
|
|
variable.SetSysVar("ssl_cert", cert)
|
|
variable.SetSysVar("ssl_key", key)
|
|
variable.SetSysVar("ssl_ca", ca)
|
|
}
|
|
|
|
// Export config-related metrics
|
|
func (s *Server) reportConfig() {
|
|
metrics.ConfigStatus.WithLabelValues("token-limit").Set(float64(s.cfg.TokenLimit))
|
|
metrics.ConfigStatus.WithLabelValues("max_connections").Set(float64(s.cfg.Instance.MaxConnections))
|
|
}
|
|
|
|
// Run runs the server.
|
|
func (s *Server) Run(dom *domain.Domain) error {
|
|
metrics.ServerEventCounter.WithLabelValues(metrics.ServerStart).Inc()
|
|
s.reportConfig()
|
|
|
|
// Start HTTP API to report tidb info such as TPS.
|
|
if s.cfg.Status.ReportStatus {
|
|
err := s.startStatusHTTP()
|
|
if err != nil {
|
|
log.Error("failed to create the server", zap.Error(err), zap.Stack("stack"))
|
|
return err
|
|
}
|
|
mppcoordmanager.InstanceMPPCoordinatorManager.InitServerAddr(s.GetStatusServerAddr())
|
|
}
|
|
if config.GetGlobalConfig().Performance.ForceInitStats && dom != nil {
|
|
<-dom.StatsHandle().InitStatsDone
|
|
}
|
|
// If error should be reported and exit the server it can be sent on this
|
|
// channel. Otherwise, end with sending a nil error to signal "done"
|
|
errChan := make(chan error, 2)
|
|
err := s.initTiDBListener()
|
|
if err != nil {
|
|
log.Error("failed to create the server", zap.Error(err), zap.Stack("stack"))
|
|
return err
|
|
}
|
|
// Register error API is not thread-safe, the caller MUST NOT register errors after initialization.
|
|
// To prevent misuse, set a flag to indicate that register new error will panic immediately.
|
|
// For regression of issue like https://github.com/pingcap/tidb/issues/28190
|
|
terror.RegisterFinish()
|
|
go s.startNetworkListener(s.listener, false, errChan)
|
|
go s.startNetworkListener(s.socket, true, errChan)
|
|
if RunInGoTest && !isClosed(RunInGoTestChan) {
|
|
close(RunInGoTestChan)
|
|
}
|
|
s.health.Store(true)
|
|
err = <-errChan
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return <-errChan
|
|
}
|
|
|
|
// isClosed is to check if the channel is closed
|
|
func isClosed(ch chan struct{}) bool {
|
|
select {
|
|
case <-ch:
|
|
return true
|
|
default:
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func (s *Server) startNetworkListener(listener net.Listener, isUnixSocket bool, errChan chan error) {
|
|
if listener == nil {
|
|
errChan <- nil
|
|
return
|
|
}
|
|
for {
|
|
conn, err := listener.Accept()
|
|
if err != nil {
|
|
if opErr, ok := err.(*net.OpError); ok {
|
|
if opErr.Err.Error() == "use of closed network connection" {
|
|
if s.inShutdownMode.Load() {
|
|
errChan <- nil
|
|
} else {
|
|
errChan <- err
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
// If we got PROXY protocol error, we should continue to accept.
|
|
if proxyprotocol.IsProxyProtocolError(err) {
|
|
logutil.BgLogger().Error("PROXY protocol failed", zap.Error(err))
|
|
continue
|
|
}
|
|
|
|
logutil.BgLogger().Error("accept failed", zap.Error(err))
|
|
errChan <- err
|
|
return
|
|
}
|
|
|
|
logutil.BgLogger().Debug("accept new connection success")
|
|
|
|
clientConn := s.newConn(conn)
|
|
if isUnixSocket {
|
|
var (
|
|
uc *net.UnixConn
|
|
ok bool
|
|
)
|
|
if clientConn.ppEnabled {
|
|
// Using reflect to get Raw Conn object from proxy protocol wrapper connection object
|
|
ppv := reflect.ValueOf(conn)
|
|
vconn := ppv.Elem().FieldByName("Conn")
|
|
rconn := vconn.Interface()
|
|
uc, ok = rconn.(*net.UnixConn)
|
|
} else {
|
|
uc, ok = conn.(*net.UnixConn)
|
|
}
|
|
if !ok {
|
|
logutil.BgLogger().Error("Expected UNIX socket, but got something else")
|
|
return
|
|
}
|
|
|
|
clientConn.isUnixSocket = true
|
|
clientConn.peerHost = "localhost"
|
|
clientConn.socketCredUID, err = linux.GetSockUID(*uc)
|
|
if err != nil {
|
|
logutil.BgLogger().Error("Failed to get UNIX socket peer credentials", zap.Error(err))
|
|
return
|
|
}
|
|
}
|
|
|
|
err = nil
|
|
if !clientConn.ppEnabled {
|
|
// Check audit plugins when ProxyProtocol not enabled
|
|
err = s.checkAuditPlugin(clientConn)
|
|
}
|
|
if err != nil {
|
|
continue
|
|
}
|
|
|
|
if s.dom != nil && s.dom.IsLostConnectionToPD() {
|
|
logutil.BgLogger().Warn("reject connection due to lost connection to PD")
|
|
terror.Log(clientConn.Close())
|
|
continue
|
|
}
|
|
|
|
go s.onConn(clientConn)
|
|
}
|
|
}
|
|
|
|
func (*Server) checkAuditPlugin(clientConn *clientConn) error {
|
|
return plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
|
|
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
|
|
if authPlugin.OnConnectionEvent == nil {
|
|
return nil
|
|
}
|
|
host, _, err := clientConn.PeerHost("", false)
|
|
if err != nil {
|
|
logutil.BgLogger().Error("get peer host failed", zap.Error(err))
|
|
terror.Log(clientConn.Close())
|
|
return errors.Trace(err)
|
|
}
|
|
if err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth,
|
|
&variable.ConnectionInfo{Host: host}); err != nil {
|
|
logutil.BgLogger().Info("do connection event failed", zap.Error(err))
|
|
terror.Log(clientConn.Close())
|
|
return errors.Trace(err)
|
|
}
|
|
return nil
|
|
})
|
|
}
|
|
|
|
func (s *Server) startShutdown() {
|
|
logutil.BgLogger().Info("setting tidb-server to report unhealthy (shutting-down)")
|
|
s.health.Store(false)
|
|
// give the load balancer a chance to receive a few unhealthy health reports
|
|
// before acquiring the s.rwlock and blocking connections.
|
|
waitTime := time.Duration(s.cfg.GracefulWaitBeforeShutdown) * time.Second
|
|
if waitTime > 0 {
|
|
logutil.BgLogger().Info("waiting for stray connections before starting shutdown process", zap.Duration("waitTime", waitTime))
|
|
time.Sleep(waitTime)
|
|
}
|
|
}
|
|
|
|
func (s *Server) closeListener() {
|
|
if s.listener != nil {
|
|
err := s.listener.Close()
|
|
terror.Log(errors.Trace(err))
|
|
s.listener = nil
|
|
}
|
|
if s.socket != nil {
|
|
err := s.socket.Close()
|
|
terror.Log(errors.Trace(err))
|
|
s.socket = nil
|
|
}
|
|
if statusServer := s.statusServer.Load(); statusServer != nil {
|
|
err := statusServer.Close()
|
|
terror.Log(errors.Trace(err))
|
|
s.statusServer.Store(nil)
|
|
}
|
|
if s.grpcServer != nil {
|
|
s.grpcServer.Stop()
|
|
s.grpcServer = nil
|
|
}
|
|
if s.autoIDService != nil {
|
|
s.autoIDService.Close()
|
|
}
|
|
if s.authTokenCancelFunc != nil {
|
|
s.authTokenCancelFunc()
|
|
}
|
|
s.wg.Wait()
|
|
metrics.ServerEventCounter.WithLabelValues(metrics.ServerStop).Inc()
|
|
}
|
|
|
|
// Close closes the server.
|
|
func (s *Server) Close() {
|
|
s.startShutdown()
|
|
s.rwlock.Lock() // // prevent new connections
|
|
defer s.rwlock.Unlock()
|
|
s.inShutdownMode.Store(true)
|
|
s.closeListener()
|
|
}
|
|
|
|
func (s *Server) registerConn(conn *clientConn) bool {
|
|
s.rwlock.Lock()
|
|
defer s.rwlock.Unlock()
|
|
|
|
logger := logutil.BgLogger()
|
|
if s.inShutdownMode.Load() {
|
|
logger.Info("close connection directly when shutting down")
|
|
terror.Log(closeConn(conn))
|
|
return false
|
|
}
|
|
s.clients[conn.connectionID] = conn
|
|
return true
|
|
}
|
|
|
|
// onConn runs in its own goroutine, handles queries from this connection.
|
|
func (s *Server) onConn(conn *clientConn) {
|
|
if s.StandbyController != nil {
|
|
s.StandbyController.OnConnActive()
|
|
}
|
|
|
|
// init the connInfo
|
|
_, _, err := conn.PeerHost("", false)
|
|
if err != nil {
|
|
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
|
|
Error("get peer host failed", zap.Error(err))
|
|
terror.Log(conn.Close())
|
|
return
|
|
}
|
|
|
|
extensions, err := extension.GetExtensions()
|
|
if err != nil {
|
|
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
|
|
Error("error in get extensions", zap.Error(err))
|
|
terror.Log(conn.Close())
|
|
return
|
|
}
|
|
|
|
if sessExtensions := extensions.NewSessionExtensions(); sessExtensions != nil {
|
|
conn.extensions = sessExtensions
|
|
conn.onExtensionConnEvent(extension.ConnConnected, nil)
|
|
defer func() {
|
|
conn.onExtensionConnEvent(extension.ConnDisconnected, nil)
|
|
}()
|
|
}
|
|
|
|
ctx := logutil.WithConnID(context.Background(), conn.connectionID)
|
|
|
|
if err := conn.handshake(ctx); err != nil {
|
|
conn.onExtensionConnEvent(extension.ConnHandshakeRejected, err)
|
|
if plugin.IsEnable(plugin.Audit) && conn.getCtx() != nil {
|
|
conn.getCtx().GetSessionVars().ConnectionInfo = conn.connectInfo()
|
|
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
|
|
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
|
|
if authPlugin.OnConnectionEvent != nil {
|
|
pluginCtx := context.WithValue(context.Background(), plugin.RejectReasonCtxValue{}, err.Error())
|
|
return authPlugin.OnConnectionEvent(pluginCtx, plugin.Reject, conn.ctx.GetSessionVars().ConnectionInfo)
|
|
}
|
|
return nil
|
|
})
|
|
terror.Log(err)
|
|
}
|
|
switch errors.Cause(err) {
|
|
case io.EOF:
|
|
// `EOF` means the connection is closed normally, we do not treat it as a noticeable error and log it in 'DEBUG' level.
|
|
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
|
|
Debug("EOF", zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
|
|
case servererr.ErrConCount:
|
|
if err := conn.writeError(ctx, err); err != nil {
|
|
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
|
|
Warn("error in writing errConCount", zap.Error(err),
|
|
zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
|
|
}
|
|
default:
|
|
metrics.HandShakeErrorCounter.Inc()
|
|
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
|
|
Warn("Server.onConn handshake", zap.Error(err),
|
|
zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
|
|
}
|
|
terror.Log(conn.Close())
|
|
return
|
|
}
|
|
|
|
logutil.Logger(ctx).Debug("new connection", zap.String("remoteAddr", conn.bufReadConn.RemoteAddr().String()))
|
|
|
|
defer func() {
|
|
terror.Log(conn.Close())
|
|
logutil.Logger(ctx).Debug("connection closed")
|
|
}()
|
|
|
|
if err := conn.increaseUserConnectionsCount(); err != nil {
|
|
logutil.BgLogger().With(zap.Uint64("conn", conn.connectionID)).
|
|
Warn("failed to increase the count of connections", zap.Error(err),
|
|
zap.String("remote addr", conn.bufReadConn.RemoteAddr().String()))
|
|
return
|
|
}
|
|
defer conn.decreaseUserConnectionCount()
|
|
|
|
if !s.registerConn(conn) {
|
|
return
|
|
}
|
|
|
|
sessionVars := conn.ctx.GetSessionVars()
|
|
sessionVars.ConnectionInfo = conn.connectInfo()
|
|
conn.onExtensionConnEvent(extension.ConnHandshakeAccepted, nil)
|
|
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
|
|
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
|
|
if authPlugin.OnConnectionEvent != nil {
|
|
return authPlugin.OnConnectionEvent(context.Background(), plugin.Connected, sessionVars.ConnectionInfo)
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
connectedTime := time.Now()
|
|
conn.Run(ctx)
|
|
|
|
err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error {
|
|
authPlugin := plugin.DeclareAuditManifest(p.Manifest)
|
|
if authPlugin.OnConnectionEvent != nil {
|
|
sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond)
|
|
err := authPlugin.OnConnectionEvent(context.Background(), plugin.Disconnect, sessionVars.ConnectionInfo)
|
|
if err != nil {
|
|
logutil.BgLogger().Warn("do connection event failed", zap.String("plugin", authPlugin.Name), zap.Error(err))
|
|
}
|
|
}
|
|
return nil
|
|
})
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
func (cc *clientConn) connectInfo() *variable.ConnectionInfo {
|
|
connType := variable.ConnTypeSocket
|
|
sslVersion := ""
|
|
if cc.isUnixSocket {
|
|
connType = variable.ConnTypeUnixSocket
|
|
} else if cc.tlsConn != nil {
|
|
connType = variable.ConnTypeTLS
|
|
sslVersionNum := cc.tlsConn.ConnectionState().Version
|
|
switch sslVersionNum {
|
|
case tls.VersionTLS12:
|
|
sslVersion = "TLSv1.2"
|
|
case tls.VersionTLS13:
|
|
sslVersion = "TLSv1.3"
|
|
default:
|
|
sslVersion = fmt.Sprintf("Unknown TLS version: %d", sslVersionNum)
|
|
}
|
|
}
|
|
connInfo := &variable.ConnectionInfo{
|
|
ConnectionID: cc.connectionID,
|
|
ConnectionType: connType,
|
|
Host: cc.peerHost,
|
|
ClientIP: cc.peerHost,
|
|
ClientPort: cc.peerPort,
|
|
ServerID: 1,
|
|
ServerIP: cc.serverHost,
|
|
ServerPort: int(cc.server.cfg.Port),
|
|
User: cc.user,
|
|
ServerOSLoginUser: osUser,
|
|
OSVersion: osVersion,
|
|
ServerVersion: mysql.TiDBReleaseVersion,
|
|
SSLVersion: sslVersion,
|
|
PID: serverPID,
|
|
DB: cc.dbname,
|
|
AuthMethod: cc.authPlugin,
|
|
Attributes: cc.attrs,
|
|
}
|
|
return connInfo
|
|
}
|
|
|
|
func (s *Server) checkConnectionCount() error {
|
|
// When the value of Instance.MaxConnections is 0, the number of connections is unlimited.
|
|
if int(s.cfg.Instance.MaxConnections) == 0 {
|
|
return nil
|
|
}
|
|
|
|
conns := s.ConnectionCount()
|
|
|
|
if conns >= int(s.cfg.Instance.MaxConnections) {
|
|
logutil.BgLogger().Error("too many connections",
|
|
zap.Uint32("max connections", s.cfg.Instance.MaxConnections), zap.Error(servererr.ErrConCount))
|
|
return servererr.ErrConCount
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// ShowProcessList implements the SessionManager interface.
|
|
func (s *Server) ShowProcessList() map[uint64]*sessmgr.ProcessInfo {
|
|
rs := make(map[uint64]*sessmgr.ProcessInfo)
|
|
maps.Copy(rs, s.GetUserProcessList())
|
|
if s.dom != nil {
|
|
maps.Copy(rs, s.dom.SysProcTracker().GetSysProcessList())
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// GetUserProcessList returns all process info that are created by user.
|
|
func (s *Server) GetUserProcessList() map[uint64]*sessmgr.ProcessInfo {
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
rs := make(map[uint64]*sessmgr.ProcessInfo)
|
|
for _, client := range s.clients {
|
|
if pi := client.ctx.ShowProcess(); pi != nil {
|
|
rs[pi.ID] = pi
|
|
}
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// GetClientCapabilityList returns all client capability.
|
|
func (s *Server) GetClientCapabilityList() map[uint64]uint32 {
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
rs := make(map[uint64]uint32)
|
|
for id, client := range s.clients {
|
|
if client.ctx.Session != nil {
|
|
rs[id] = client.capability
|
|
}
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// ShowTxnList shows all txn info for displaying in `TIDB_TRX`.
|
|
// Internal sessions are not taken into consideration.
|
|
func (s *Server) ShowTxnList() []*txninfo.TxnInfo {
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
rs := make([]*txninfo.TxnInfo, 0, len(s.clients))
|
|
for _, client := range s.clients {
|
|
if client.ctx.Session != nil {
|
|
info := client.ctx.Session.TxnInfo()
|
|
if info != nil && info.ProcessInfo != nil {
|
|
rs = append(rs, info)
|
|
}
|
|
}
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// UpdateProcessCPUTime updates specific process's tidb CPU time when the process is still running
|
|
// It implements ProcessCPUTimeUpdater interface
|
|
func (s *Server) UpdateProcessCPUTime(connID uint64, sqlID uint64, cpuTime time.Duration) {
|
|
s.rwlock.RLock()
|
|
conn, ok := s.clients[connID]
|
|
s.rwlock.RUnlock()
|
|
if !ok {
|
|
return
|
|
}
|
|
vars := conn.ctx.GetSessionVars()
|
|
if vars != nil {
|
|
vars.SQLCPUUsages.MergeTidbCPUTime(sqlID, cpuTime)
|
|
}
|
|
}
|
|
|
|
// GetProcessInfo implements the SessionManager interface.
|
|
func (s *Server) GetProcessInfo(id uint64) (*sessmgr.ProcessInfo, bool) {
|
|
s.rwlock.RLock()
|
|
conn, ok := s.clients[id]
|
|
s.rwlock.RUnlock()
|
|
if !ok {
|
|
if s.dom != nil {
|
|
if pinfo, ok2 := s.dom.SysProcTracker().GetSysProcessList()[id]; ok2 {
|
|
return pinfo, true
|
|
}
|
|
}
|
|
return &sessmgr.ProcessInfo{}, false
|
|
}
|
|
return conn.ctx.ShowProcess(), ok
|
|
}
|
|
|
|
// GetConAttrs returns the connection attributes
|
|
func (s *Server) GetConAttrs(user *auth.UserIdentity) map[uint64]map[string]string {
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
rs := make(map[uint64]map[string]string)
|
|
for _, client := range s.clients {
|
|
if user != nil {
|
|
if user.Username != client.user {
|
|
continue
|
|
}
|
|
if user.Hostname != client.peerHost {
|
|
continue
|
|
}
|
|
}
|
|
if pi := client.ctx.ShowProcess(); pi != nil {
|
|
rs[pi.ID] = client.attrs
|
|
}
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// Kill implements the SessionManager interface.
|
|
func (s *Server) Kill(connectionID uint64, query bool, maxExecutionTime bool, runaway bool) {
|
|
logutil.BgLogger().Info("kill", zap.Uint64("conn", connectionID),
|
|
zap.Bool("query", query), zap.Bool("maxExecutionTime", maxExecutionTime), zap.Bool("runawayExceed", runaway))
|
|
metrics.ServerEventCounter.WithLabelValues(metrics.EventKill).Inc()
|
|
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
conn, ok := s.clients[connectionID]
|
|
if !ok && s.dom != nil {
|
|
s.dom.SysProcTracker().KillSysProcess(connectionID)
|
|
return
|
|
}
|
|
|
|
if !query {
|
|
// Mark the client connection status as WaitShutdown, when clientConn.Run detect
|
|
// this, it will end the dispatch loop and exit.
|
|
conn.setStatus(connStatusWaitShutdown)
|
|
if conn.bufReadConn != nil {
|
|
// When attempting to 'kill connection' and TiDB is stuck in the network stack while writing packets,
|
|
// we can quickly exit the network stack and terminate the SQL execution by setting WriteDeadline.
|
|
if err := conn.bufReadConn.SetWriteDeadline(time.Now()); err != nil {
|
|
logutil.BgLogger().Warn("error setting write deadline for kill.", zap.Error(err))
|
|
}
|
|
if err := conn.bufReadConn.SetReadDeadline(time.Now()); err != nil {
|
|
logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err))
|
|
}
|
|
}
|
|
}
|
|
killQuery(conn, maxExecutionTime, runaway)
|
|
}
|
|
|
|
// UpdateTLSConfig implements the SessionManager interface.
|
|
func (s *Server) UpdateTLSConfig(cfg *tls.Config) {
|
|
atomic.StorePointer(&s.tlsConfig, unsafe.Pointer(cfg))
|
|
}
|
|
|
|
// GetTLSConfig implements the SessionManager interface.
|
|
func (s *Server) GetTLSConfig() *tls.Config {
|
|
return (*tls.Config)(atomic.LoadPointer(&s.tlsConfig))
|
|
}
|
|
|
|
func killQuery(conn *clientConn, maxExecutionTime, runaway bool) {
|
|
sessVars := conn.ctx.GetSessionVars()
|
|
if runaway {
|
|
sessVars.SQLKiller.SendKillSignal(sqlkiller.RunawayQueryExceeded)
|
|
} else if maxExecutionTime {
|
|
sessVars.SQLKiller.SendKillSignal(sqlkiller.MaxExecTimeExceeded)
|
|
} else {
|
|
sessVars.SQLKiller.SendKillSignal(sqlkiller.QueryInterrupted)
|
|
}
|
|
conn.mu.RLock()
|
|
cancelFunc := conn.mu.cancelFunc
|
|
conn.mu.RUnlock()
|
|
|
|
if cancelFunc != nil {
|
|
cancelFunc()
|
|
}
|
|
sessVars.SQLKiller.FinishResultSet()
|
|
}
|
|
|
|
// KillSysProcesses kill sys processes such as auto analyze.
|
|
func (s *Server) KillSysProcesses() {
|
|
if s.dom == nil {
|
|
return
|
|
}
|
|
sysProcTracker := s.dom.SysProcTracker()
|
|
for connID := range sysProcTracker.GetSysProcessList() {
|
|
sysProcTracker.KillSysProcess(connID)
|
|
}
|
|
}
|
|
|
|
// KillAllConnections implements the SessionManager interface.
|
|
// KillAllConnections kills all connections.
|
|
func (s *Server) KillAllConnections() {
|
|
logutil.BgLogger().Info("kill all connections.", zap.String("category", "server"))
|
|
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
for _, conn := range s.clients {
|
|
conn.setStatus(connStatusShutdown)
|
|
if err := conn.closeWithoutLock(); err != nil {
|
|
terror.Log(err)
|
|
}
|
|
if conn.bufReadConn != nil {
|
|
if err := conn.bufReadConn.SetReadDeadline(time.Now()); err != nil {
|
|
logutil.BgLogger().Warn("error setting read deadline for kill.", zap.Error(err))
|
|
}
|
|
}
|
|
killQuery(conn, false, false)
|
|
}
|
|
|
|
s.KillSysProcesses()
|
|
}
|
|
|
|
// DrainClients drain all connections in drainWait.
|
|
// After drainWait duration, we kill all connections still not quit explicitly and wait for cancelWait.
|
|
func (s *Server) DrainClients(drainWait time.Duration, cancelWait time.Duration) {
|
|
logger := logutil.BgLogger()
|
|
logger.Info("start drain clients")
|
|
|
|
conns := make(map[uint64]*clientConn)
|
|
|
|
s.rwlock.Lock()
|
|
maps.Copy(conns, s.clients)
|
|
s.rwlock.Unlock()
|
|
|
|
allDone := make(chan struct{})
|
|
quitWaitingForConns := make(chan struct{})
|
|
defer close(quitWaitingForConns)
|
|
go func() {
|
|
defer close(allDone)
|
|
for _, conn := range conns {
|
|
// Wait for the connections with explicit transaction or an executing auto-commit query.
|
|
if conn.getStatus() == connStatusReading && !conn.getCtx().GetSessionVars().InTxn() {
|
|
// The waitgroup is not protected by the `quitWaitingForConns`. However, the implementation
|
|
// of `client-go` will guarantee this `Wait` will return at least after killing the
|
|
// connections. We also wait for a similar `WaitGroup` on the store after killing the connections.
|
|
//
|
|
// Therefore, it'll not cause goroutine leak. Even if, it's not a big issue when the TiDB is
|
|
// going to shutdown.
|
|
//
|
|
// It should be waited for connections in all status, even if it's not in transactions and is reading
|
|
// from the client. Because it may run background commit goroutines at any time.
|
|
conn.getCtx().Session.GetCommitWaitGroup().Wait()
|
|
|
|
continue
|
|
}
|
|
select {
|
|
case <-conn.quit:
|
|
case <-quitWaitingForConns:
|
|
return
|
|
}
|
|
|
|
// Wait for the commit wait group after waiting for the `conn.quit` channel to make sure the foreground
|
|
// process has finished to avoid the situation that after waiting for the wait group, the transaction starts
|
|
// a new background goroutine and increase the wait group.
|
|
conn.getCtx().Session.GetCommitWaitGroup().Wait()
|
|
}
|
|
}()
|
|
|
|
select {
|
|
case <-allDone:
|
|
logger.Info("all sessions quit in drain wait time")
|
|
case <-time.After(drainWait):
|
|
logger.Info("timeout waiting all sessions quit")
|
|
}
|
|
|
|
s.KillAllConnections()
|
|
|
|
select {
|
|
case <-allDone:
|
|
case <-time.After(cancelWait):
|
|
logger.Warn("some sessions do not quit in cancel wait time")
|
|
}
|
|
}
|
|
|
|
// ServerID implements SessionManager interface.
|
|
func (s *Server) ServerID() uint64 {
|
|
return s.dom.ServerID()
|
|
}
|
|
|
|
// StoreInternalSession implements SessionManager interface.
|
|
// @param addr The address of a session.session struct variable
|
|
func (s *Server) StoreInternalSession(se any) {
|
|
s.sessionMapMutex.Lock()
|
|
s.internalSessions[se] = struct{}{}
|
|
metrics.InternalSessions.Set(float64(len(s.internalSessions)))
|
|
s.sessionMapMutex.Unlock()
|
|
}
|
|
|
|
// ContainsInternalSession implements SessionManager interface.
|
|
func (s *Server) ContainsInternalSession(se any) bool {
|
|
s.sessionMapMutex.Lock()
|
|
defer s.sessionMapMutex.Unlock()
|
|
_, ok := s.internalSessions[se]
|
|
return ok
|
|
}
|
|
|
|
// InternalSessionCount implements sessionmgr.InfoSchemaCoordinator interface.
|
|
func (s *Server) InternalSessionCount() int {
|
|
s.sessionMapMutex.Lock()
|
|
defer s.sessionMapMutex.Unlock()
|
|
return len(s.internalSessions)
|
|
}
|
|
|
|
// DeleteInternalSession implements SessionManager interface.
|
|
// @param addr The address of a session.session struct variable
|
|
func (s *Server) DeleteInternalSession(se any) {
|
|
s.sessionMapMutex.Lock()
|
|
delete(s.internalSessions, se)
|
|
metrics.InternalSessions.Set(float64(len(s.internalSessions)))
|
|
s.sessionMapMutex.Unlock()
|
|
}
|
|
|
|
// GetInternalSessionStartTSList implements SessionManager interface.
|
|
func (s *Server) GetInternalSessionStartTSList() []uint64 {
|
|
s.sessionMapMutex.Lock()
|
|
defer s.sessionMapMutex.Unlock()
|
|
tsList := make([]uint64, 0, len(s.internalSessions))
|
|
for se := range s.internalSessions {
|
|
if ts, processInfoID := session.GetStartTSFromSession(se); ts != 0 {
|
|
if statsutil.GlobalAutoAnalyzeProcessList.Contains(processInfoID) {
|
|
continue
|
|
}
|
|
tsList = append(tsList, ts)
|
|
}
|
|
}
|
|
return tsList
|
|
}
|
|
|
|
// InternalSessionExists is used for test
|
|
func (s *Server) InternalSessionExists(se any) bool {
|
|
s.sessionMapMutex.Lock()
|
|
_, ok := s.internalSessions[se]
|
|
s.sessionMapMutex.Unlock()
|
|
return ok
|
|
}
|
|
|
|
// setSysTimeZoneOnce is used for parallel run tests. When several servers are running,
|
|
// only the first will actually do setSystemTimeZoneVariable, thus we can avoid data race.
|
|
var setSysTimeZoneOnce = &sync.Once{}
|
|
|
|
func setSystemTimeZoneVariable() {
|
|
setSysTimeZoneOnce.Do(func() {
|
|
tz, err := timeutil.GetSystemTZ()
|
|
if err != nil {
|
|
logutil.BgLogger().Error(
|
|
"Error getting SystemTZ, use default value instead",
|
|
zap.Error(err),
|
|
zap.String("default system_time_zone", variable.GetSysVar("system_time_zone").Value))
|
|
return
|
|
}
|
|
variable.SetSysVar("system_time_zone", tz)
|
|
})
|
|
}
|
|
|
|
// CheckOldRunningTxn implements SessionManager interface.
|
|
func (s *Server) CheckOldRunningTxn(jobs map[int64]*mdldef.JobMDL) {
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
|
|
printLog := false
|
|
if time.Since(s.printMDLLogTime) > 10*time.Second {
|
|
printLog = true
|
|
s.printMDLLogTime = time.Now()
|
|
}
|
|
for _, client := range s.clients {
|
|
se := client.ctx.Session
|
|
if se != nil {
|
|
variable.RemoveLockDDLJobs(se.GetSessionVars(), jobs, printLog)
|
|
}
|
|
}
|
|
}
|
|
|
|
// KillNonFlashbackClusterConn implements SessionManager interface.
|
|
func (s *Server) KillNonFlashbackClusterConn() {
|
|
s.rwlock.RLock()
|
|
connIDs := make([]uint64, 0, len(s.clients))
|
|
for _, client := range s.clients {
|
|
if client.ctx.Session != nil {
|
|
processInfo := client.ctx.Session.ShowProcess()
|
|
ddl, ok := processInfo.StmtCtx.GetPlan().(*core.DDL)
|
|
if !ok {
|
|
connIDs = append(connIDs, client.connectionID)
|
|
continue
|
|
}
|
|
_, ok = ddl.Statement.(*ast.FlashBackToTimestampStmt)
|
|
if !ok {
|
|
connIDs = append(connIDs, client.connectionID)
|
|
continue
|
|
}
|
|
}
|
|
}
|
|
s.rwlock.RUnlock()
|
|
for _, id := range connIDs {
|
|
s.Kill(id, false, false, false)
|
|
}
|
|
}
|
|
|
|
// GetStatusVars is getting the per process status variables from the server
|
|
func (s *Server) GetStatusVars() map[uint64]map[string]string {
|
|
s.rwlock.RLock()
|
|
defer s.rwlock.RUnlock()
|
|
rs := make(map[uint64]map[string]string)
|
|
for _, client := range s.clients {
|
|
if pi := client.ctx.ShowProcess(); pi != nil {
|
|
if client.tlsConn != nil {
|
|
connState := client.tlsConn.ConnectionState()
|
|
rs[pi.ID] = map[string]string{
|
|
"Ssl_cipher": tlsutil.CipherSuiteName(connState.CipherSuite),
|
|
"Ssl_version": tlsutil.VersionName(connState.Version),
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return rs
|
|
}
|
|
|
|
// Health returns if the server is healthy (begin to shut down)
|
|
func (s *Server) Health() bool {
|
|
return s.health.Load()
|
|
}
|