// The MIT License (MIT) // // Copyright (c) 2014 wandoulabs // Copyright (c) 2014 siddontang // // Permission is hereby granted, free of charge, to any person obtaining a copy of // this software and associated documentation files (the "Software"), to deal in // the Software without restriction, including without limitation the rights to // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of // the Software, and to permit persons to whom the Software is furnished to do so, // subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // Copyright 2015 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package server import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "io/ioutil" "math/rand" "net" "net/http" "os" "os/user" "sync" "sync/atomic" "time" // For pprof _ "net/http/pprof" "github.com/blacktear23/go-proxyprotocol" "github.com/pingcap/errors" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" "github.com/pingcap/tidb/util/sys/linux" "go.uber.org/zap" ) var ( baseConnID uint32 serverPID int osUser string osVersion string ) func init() { serverPID = os.Getpid() currentUser, err := user.Current() if err != nil { osUser = "" } else { osUser = currentUser.Name } osVersion, err = linux.OSVersion() if err != nil { osVersion = "" } } var ( errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type") errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length") errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence") errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type") errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version") errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied]) errMaxExecTimeExceeded = terror.ClassServer.New(codeMaxExecTimeExceeded, mysql.MySQLErrName[mysql.ErrMaxExecTimeExceeded]) ) // DefaultCapability is the capability of the server when it is created using the default configuration. // When server is configured with SSL, the server will have extra capabilities compared to DefaultCapability. const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | mysql.ClientConnectWithDB | mysql.ClientProtocol41 | mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows | mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles | mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive // Server is the MySQL protocol server type Server struct { cfg *config.Config tlsConfig *tls.Config driver IDriver listener net.Listener socket net.Listener rwlock sync.RWMutex concurrentLimiter *TokenLimiter clients map[uint32]*clientConn capability uint32 // stopListenerCh is used when a critical error occurred, we don't want to exit the process, because there may be // a supervisor automatically restart it, then new client connection will be created, but we can't server it. // So we just stop the listener and store to force clients to chose other TiDB servers. stopListenerCh chan struct{} statusServer *http.Server } // ConnectionCount gets current connection count. func (s *Server) ConnectionCount() int { s.rwlock.RLock() cnt := len(s.clients) s.rwlock.RUnlock() return cnt } func (s *Server) getToken() *Token { start := time.Now() tok := s.concurrentLimiter.Get() // Note that data smaller than one microsecond is ignored, because that case can be viewed as non-block. metrics.GetTokenDurationHistogram.Observe(float64(time.Since(start).Nanoseconds() / 1e3)) return tok } func (s *Server) releaseToken(token *Token) { s.concurrentLimiter.Put(token) } // newConn creates a new *clientConn from a net.Conn. // It allocates a connection ID and random salt data for authentication. func (s *Server) newConn(conn net.Conn) *clientConn { cc := newClientConn(s) if s.cfg.Performance.TCPKeepAlive { if tcpConn, ok := conn.(*net.TCPConn); ok { if err := tcpConn.SetKeepAlive(true); err != nil { logutil.BgLogger().Error("failed to set tcp keep alive option", zap.Error(err)) } } } cc.setConn(conn) cc.salt = util.RandomBuf(20) return cc } func (s *Server) isUnixSocket() bool { return s.cfg.Socket != "" } func (s *Server) forwardUnixSocketToTCP() { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) for { if s.listener == nil { return // server shutdown has started } if uconn, err := s.socket.Accept(); err == nil { logutil.BgLogger().Info("server socket forwarding", zap.String("from", s.cfg.Socket), zap.String("to", addr)) go s.handleForwardedConnection(uconn, addr) } else { if s.listener != nil { logutil.BgLogger().Error("server failed to forward", zap.String("from", s.cfg.Socket), zap.String("to", addr), zap.Error(err)) } } } } func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) { defer terror.Call(uconn.Close) if tconn, err := net.Dial("tcp", addr); err == nil { go func() { if _, err := io.Copy(uconn, tconn); err != nil { logutil.BgLogger().Warn("copy server to socket failed", zap.Error(err)) } }() if _, err := io.Copy(tconn, uconn); err != nil { logutil.BgLogger().Warn("socket forward copy failed", zap.Error(err)) } } else { logutil.BgLogger().Warn("socket forward failed: could not connect", zap.String("addr", addr), zap.Error(err)) } } // NewServer creates a new Server. func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s := &Server{ cfg: cfg, driver: driver, concurrentLimiter: NewTokenLimiter(cfg.TokenLimit), clients: make(map[uint32]*clientConn), stopListenerCh: make(chan struct{}, 1), } s.loadTLSCertificates() s.capability = defaultCapability if s.tlsConfig != nil { s.capability |= mysql.ClientSSL } var err error if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { logutil.BgLogger().Info("server is running MySQL protocol", zap.String("addr", addr)) if cfg.Socket != "" { if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil { logutil.BgLogger().Info("server redirecting", zap.String("from", s.cfg.Socket), zap.String("to", addr)) go s.forwardUnixSocketToTCP() } } } } else if cfg.Socket != "" { if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { logutil.BgLogger().Info("server is running MySQL protocol", zap.String("socket", cfg.Socket)) } } else { err = errors.New("Server not configured to listen on either -socket or -host and -port") } if cfg.ProxyProtocol.Networks != "" { pplistener, errProxy := proxyprotocol.NewListener(s.listener, cfg.ProxyProtocol.Networks, int(cfg.ProxyProtocol.HeaderTimeout)) if errProxy != nil { logutil.BgLogger().Error("ProxyProtocol networks parameter invalid") return nil, errors.Trace(errProxy) } logutil.BgLogger().Info("server is running MySQL protocol (through PROXY protocol)", zap.String("host", s.cfg.Host)) s.listener = pplistener } if err != nil { return nil, errors.Trace(err) } // Init rand seed for randomBuf() rand.Seed(time.Now().UTC().UnixNano()) return s, nil } func (s *Server) loadTLSCertificates() { defer func() { if s.tlsConfig != nil { logutil.BgLogger().Info("secure connection is enabled", zap.Bool("client verification enabled", len(variable.SysVars["ssl_ca"].Value) > 0)) variable.SysVars["have_openssl"].Value = "YES" variable.SysVars["have_ssl"].Value = "YES" variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey } else { logutil.BgLogger().Warn("secure connection is not enabled") } }() if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 { s.tlsConfig = nil return } tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey) if err != nil { logutil.BgLogger().Warn("load x509 failed", zap.Error(err)) s.tlsConfig = nil return } // Try loading CA cert. clientAuthPolicy := tls.NoClientCert var certPool *x509.CertPool if len(s.cfg.Security.SSLCA) > 0 { caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA) if err != nil { logutil.BgLogger().Warn("read file failed", zap.Error(err)) } else { certPool = x509.NewCertPool() if certPool.AppendCertsFromPEM(caCert) { clientAuthPolicy = tls.VerifyClientCertIfGiven } variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA } } s.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{tlsCert}, ClientCAs: certPool, ClientAuth: clientAuthPolicy, } } // Run runs the server. func (s *Server) Run() error { metrics.ServerEventCounter.WithLabelValues(metrics.EventStart).Inc() // Start HTTP API to report tidb info such as TPS. if s.cfg.Status.ReportStatus { s.startStatusHTTP() } for { conn, err := s.listener.Accept() if err != nil { if opErr, ok := err.(*net.OpError); ok { if opErr.Err.Error() == "use of closed network connection" { return nil } } // If we got PROXY protocol error, we should continue accept. if proxyprotocol.IsProxyProtocolError(err) { logutil.BgLogger().Error("PROXY protocol failed", zap.Error(err)) continue } logutil.BgLogger().Error("accept failed", zap.Error(err)) return errors.Trace(err) } if s.shouldStopListener() { err = conn.Close() terror.Log(errors.Trace(err)) break } clientConn := s.newConn(conn) err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { host, err := clientConn.PeerHost("") if err != nil { logutil.BgLogger().Error("get peer host failed", zap.Error(err)) terror.Log(clientConn.Close()) return errors.Trace(err) } err = authPlugin.OnConnectionEvent(context.Background(), plugin.PreAuth, &variable.ConnectionInfo{Host: host}) if err != nil { logutil.BgLogger().Info("do connection event failed", zap.Error(err)) terror.Log(clientConn.Close()) return errors.Trace(err) } } return nil }) if err != nil { continue } go s.onConn(clientConn) } err := s.listener.Close() terror.Log(errors.Trace(err)) s.listener = nil for { metrics.ServerEventCounter.WithLabelValues(metrics.EventHang).Inc() logutil.BgLogger().Error("listener stopped, waiting for manual kill.") time.Sleep(time.Minute) } } func (s *Server) shouldStopListener() bool { select { case <-s.stopListenerCh: return true default: return false } } // Close closes the server. func (s *Server) Close() { s.rwlock.Lock() defer s.rwlock.Unlock() if s.listener != nil { err := s.listener.Close() terror.Log(errors.Trace(err)) s.listener = nil } if s.socket != nil { err := s.socket.Close() terror.Log(errors.Trace(err)) s.socket = nil } if s.statusServer != nil { err := s.statusServer.Close() terror.Log(errors.Trace(err)) s.statusServer = nil } metrics.ServerEventCounter.WithLabelValues(metrics.EventClose).Inc() } // onConn runs in its own goroutine, handles queries from this connection. func (s *Server) onConn(conn *clientConn) { ctx := logutil.WithConnID(context.Background(), conn.connectionID) if err := conn.handshake(ctx); err != nil { // Some keep alive services will send request to TiDB and disconnect immediately. // So we only record metrics. metrics.HandShakeErrorCounter.Inc() err = conn.Close() terror.Log(errors.Trace(err)) return } logutil.Logger(ctx).Info("new connection", zap.String("remoteAddr", conn.bufReadConn.RemoteAddr().String())) defer func() { logutil.Logger(ctx).Info("connection closed") }() s.rwlock.Lock() s.clients[conn.connectionID] = conn connections := len(s.clients) s.rwlock.Unlock() metrics.ConnGauge.Set(float64(connections)) if plugin.IsEnable(plugin.Audit) { conn.ctx.GetSessionVars().ConnectionInfo = conn.connectInfo() } err := plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { sessionVars := conn.ctx.GetSessionVars() return authPlugin.OnConnectionEvent(context.Background(), plugin.Connected, sessionVars.ConnectionInfo) } return nil }) if err != nil { return } connectedTime := time.Now() conn.Run(ctx) err = plugin.ForeachPlugin(plugin.Audit, func(p *plugin.Plugin) error { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { sessionVars := conn.ctx.GetSessionVars() sessionVars.ConnectionInfo.Duration = float64(time.Since(connectedTime)) / float64(time.Millisecond) err := authPlugin.OnConnectionEvent(context.Background(), plugin.Disconnect, sessionVars.ConnectionInfo) if err != nil { logutil.BgLogger().Warn("do connection event failed", zap.String("plugin", authPlugin.Name), zap.Error(err)) } } return nil }) if err != nil { return } } func (cc *clientConn) connectInfo() *variable.ConnectionInfo { connType := "Socket" if cc.server.isUnixSocket() { connType = "UnixSocket" } else if cc.tlsConn != nil { connType = "SSL/TLS" } connInfo := &variable.ConnectionInfo{ ConnectionID: cc.connectionID, ConnectionType: connType, Host: cc.peerHost, ClientIP: cc.peerHost, ClientPort: cc.peerPort, ServerID: 1, ServerPort: int(cc.server.cfg.Port), User: cc.user, ServerOSLoginUser: osUser, OSVersion: osVersion, ServerVersion: mysql.TiDBReleaseVersion, SSLVersion: "v1.2.0", // for current go version PID: serverPID, DB: cc.dbname, } return connInfo } // ShowProcessList implements the SessionManager interface. func (s *Server) ShowProcessList() map[uint64]*util.ProcessInfo { s.rwlock.RLock() defer s.rwlock.RUnlock() rs := make(map[uint64]*util.ProcessInfo, len(s.clients)) for _, client := range s.clients { if atomic.LoadInt32(&client.status) == connStatusWaitShutdown { continue } if pi := client.ctx.ShowProcess(); pi != nil { rs[pi.ID] = pi } } return rs } // GetProcessInfo implements the SessionManager interface. func (s *Server) GetProcessInfo(id uint64) (*util.ProcessInfo, bool) { s.rwlock.RLock() conn, ok := s.clients[uint32(id)] s.rwlock.RUnlock() if !ok || atomic.LoadInt32(&conn.status) == connStatusWaitShutdown { return &util.ProcessInfo{}, false } return conn.ctx.ShowProcess(), ok } // Kill implements the SessionManager interface. func (s *Server) Kill(connectionID uint64, query bool) { logutil.BgLogger().Info("kill", zap.Uint64("connID", connectionID), zap.Bool("query", query)) metrics.ServerEventCounter.WithLabelValues(metrics.EventKill).Inc() s.rwlock.RLock() defer s.rwlock.RUnlock() conn, ok := s.clients[uint32(connectionID)] if !ok { return } if !query { // Mark the client connection status as WaitShutdown, when the goroutine detect // this, it will end the dispatch loop and exit. atomic.StoreInt32(&conn.status, connStatusWaitShutdown) } killConn(conn) } func killConn(conn *clientConn) { sessVars := conn.ctx.GetSessionVars() atomic.CompareAndSwapUint32(&sessVars.Killed, 0, 1) } // KillAllConnections kills all connections when server is not gracefully shutdown. func (s *Server) KillAllConnections() { logutil.BgLogger().Info("[server] kill all connections.") s.rwlock.RLock() defer s.rwlock.RUnlock() for _, conn := range s.clients { atomic.StoreInt32(&conn.status, connStatusShutdown) if err := conn.closeWithoutLock(); err != nil { terror.Log(err) } killConn(conn) } } var gracefulCloseConnectionsTimeout = 15 * time.Second // TryGracefulDown will try to gracefully close all connection first with timeout. if timeout, will close all connection directly. func (s *Server) TryGracefulDown() { ctx, cancel := context.WithTimeout(context.Background(), gracefulCloseConnectionsTimeout) defer cancel() done := make(chan struct{}) go func() { s.GracefulDown(ctx, done) }() select { case <-ctx.Done(): s.KillAllConnections() case <-done: return } } // GracefulDown waits all clients to close. func (s *Server) GracefulDown(ctx context.Context, done chan struct{}) { logutil.Logger(ctx).Info("[server] graceful shutdown.") metrics.ServerEventCounter.WithLabelValues(metrics.EventGracefulDown).Inc() count := s.ConnectionCount() for i := 0; count > 0; i++ { s.kickIdleConnection() count = s.ConnectionCount() if count == 0 { break } // Print information for every 30s. if i%30 == 0 { logutil.Logger(ctx).Info("graceful shutdown...", zap.Int("conn count", count)) } ticker := time.After(time.Second) select { case <-ctx.Done(): return case <-ticker: } } close(done) } func (s *Server) kickIdleConnection() { var conns []*clientConn s.rwlock.RLock() for _, cc := range s.clients { if cc.ShutdownOrNotify() { // Shutdowned conn will be closed by us, and notified conn will exist themselves. conns = append(conns, cc) } } s.rwlock.RUnlock() for _, cc := range conns { err := cc.Close() if err != nil { logutil.BgLogger().Error("close connection", zap.Error(err)) } } } // Server error codes. const ( codeUnknownFieldType = 1 codeInvalidPayloadLen = 2 codeInvalidSequence = 3 codeInvalidType = 4 codeNotAllowedCommand = 1148 codeAccessDenied = mysql.ErrAccessDenied codeMaxExecTimeExceeded = mysql.ErrMaxExecTimeExceeded ) func init() { serverMySQLErrCodes := map[terror.ErrCode]uint16{ codeNotAllowedCommand: mysql.ErrNotAllowedCommand, codeAccessDenied: mysql.ErrAccessDenied, codeMaxExecTimeExceeded: mysql.ErrMaxExecTimeExceeded, } terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes }