257 lines
6.4 KiB
Go
257 lines
6.4 KiB
Go
// The MIT License (MIT)
|
|
//
|
|
// Copyright (c) 2014 wandoulabs
|
|
// Copyright (c) 2014 siddontang
|
|
//
|
|
// Permission is hereby granted, free of charge, to any person obtaining a copy of
|
|
// this software and associated documentation files (the "Software"), to deal in
|
|
// the Software without restriction, including without limitation the rights to
|
|
// use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
|
|
// the Software, and to permit persons to whom the Software is furnished to do so,
|
|
// subject to the following conditions:
|
|
//
|
|
// The above copyright notice and this permission notice shall be included in all
|
|
// copies or substantial portions of the Software.
|
|
|
|
// Copyright 2015 PingCAP, Inc.
|
|
//
|
|
// Licensed under the Apache License, Version 2.0 (the "License");
|
|
// you may not use this file except in compliance with the License.
|
|
// You may obtain a copy of the License at
|
|
//
|
|
// http://www.apache.org/licenses/LICENSE-2.0
|
|
//
|
|
// Unless required by applicable law or agreed to in writing, software
|
|
// distributed under the License is distributed on an "AS IS" BASIS,
|
|
// See the License for the specific language governing permissions and
|
|
// limitations under the License.
|
|
|
|
package server
|
|
|
|
import (
|
|
"encoding/json"
|
|
"math/rand"
|
|
"net"
|
|
"net/http"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
// For pprof
|
|
_ "net/http/pprof"
|
|
|
|
"github.com/juju/errors"
|
|
"github.com/ngaut/log"
|
|
"github.com/pingcap/tidb/mysql"
|
|
"github.com/pingcap/tidb/terror"
|
|
"github.com/pingcap/tidb/util/arena"
|
|
"github.com/pingcap/tidb/util/printer"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
)
|
|
|
|
var (
|
|
baseConnID uint32 = 10000
|
|
)
|
|
|
|
var (
|
|
errUnknownFieldType = terror.ClassServer.New(codeUnknownFieldType, "unknown field type")
|
|
errInvalidPayloadLen = terror.ClassServer.New(codeInvalidPayloadLen, "invalid payload length")
|
|
errInvalidSequence = terror.ClassServer.New(codeInvalidSequence, "invalid sequence")
|
|
errInvalidType = terror.ClassServer.New(codeInvalidType, "invalid type")
|
|
)
|
|
|
|
// Server is the MySQL protocol server
|
|
type Server struct {
|
|
cfg *Config
|
|
driver IDriver
|
|
listener net.Listener
|
|
rwlock *sync.RWMutex
|
|
concurrentLimiter *TokenLimiter
|
|
clients map[uint32]*clientConn
|
|
}
|
|
|
|
// ConnectionCount gets current connection count.
|
|
func (s *Server) ConnectionCount() int {
|
|
var cnt int
|
|
s.rwlock.RLock()
|
|
cnt = len(s.clients)
|
|
s.rwlock.RUnlock()
|
|
return cnt
|
|
}
|
|
|
|
func (s *Server) getToken() *Token {
|
|
return s.concurrentLimiter.Get()
|
|
}
|
|
|
|
func (s *Server) releaseToken(token *Token) {
|
|
s.concurrentLimiter.Put(token)
|
|
}
|
|
|
|
// Generate a random string using ASCII characters but avoid separator character.
|
|
// See https://github.com/mysql/mysql-server/blob/5.7/mysys_ssl/crypt_genhash_impl.cc#L435
|
|
func randomBuf(size int) []byte {
|
|
buf := make([]byte, size)
|
|
for i := 0; i < size; i++ {
|
|
buf[i] = byte(rand.Intn(127))
|
|
if buf[i] == 0 || buf[i] == byte('$') {
|
|
buf[i]++
|
|
}
|
|
}
|
|
return buf
|
|
}
|
|
|
|
// newConn creates a new *clientConn from a net.Conn.
|
|
// It allocates a connection ID and random salt data for authentication.
|
|
func (s *Server) newConn(conn net.Conn) *clientConn {
|
|
log.Info("newConn", conn.RemoteAddr().String())
|
|
cc := &clientConn{
|
|
conn: conn,
|
|
pkt: newPacketIO(conn),
|
|
server: s,
|
|
connectionID: atomic.AddUint32(&baseConnID, 1),
|
|
collation: mysql.DefaultCollationID,
|
|
alloc: arena.NewAllocator(32 * 1024),
|
|
}
|
|
cc.salt = randomBuf(20)
|
|
return cc
|
|
}
|
|
|
|
func (s *Server) skipAuth() bool {
|
|
return s.cfg.SkipAuth
|
|
}
|
|
|
|
// NewServer creates a new Server.
|
|
func NewServer(cfg *Config, driver IDriver) (*Server, error) {
|
|
s := &Server{
|
|
cfg: cfg,
|
|
driver: driver,
|
|
concurrentLimiter: NewTokenLimiter(100),
|
|
rwlock: &sync.RWMutex{},
|
|
clients: make(map[uint32]*clientConn),
|
|
}
|
|
|
|
var err error
|
|
if cfg.Socket != "" {
|
|
cfg.SkipAuth = true
|
|
s.listener, err = net.Listen("unix", cfg.Socket)
|
|
} else {
|
|
s.listener, err = net.Listen("tcp", s.cfg.Addr)
|
|
}
|
|
if err != nil {
|
|
return nil, errors.Trace(err)
|
|
}
|
|
|
|
// Init rand seed for randomBuf()
|
|
rand.Seed(time.Now().UTC().UnixNano())
|
|
log.Infof("Server run MySql Protocol Listen at [%s]", s.cfg.Addr)
|
|
return s, nil
|
|
}
|
|
|
|
// Run runs the server.
|
|
func (s *Server) Run() error {
|
|
|
|
// Start http api to report tidb info such as tps.
|
|
if s.cfg.ReportStatus {
|
|
s.startStatusHTTP()
|
|
}
|
|
|
|
for {
|
|
conn, err := s.listener.Accept()
|
|
if err != nil {
|
|
if opErr, ok := err.(*net.OpError); ok {
|
|
if opErr.Err.Error() == "use of closed network connection" {
|
|
return nil
|
|
}
|
|
}
|
|
log.Errorf("accept error %s", err.Error())
|
|
return errors.Trace(err)
|
|
}
|
|
|
|
go s.onConn(conn)
|
|
}
|
|
}
|
|
|
|
// Close closes the server.
|
|
func (s *Server) Close() {
|
|
s.rwlock.Lock()
|
|
defer s.rwlock.Unlock()
|
|
|
|
if s.listener != nil {
|
|
s.listener.Close()
|
|
s.listener = nil
|
|
}
|
|
}
|
|
|
|
// onConn runs in its own goroutine, handles queries from this connection.
|
|
func (s *Server) onConn(c net.Conn) {
|
|
conn := s.newConn(c)
|
|
if err := conn.handshake(); err != nil {
|
|
log.Errorf("handshake error %s", errors.ErrorStack(err))
|
|
c.Close()
|
|
return
|
|
}
|
|
defer func() {
|
|
log.Infof("close %s", conn)
|
|
}()
|
|
|
|
s.rwlock.Lock()
|
|
s.clients[conn.connectionID] = conn
|
|
connections := len(s.clients)
|
|
s.rwlock.Unlock()
|
|
connGauge.Set(float64(connections))
|
|
|
|
conn.Run()
|
|
}
|
|
|
|
var once sync.Once
|
|
|
|
const defaultStatusAddr = ":10080"
|
|
|
|
func (s *Server) startStatusHTTP() {
|
|
once.Do(func() {
|
|
go func() {
|
|
http.HandleFunc("/status", func(w http.ResponseWriter, req *http.Request) {
|
|
w.Header().Set("Content-Type", "application/json")
|
|
s := status{
|
|
Connections: s.ConnectionCount(),
|
|
Version: mysql.ServerVersion,
|
|
GitHash: printer.TiDBGitHash,
|
|
}
|
|
js, err := json.Marshal(s)
|
|
if err != nil {
|
|
w.WriteHeader(http.StatusInternalServerError)
|
|
log.Error("Encode json error", err)
|
|
} else {
|
|
w.Write(js)
|
|
}
|
|
|
|
})
|
|
// HTTP path for prometheus.
|
|
http.Handle("/metrics", prometheus.Handler())
|
|
addr := s.cfg.StatusAddr
|
|
if len(addr) == 0 {
|
|
addr = defaultStatusAddr
|
|
}
|
|
log.Infof("Listening on %v for status and metrics report.", addr)
|
|
err := http.ListenAndServe(addr, nil)
|
|
if err != nil {
|
|
log.Fatal(err)
|
|
}
|
|
}()
|
|
})
|
|
}
|
|
|
|
// TiDB status
|
|
type status struct {
|
|
Connections int `json:"connections"`
|
|
Version string `json:"version"`
|
|
GitHash string `json:"git_hash"`
|
|
}
|
|
|
|
// Server error codes.
|
|
const (
|
|
codeUnknownFieldType = 1
|
|
codeInvalidPayloadLen = 2
|
|
codeInvalidSequence = 3
|
|
codeInvalidType = 4
|
|
)
|