Merge pull request #221 from pingcap/shenli/dsn-auth
Support basic auth for tidb-server
This commit is contained in:
@ -51,7 +51,7 @@ func bootstrap(s Session) {
|
||||
func initUserTable(s Session) {
|
||||
mustExecute(s, CreateUserTable)
|
||||
// Insert a default user with empty password.
|
||||
mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", ""), ("127.0.0.1", "root", "");`)
|
||||
mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", ""), ("127.0.0.1", "root", ""), ("::1", "root", "");`)
|
||||
}
|
||||
|
||||
func mustExecute(s Session, sql string) {
|
||||
|
||||
92
session.go
92
session.go
@ -18,8 +18,11 @@
|
||||
package tidb
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
@ -44,7 +47,6 @@ type Session interface {
|
||||
LastInsertID() uint64 // Last inserted auto_increment id
|
||||
AffectedRows() uint64 // Affected rows by lastest executed stmt
|
||||
Execute(sql string) ([]rset.Recordset, error) // Execute a sql statement
|
||||
SetUsername(name string) // Current user name
|
||||
String() string // For debug
|
||||
FinishTxn(rollback bool) error
|
||||
// For execute prepare statement in binary protocol
|
||||
@ -55,6 +57,7 @@ type Session interface {
|
||||
SetClientCapability(uint32) // Set client capability flags
|
||||
Close() error
|
||||
Retry() error
|
||||
Auth(user string, auth []byte, salt []byte) bool
|
||||
}
|
||||
|
||||
var (
|
||||
@ -96,13 +99,13 @@ func (h *stmtHistory) clone() *stmtHistory {
|
||||
}
|
||||
|
||||
type session struct {
|
||||
txn kv.Transaction // Current transaction
|
||||
userName string
|
||||
args []interface{} // Statment execution args, this should be cleaned up after exec
|
||||
values map[fmt.Stringer]interface{}
|
||||
store kv.Storage
|
||||
sid int64
|
||||
history stmtHistory
|
||||
txn kv.Transaction // Current transaction
|
||||
user string
|
||||
args []interface{} // Statment execution args, this should be cleaned up after exec
|
||||
values map[fmt.Stringer]interface{}
|
||||
store kv.Storage
|
||||
sid int64
|
||||
history stmtHistory
|
||||
}
|
||||
|
||||
func (s *session) Status() uint16 {
|
||||
@ -117,10 +120,6 @@ func (s *session) AffectedRows() uint64 {
|
||||
return variable.GetSessionVars(s).AffectedRows
|
||||
}
|
||||
|
||||
func (s *session) SetUsername(name string) {
|
||||
s.userName = name
|
||||
}
|
||||
|
||||
func (s *session) resetHistory() {
|
||||
s.ClearValue(forupdate.ForUpdateKey)
|
||||
s.history.reset()
|
||||
@ -157,7 +156,7 @@ func (s *session) FinishTxn(rollback bool) error {
|
||||
func (s *session) String() string {
|
||||
// TODO: how to print binded context in values appropriately?
|
||||
data := map[string]interface{}{
|
||||
"userName": s.userName,
|
||||
"user": s.user,
|
||||
"currDBName": db.GetCurrentSchema(s),
|
||||
"sid": s.sid,
|
||||
}
|
||||
@ -383,6 +382,73 @@ func (s *session) Close() error {
|
||||
return s.FinishTxn(true)
|
||||
}
|
||||
|
||||
func calcPassword(scramble, password []byte) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stage1Hash = SHA1(password)
|
||||
crypt := sha1.New()
|
||||
crypt.Write(password)
|
||||
stage1 := crypt.Sum(nil)
|
||||
|
||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
||||
// inner Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(stage1)
|
||||
hash := crypt.Sum(nil)
|
||||
|
||||
// outer Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(scramble)
|
||||
crypt.Write(hash)
|
||||
scramble = crypt.Sum(nil)
|
||||
|
||||
// token = scrambleHash XOR stage1Hash
|
||||
for i := range scramble {
|
||||
scramble[i] ^= stage1[i]
|
||||
}
|
||||
return scramble
|
||||
}
|
||||
|
||||
func (s *session) Auth(user string, auth []byte, salt []byte) bool {
|
||||
strs := strings.Split(user, "@")
|
||||
if len(strs) != 2 {
|
||||
log.Warnf("Invalid format for user: %s", user)
|
||||
return false
|
||||
}
|
||||
// Get user password.
|
||||
name := strs[0]
|
||||
host := strs[1]
|
||||
authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User=\"%s\" and Host=\"%s\";", mysql.SystemDB, mysql.UserTable, name, host)
|
||||
rs, err := s.Execute(authSQL)
|
||||
if err != nil {
|
||||
log.Warnf("Encounter error when auth user %s. Error: %v", user, err)
|
||||
return false
|
||||
}
|
||||
if len(rs) == 0 {
|
||||
return false
|
||||
}
|
||||
row, err := rs[0].Next()
|
||||
if err != nil {
|
||||
log.Warnf("Encounter error when auth user %s. Error: %v", user, err)
|
||||
return false
|
||||
}
|
||||
if row == nil || len(row.Data) == 0 {
|
||||
return false
|
||||
}
|
||||
pwd, ok := row.Data[0].(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
checkAuth := calcPassword(salt, []byte(pwd))
|
||||
if !bytes.Equal(auth, checkAuth) {
|
||||
return false
|
||||
}
|
||||
s.user = user
|
||||
return true
|
||||
}
|
||||
|
||||
// CreateSession creates a new session environment.
|
||||
func CreateSession(store kv.Storage) (Session, error) {
|
||||
s := &session{
|
||||
|
||||
@ -42,8 +42,6 @@ func main() {
|
||||
|
||||
cfg := &server.Config{
|
||||
Addr: fmt.Sprintf(":%s", *port),
|
||||
User: "root",
|
||||
Password: "",
|
||||
LogLevel: *logLevel,
|
||||
}
|
||||
|
||||
|
||||
@ -16,8 +16,6 @@ package server
|
||||
// Config contains configuration options.
|
||||
type Config struct {
|
||||
Addr string `json:"addr" toml:"addr"`
|
||||
User string `json:"user" toml:"user"`
|
||||
Password string `json:"password" toml:"password"`
|
||||
LogLevel string `json:"log_level" toml:"log_level"`
|
||||
SkipAuth bool `json:"skip_auth" toml:"skip_auth"`
|
||||
}
|
||||
|
||||
@ -36,7 +36,6 @@ package server
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -157,35 +156,6 @@ func (cc *clientConn) writePacket(data []byte) error {
|
||||
return cc.pkg.writePacket(data)
|
||||
}
|
||||
|
||||
func calcPassword(scramble, password []byte) []byte {
|
||||
if len(password) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
// stage1Hash = SHA1(password)
|
||||
crypt := sha1.New()
|
||||
crypt.Write(password)
|
||||
stage1 := crypt.Sum(nil)
|
||||
|
||||
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
|
||||
// inner Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(stage1)
|
||||
hash := crypt.Sum(nil)
|
||||
|
||||
// outer Hash
|
||||
crypt.Reset()
|
||||
crypt.Write(scramble)
|
||||
crypt.Write(hash)
|
||||
scramble = crypt.Sum(nil)
|
||||
|
||||
// token = scrambleHash XOR stage1Hash
|
||||
for i := range scramble {
|
||||
scramble[i] ^= stage1[i]
|
||||
}
|
||||
return scramble
|
||||
}
|
||||
|
||||
func (cc *clientConn) readHandshakeResponse() error {
|
||||
data, err := cc.readPacket()
|
||||
if err != nil {
|
||||
@ -210,20 +180,31 @@ func (cc *clientConn) readHandshakeResponse() error {
|
||||
authLen := int(data[pos])
|
||||
pos++
|
||||
auth := data[pos : pos+authLen]
|
||||
checkAuth := calcPassword(cc.salt, []byte(cc.server.cfgGetPwd(cc.user)))
|
||||
if !bytes.Equal(auth, checkAuth) && !cc.server.skipAuth() {
|
||||
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.conn.RemoteAddr().String(), cc.user, "Yes"))
|
||||
}
|
||||
|
||||
pos += authLen
|
||||
if cc.capability|mysql.ClientConnectWithDB > 0 {
|
||||
if len(data[pos:]) == 0 {
|
||||
return nil
|
||||
if len(data[pos:]) > 0 {
|
||||
idx := bytes.IndexByte(data[pos:], 0)
|
||||
cc.dbname = string(data[pos : pos+idx])
|
||||
}
|
||||
}
|
||||
// Open session and do auth
|
||||
cc.ctx, err = cc.server.driver.OpenCtx(cc.capability, uint8(cc.collation), cc.dbname)
|
||||
if err != nil {
|
||||
cc.Close()
|
||||
return errors.Trace(err)
|
||||
}
|
||||
if !cc.server.skipAuth() {
|
||||
// Do Auth
|
||||
addr := cc.conn.RemoteAddr().String()
|
||||
host, _, err1 := net.SplitHostPort(addr)
|
||||
if err1 != nil {
|
||||
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, addr, "Yes"))
|
||||
}
|
||||
user := fmt.Sprintf("%s@%s", cc.user, host)
|
||||
if !cc.ctx.Auth(user, auth, cc.salt) {
|
||||
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, host, "Yes"))
|
||||
}
|
||||
idx := bytes.IndexByte(data[pos:], 0)
|
||||
cc.dbname = string(data[pos : pos+idx])
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@ -50,6 +50,9 @@ type IContext interface {
|
||||
|
||||
// Close closes the IContext.
|
||||
Close() error
|
||||
|
||||
// Auth verifies user's authentication.
|
||||
Auth(user string, auth []byte, salt []byte) bool
|
||||
}
|
||||
|
||||
// IStatement is the interface to use a prepared statement.
|
||||
|
||||
@ -178,6 +178,11 @@ func (tc *TiDBContext) Close() (err error) {
|
||||
return tc.session.Close()
|
||||
}
|
||||
|
||||
// Auth implements IContext Auth method.
|
||||
func (tc *TiDBContext) Auth(user string, auth []byte, salt []byte) bool {
|
||||
return tc.session.Auth(user, auth, salt)
|
||||
}
|
||||
|
||||
// FieldList implements IContext FieldList method.
|
||||
func (tc *TiDBContext) FieldList(table string) (colums []*ColumnInfo, err error) {
|
||||
rs, err := tc.Execute("SELECT * FROM " + table + " LIMIT 0")
|
||||
|
||||
@ -88,10 +88,6 @@ func (s *Server) skipAuth() bool {
|
||||
return s.cfg.SkipAuth
|
||||
}
|
||||
|
||||
func (s *Server) cfgGetPwd(user string) string {
|
||||
return s.cfg.Password // TODO: support multiple users
|
||||
}
|
||||
|
||||
// NewServer creates a new Server.
|
||||
func NewServer(cfg *Config, driver IDriver) (*Server, error) {
|
||||
s := &Server{
|
||||
@ -147,13 +143,6 @@ func (s *Server) onConn(c net.Conn) {
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
conn.ctx, err = s.driver.OpenCtx(conn.capability, uint8(conn.collation), conn.dbname)
|
||||
if err != nil {
|
||||
log.Errorf("open ctx error %s", errors.ErrorStack(err))
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
log.Infof("close %s", conn)
|
||||
}()
|
||||
|
||||
@ -219,3 +219,20 @@ func runTestConcurrentUpdate(c *C) {
|
||||
c.Assert(err, IsNil)
|
||||
})
|
||||
}
|
||||
|
||||
func runTestAuth(c *C) {
|
||||
runTests(c, dsn, func(dbt *DBTest) {
|
||||
dbt.mustExec(`CREATE USER 'test'@'127.0.0.1' IDENTIFIED BY '123';`)
|
||||
dbt.mustExec(`CREATE USER 'test'@'localhost' IDENTIFIED BY '123';`)
|
||||
dbt.mustExec(`CREATE USER 'test'@'::1' IDENTIFIED BY '123';`)
|
||||
})
|
||||
newDsn := "test:123@tcp(localhost:4001)/test?strict=true"
|
||||
runTests(c, newDsn, func(dbt *DBTest) {
|
||||
dbt.mustExec(`USE mysql;`)
|
||||
})
|
||||
|
||||
db, err := sql.Open("mysql", "test:456@tcp(localhost:4001)/test?strict=true")
|
||||
_, err = db.Query("USE mysql;")
|
||||
c.Assert(err, NotNil, Commentf("Wrong password should be failed"))
|
||||
db.Close()
|
||||
}
|
||||
|
||||
@ -33,8 +33,6 @@ func (ts *TidbTestSuite) SetUpSuite(c *C) {
|
||||
ts.tidbdrv = NewTiDBDriver(store)
|
||||
cfg := &Config{
|
||||
Addr: ":4001",
|
||||
User: "root",
|
||||
Password: "",
|
||||
LogLevel: "debug",
|
||||
}
|
||||
server, err := NewServer(cfg, ts.tidbdrv)
|
||||
@ -69,3 +67,7 @@ func (ts *TidbTestSuite) TestPreparedString(c *C) {
|
||||
func (ts *TidbTestSuite) TestConcurrentUpdate(c *C) {
|
||||
runTestConcurrentUpdate(c)
|
||||
}
|
||||
|
||||
func (ts *TidbTestSuite) TestAuth(c *C) {
|
||||
runTestAuth(c)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user