// The MIT License (MIT) // // Copyright (c) 2014 wandoulabs // Copyright (c) 2014 siddontang // // Permission is hereby granted, free of charge, to any person obtaining a copy of // this software and associated documentation files (the "Software"), to deal in // the Software without restriction, including without limitation the rights to // use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of // the Software, and to permit persons to whom the Software is furnished to do so, // subject to the following conditions: // // The above copyright notice and this permission notice shall be included in all // copies or substantial portions of the Software. // Copyright 2015 PingCAP, Inc. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // See the License for the specific language governing permissions and // limitations under the License. package server import ( "context" "crypto/tls" "crypto/x509" "fmt" "io" "io/ioutil" "math/rand" "net" "net/http" // For pprof _ "net/http/pprof" "sync" "sync/atomic" "time" "github.com/blacktear23/go-proxyprotocol" "github.com/pingcap/errors" "github.com/pingcap/parser/auth" "github.com/pingcap/parser/mysql" "github.com/pingcap/parser/terror" "github.com/pingcap/tidb/config" "github.com/pingcap/tidb/metrics" "github.com/pingcap/tidb/plugin" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util" "github.com/pingcap/tidb/util/logutil" log "github.com/sirupsen/logrus" "go.uber.org/zap" ) var ( baseConnID uint32 ) var ( errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type") errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length") errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence") errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type") errNotAllowedCommand = terror.ClassServer.New(codeNotAllowedCommand, "the used command is not allowed with this TiDB version") errAccessDenied = terror.ClassServer.New(codeAccessDenied, mysql.MySQLErrName[mysql.ErrAccessDenied]) ) // DefaultCapability is the capability of the server when it is created using the default configuration. // When server is configured with SSL, the server will have extra capabilities compared to DefaultCapability. const defaultCapability = mysql.ClientLongPassword | mysql.ClientLongFlag | mysql.ClientConnectWithDB | mysql.ClientProtocol41 | mysql.ClientTransactions | mysql.ClientSecureConnection | mysql.ClientFoundRows | mysql.ClientMultiStatements | mysql.ClientMultiResults | mysql.ClientLocalFiles | mysql.ClientConnectAtts | mysql.ClientPluginAuth | mysql.ClientInteractive // Server is the MySQL protocol server type Server struct { cfg *config.Config tlsConfig *tls.Config driver IDriver listener net.Listener socket net.Listener rwlock *sync.RWMutex concurrentLimiter *TokenLimiter clients map[uint32]*clientConn capability uint32 // stopListenerCh is used when a critical error occurred, we don't want to exit the process, because there may be // a supervisor automatically restart it, then new client connection will be created, but we can't server it. // So we just stop the listener and store to force clients to chose other TiDB servers. stopListenerCh chan struct{} statusServer *http.Server } // ConnectionCount gets current connection count. func (s *Server) ConnectionCount() int { var cnt int s.rwlock.RLock() cnt = len(s.clients) s.rwlock.RUnlock() return cnt } func (s *Server) getToken() *Token { start := time.Now() tok := s.concurrentLimiter.Get() // Note that data smaller than one microsecond is ignored, because that case can be viewed as non-block. metrics.GetTokenDurationHistogram.Observe(float64(time.Since(start).Nanoseconds() / 1e3)) return tok } func (s *Server) releaseToken(token *Token) { s.concurrentLimiter.Put(token) } // newConn creates a new *clientConn from a net.Conn. // It allocates a connection ID and random salt data for authentication. func (s *Server) newConn(conn net.Conn) *clientConn { cc := newClientConn(s) if s.cfg.Performance.TCPKeepAlive { if tcpConn, ok := conn.(*net.TCPConn); ok { if err := tcpConn.SetKeepAlive(true); err != nil { log.Error("failed to set tcp keep alive option:", err) } } } cc.setConn(conn) cc.salt = util.RandomBuf(20) return cc } func (s *Server) isUnixSocket() bool { return s.cfg.Socket != "" } func (s *Server) forwardUnixSocketToTCP() { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) for { if s.listener == nil { return // server shutdown has started } if uconn, err := s.socket.Accept(); err == nil { log.Infof("server socket forwarding from [%s] to [%s]", s.cfg.Socket, addr) go s.handleForwardedConnection(uconn, addr) } else { if s.listener != nil { log.Errorf("server failed to forward from [%s] to [%s], err: %s", s.cfg.Socket, addr, err) } } } } func (s *Server) handleForwardedConnection(uconn net.Conn, addr string) { defer terror.Call(uconn.Close) if tconn, err := net.Dial("tcp", addr); err == nil { go func() { if _, err := io.Copy(uconn, tconn); err != nil { log.Warningf("copy server to socket failed: %s", err) } }() if _, err := io.Copy(tconn, uconn); err != nil { log.Warningf("socket forward copy failed: %s", err) } } else { log.Warningf("socket forward failed: could not connect to [%s], err: %s", addr, err) } } // NewServer creates a new Server. func NewServer(cfg *config.Config, driver IDriver) (*Server, error) { s := &Server{ cfg: cfg, driver: driver, concurrentLimiter: NewTokenLimiter(cfg.TokenLimit), rwlock: &sync.RWMutex{}, clients: make(map[uint32]*clientConn), stopListenerCh: make(chan struct{}, 1), } s.loadTLSCertificates() s.capability = defaultCapability if s.tlsConfig != nil { s.capability |= mysql.ClientSSL } var err error if s.cfg.Host != "" && s.cfg.Port != 0 { addr := fmt.Sprintf("%s:%d", s.cfg.Host, s.cfg.Port) if s.listener, err = net.Listen("tcp", addr); err == nil { log.Infof("Server is running MySQL Protocol at [%s]", addr) if cfg.Socket != "" { if s.socket, err = net.Listen("unix", s.cfg.Socket); err == nil { log.Infof("Server redirecting [%s] to [%s]", s.cfg.Socket, addr) go s.forwardUnixSocketToTCP() } } } } else if cfg.Socket != "" { if s.listener, err = net.Listen("unix", cfg.Socket); err == nil { log.Infof("Server is running MySQL Protocol through Socket [%s]", cfg.Socket) } } else { err = errors.New("Server not configured to listen on either -socket or -host and -port") } if cfg.ProxyProtocol.Networks != "" { pplistener, errProxy := proxyprotocol.NewListener(s.listener, cfg.ProxyProtocol.Networks, int(cfg.ProxyProtocol.HeaderTimeout)) if errProxy != nil { log.Error("ProxyProtocol Networks parameter invalid") return nil, errors.Trace(errProxy) } log.Infof("Server is running MySQL Protocol (through PROXY Protocol) at [%s]", s.cfg.Host) s.listener = pplistener } if err != nil { return nil, errors.Trace(err) } // Init rand seed for randomBuf() rand.Seed(time.Now().UTC().UnixNano()) return s, nil } func (s *Server) loadTLSCertificates() { defer func() { if s.tlsConfig != nil { log.Infof("Secure connection is enabled (client verification enabled = %v)", len(variable.SysVars["ssl_ca"].Value) > 0) variable.SysVars["have_openssl"].Value = "YES" variable.SysVars["have_ssl"].Value = "YES" variable.SysVars["ssl_cert"].Value = s.cfg.Security.SSLCert variable.SysVars["ssl_key"].Value = s.cfg.Security.SSLKey } else { log.Warn("Secure connection is NOT ENABLED") } }() if len(s.cfg.Security.SSLCert) == 0 || len(s.cfg.Security.SSLKey) == 0 { s.tlsConfig = nil return } tlsCert, err := tls.LoadX509KeyPair(s.cfg.Security.SSLCert, s.cfg.Security.SSLKey) if err != nil { log.Warn(errors.ErrorStack(err)) s.tlsConfig = nil return } // Try loading CA cert. clientAuthPolicy := tls.NoClientCert var certPool *x509.CertPool if len(s.cfg.Security.SSLCA) > 0 { caCert, err := ioutil.ReadFile(s.cfg.Security.SSLCA) if err != nil { log.Warn(errors.ErrorStack(err)) } else { certPool = x509.NewCertPool() if certPool.AppendCertsFromPEM(caCert) { clientAuthPolicy = tls.VerifyClientCertIfGiven } variable.SysVars["ssl_ca"].Value = s.cfg.Security.SSLCA } } s.tlsConfig = &tls.Config{ Certificates: []tls.Certificate{tlsCert}, ClientCAs: certPool, ClientAuth: clientAuthPolicy, MinVersion: 0, } } // Run runs the server. func (s *Server) Run() error { metrics.ServerEventCounter.WithLabelValues(metrics.EventStart).Inc() // Start HTTP API to report tidb info such as TPS. if s.cfg.Status.ReportStatus { s.startStatusHTTP() } for { conn, err := s.listener.Accept() if err != nil { if opErr, ok := err.(*net.OpError); ok { if opErr.Err.Error() == "use of closed network connection" { return nil } } // If we got PROXY protocol error, we should continue accept. if proxyprotocol.IsProxyProtocolError(err) { log.Errorf("PROXY protocol error: %s", err.Error()) continue } log.Errorf("accept error %s", err.Error()) return errors.Trace(err) } if s.shouldStopListener() { err = conn.Close() terror.Log(errors.Trace(err)) break } for _, p := range plugin.GetByKind(plugin.Audit) { authPlugin := plugin.DeclareAuditManifest(p.Manifest) if authPlugin.OnConnectionEvent != nil { host, err := getPeerHost(conn) if err != nil { log.Error(err) terror.Log(conn.Close()) continue } err = authPlugin.OnConnectionEvent(context.Background(), &auth.UserIdentity{Hostname: host}) if err != nil { log.Info(err) terror.Log(conn.Close()) continue } } } go s.onConn(conn) } err := s.listener.Close() terror.Log(errors.Trace(err)) s.listener = nil for { metrics.ServerEventCounter.WithLabelValues(metrics.EventHang).Inc() log.Errorf("listener stopped, waiting for manual kill.") time.Sleep(time.Minute) } } func getPeerHost(conn net.Conn) (string, error) { addr := conn.RemoteAddr().String() host, _, err := net.SplitHostPort(addr) if err != nil { return "", errors.Trace(err) } return host, nil } func (s *Server) shouldStopListener() bool { select { case <-s.stopListenerCh: return true default: return false } } // Close closes the server. func (s *Server) Close() { s.rwlock.Lock() defer s.rwlock.Unlock() if s.listener != nil { err := s.listener.Close() terror.Log(errors.Trace(err)) s.listener = nil } if s.socket != nil { err := s.socket.Close() terror.Log(errors.Trace(err)) s.socket = nil } if s.statusServer != nil { err := s.statusServer.Close() terror.Log(errors.Trace(err)) s.statusServer = nil } metrics.ServerEventCounter.WithLabelValues(metrics.EventClose).Inc() } // onConn runs in its own goroutine, handles queries from this connection. func (s *Server) onConn(c net.Conn) { conn := s.newConn(c) ctx := logutil.WithConnID(context.Background(), conn.connectionID) if err := conn.handshake(ctx); err != nil { // Some keep alive services will send request to TiDB and disconnect immediately. // So we only record metrics. metrics.HandShakeErrorCounter.Inc() err = c.Close() terror.Log(errors.Trace(err)) return } logutil.Logger(ctx).Info("new connection", zap.String("remoteAddr", c.RemoteAddr().String())) defer func() { logutil.Logger(ctx).Info("close connection") }() s.rwlock.Lock() s.clients[conn.connectionID] = conn connections := len(s.clients) s.rwlock.Unlock() metrics.ConnGauge.Set(float64(connections)) conn.Run(ctx) } // ShowProcessList implements the SessionManager interface. func (s *Server) ShowProcessList() map[uint64]util.ProcessInfo { s.rwlock.RLock() rs := make(map[uint64]util.ProcessInfo, len(s.clients)) for _, client := range s.clients { if atomic.LoadInt32(&client.status) == connStatusWaitShutdown { continue } pi := client.ctx.ShowProcess() rs[pi.ID] = pi } s.rwlock.RUnlock() return rs } // Kill implements the SessionManager interface. func (s *Server) Kill(connectionID uint64, query bool) { s.rwlock.Lock() defer s.rwlock.Unlock() log.Infof("[server] Kill connectionID %d, query %t]", connectionID, query) metrics.ServerEventCounter.WithLabelValues(metrics.EventKill).Inc() conn, ok := s.clients[uint32(connectionID)] if !ok { return } killConn(conn, query) } func killConn(conn *clientConn, query bool) { if !query { // Mark the client connection status as WaitShutdown, when the goroutine detect // this, it will end the dispatch loop and exit. atomic.StoreInt32(&conn.status, connStatusWaitShutdown) } conn.mu.RLock() cancelFunc := conn.mu.cancelFunc conn.mu.RUnlock() if cancelFunc != nil { cancelFunc() } } // KillAllConnections kills all connections when server is not gracefully shutdown. func (s *Server) KillAllConnections() { s.rwlock.Lock() defer s.rwlock.Unlock() log.Info("[server] kill all connections.") for _, conn := range s.clients { atomic.StoreInt32(&conn.status, connStatusShutdown) terror.Log(errors.Trace(conn.closeWithoutLock())) conn.mu.RLock() cancelFunc := conn.mu.cancelFunc conn.mu.RUnlock() if cancelFunc != nil { cancelFunc() } } } var gracefulCloseConnectionsTimeout = 15 * time.Second // TryGracefulDown will try to gracefully close all connection first with timeout. if timeout, will close all connection directly. func (s *Server) TryGracefulDown() { ctx, cancel := context.WithTimeout(context.Background(), gracefulCloseConnectionsTimeout) defer cancel() done := make(chan struct{}) go func() { s.GracefulDown(ctx, done) }() select { case <-ctx.Done(): s.KillAllConnections() case <-done: return } } // GracefulDown waits all clients to close. func (s *Server) GracefulDown(ctx context.Context, done chan struct{}) { log.Info("[server] graceful shutdown.") metrics.ServerEventCounter.WithLabelValues(metrics.EventGracefulDown).Inc() count := s.ConnectionCount() for i := 0; count > 0; i++ { s.kickIdleConnection() count = s.ConnectionCount() if count == 0 { break } // Print information for every 30s. if i%30 == 0 { log.Infof("graceful shutdown...connection count %d\n", count) } ticker := time.After(time.Second) select { case <-ctx.Done(): return case <-ticker: } } close(done) } func (s *Server) kickIdleConnection() { var conns []*clientConn s.rwlock.RLock() for _, cc := range s.clients { if cc.ShutdownOrNotify() { // Shutdowned conn will be closed by us, and notified conn will exist themselves. conns = append(conns, cc) } } s.rwlock.RUnlock() for _, cc := range conns { err := cc.Close() if err != nil { log.Errorf("close connection error: %s", err) } } } // Server error codes. const ( codeUnknownFieldType = 1 codeInvalidPayloadLen = 2 codeInvalidSequence = 3 codeInvalidType = 4 codeNotAllowedCommand = 1148 codeAccessDenied = mysql.ErrAccessDenied ) func init() { serverMySQLErrCodes := map[terror.ErrCode]uint16{ codeNotAllowedCommand: mysql.ErrNotAllowedCommand, codeAccessDenied: mysql.ErrAccessDenied, } terror.ErrClassToMySQLCodes[terror.ClassServer] = serverMySQLErrCodes }