diff --git a/bootstrap.go b/bootstrap.go index b9340b439b..4ded29e59f 100644 --- a/bootstrap.go +++ b/bootstrap.go @@ -51,7 +51,7 @@ func bootstrap(s Session) { func initUserTable(s Session) { mustExecute(s, CreateUserTable) // Insert a default user with empty password. - mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", ""), ("127.0.0.1", "root", "");`) + mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", ""), ("127.0.0.1", "root", ""), ("::1", "root", "");`) } func mustExecute(s Session, sql string) { diff --git a/session.go b/session.go index 33ec018de8..4fff174146 100644 --- a/session.go +++ b/session.go @@ -18,8 +18,11 @@ package tidb import ( + "bytes" + "crypto/sha1" "encoding/json" "fmt" + "strings" "sync/atomic" "time" @@ -44,7 +47,6 @@ type Session interface { LastInsertID() uint64 // Last inserted auto_increment id AffectedRows() uint64 // Affected rows by lastest executed stmt Execute(sql string) ([]rset.Recordset, error) // Execute a sql statement - SetUsername(name string) // Current user name String() string // For debug FinishTxn(rollback bool) error // For execute prepare statement in binary protocol @@ -55,6 +57,7 @@ type Session interface { SetClientCapability(uint32) // Set client capability flags Close() error Retry() error + Auth(user string, auth []byte, salt []byte) bool } var ( @@ -96,13 +99,13 @@ func (h *stmtHistory) clone() *stmtHistory { } type session struct { - txn kv.Transaction // Current transaction - userName string - args []interface{} // Statment execution args, this should be cleaned up after exec - values map[fmt.Stringer]interface{} - store kv.Storage - sid int64 - history stmtHistory + txn kv.Transaction // Current transaction + user string + args []interface{} // Statment execution args, this should be cleaned up after exec + values map[fmt.Stringer]interface{} + store kv.Storage + sid int64 + history stmtHistory } func (s *session) Status() uint16 { @@ -117,10 +120,6 @@ func (s *session) AffectedRows() uint64 { return variable.GetSessionVars(s).AffectedRows } -func (s *session) SetUsername(name string) { - s.userName = name -} - func (s *session) resetHistory() { s.ClearValue(forupdate.ForUpdateKey) s.history.reset() @@ -157,7 +156,7 @@ func (s *session) FinishTxn(rollback bool) error { func (s *session) String() string { // TODO: how to print binded context in values appropriately? data := map[string]interface{}{ - "userName": s.userName, + "user": s.user, "currDBName": db.GetCurrentSchema(s), "sid": s.sid, } @@ -383,6 +382,73 @@ func (s *session) Close() error { return s.FinishTxn(true) } +func calcPassword(scramble, password []byte) []byte { + if len(password) == 0 { + return nil + } + + // stage1Hash = SHA1(password) + crypt := sha1.New() + crypt.Write(password) + stage1 := crypt.Sum(nil) + + // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) + // inner Hash + crypt.Reset() + crypt.Write(stage1) + hash := crypt.Sum(nil) + + // outer Hash + crypt.Reset() + crypt.Write(scramble) + crypt.Write(hash) + scramble = crypt.Sum(nil) + + // token = scrambleHash XOR stage1Hash + for i := range scramble { + scramble[i] ^= stage1[i] + } + return scramble +} + +func (s *session) Auth(user string, auth []byte, salt []byte) bool { + strs := strings.Split(user, "@") + if len(strs) != 2 { + log.Warnf("Invalid format for user: %s", user) + return false + } + // Get user password. + name := strs[0] + host := strs[1] + authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User=\"%s\" and Host=\"%s\";", mysql.SystemDB, mysql.UserTable, name, host) + rs, err := s.Execute(authSQL) + if err != nil { + log.Warnf("Encounter error when auth user %s. Error: %v", user, err) + return false + } + if len(rs) == 0 { + return false + } + row, err := rs[0].Next() + if err != nil { + log.Warnf("Encounter error when auth user %s. Error: %v", user, err) + return false + } + if row == nil || len(row.Data) == 0 { + return false + } + pwd, ok := row.Data[0].(string) + if !ok { + return false + } + checkAuth := calcPassword(salt, []byte(pwd)) + if !bytes.Equal(auth, checkAuth) { + return false + } + s.user = user + return true +} + // CreateSession creates a new session environment. func CreateSession(store kv.Storage) (Session, error) { s := &session{ diff --git a/tidb-server/main.go b/tidb-server/main.go index 1d5d4dbf9a..9456f46ee5 100644 --- a/tidb-server/main.go +++ b/tidb-server/main.go @@ -42,8 +42,6 @@ func main() { cfg := &server.Config{ Addr: fmt.Sprintf(":%s", *port), - User: "root", - Password: "", LogLevel: *logLevel, } diff --git a/tidb-server/server/config.go b/tidb-server/server/config.go index 6100039d1a..738583a153 100644 --- a/tidb-server/server/config.go +++ b/tidb-server/server/config.go @@ -16,8 +16,6 @@ package server // Config contains configuration options. type Config struct { Addr string `json:"addr" toml:"addr"` - User string `json:"user" toml:"user"` - Password string `json:"password" toml:"password"` LogLevel string `json:"log_level" toml:"log_level"` SkipAuth bool `json:"skip_auth" toml:"skip_auth"` } diff --git a/tidb-server/server/conn.go b/tidb-server/server/conn.go index 734fb611c0..a96b0cccf7 100644 --- a/tidb-server/server/conn.go +++ b/tidb-server/server/conn.go @@ -36,7 +36,6 @@ package server import ( "bytes" - "crypto/sha1" "encoding/binary" "fmt" "io" @@ -157,35 +156,6 @@ func (cc *clientConn) writePacket(data []byte) error { return cc.pkg.writePacket(data) } -func calcPassword(scramble, password []byte) []byte { - if len(password) == 0 { - return nil - } - - // stage1Hash = SHA1(password) - crypt := sha1.New() - crypt.Write(password) - stage1 := crypt.Sum(nil) - - // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) - // inner Hash - crypt.Reset() - crypt.Write(stage1) - hash := crypt.Sum(nil) - - // outer Hash - crypt.Reset() - crypt.Write(scramble) - crypt.Write(hash) - scramble = crypt.Sum(nil) - - // token = scrambleHash XOR stage1Hash - for i := range scramble { - scramble[i] ^= stage1[i] - } - return scramble -} - func (cc *clientConn) readHandshakeResponse() error { data, err := cc.readPacket() if err != nil { @@ -210,20 +180,31 @@ func (cc *clientConn) readHandshakeResponse() error { authLen := int(data[pos]) pos++ auth := data[pos : pos+authLen] - checkAuth := calcPassword(cc.salt, []byte(cc.server.cfgGetPwd(cc.user))) - if !bytes.Equal(auth, checkAuth) && !cc.server.skipAuth() { - return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.conn.RemoteAddr().String(), cc.user, "Yes")) - } - pos += authLen if cc.capability|mysql.ClientConnectWithDB > 0 { - if len(data[pos:]) == 0 { - return nil + if len(data[pos:]) > 0 { + idx := bytes.IndexByte(data[pos:], 0) + cc.dbname = string(data[pos : pos+idx]) + } + } + // Open session and do auth + cc.ctx, err = cc.server.driver.OpenCtx(cc.capability, uint8(cc.collation), cc.dbname) + if err != nil { + cc.Close() + return errors.Trace(err) + } + if !cc.server.skipAuth() { + // Do Auth + addr := cc.conn.RemoteAddr().String() + host, _, err1 := net.SplitHostPort(addr) + if err1 != nil { + return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, addr, "Yes")) + } + user := fmt.Sprintf("%s@%s", cc.user, host) + if !cc.ctx.Auth(user, auth, cc.salt) { + return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, host, "Yes")) } - idx := bytes.IndexByte(data[pos:], 0) - cc.dbname = string(data[pos : pos+idx]) } - return nil } diff --git a/tidb-server/server/driver.go b/tidb-server/server/driver.go index e481b8e47e..1f1c9a15ff 100644 --- a/tidb-server/server/driver.go +++ b/tidb-server/server/driver.go @@ -50,6 +50,9 @@ type IContext interface { // Close closes the IContext. Close() error + + // Auth verifies user's authentication. + Auth(user string, auth []byte, salt []byte) bool } // IStatement is the interface to use a prepared statement. diff --git a/tidb-server/server/driver_tidb.go b/tidb-server/server/driver_tidb.go index 25c3870eb2..1ea32908fa 100644 --- a/tidb-server/server/driver_tidb.go +++ b/tidb-server/server/driver_tidb.go @@ -178,6 +178,11 @@ func (tc *TiDBContext) Close() (err error) { return tc.session.Close() } +// Auth implements IContext Auth method. +func (tc *TiDBContext) Auth(user string, auth []byte, salt []byte) bool { + return tc.session.Auth(user, auth, salt) +} + // FieldList implements IContext FieldList method. func (tc *TiDBContext) FieldList(table string) (colums []*ColumnInfo, err error) { rs, err := tc.Execute("SELECT * FROM " + table + " LIMIT 0") diff --git a/tidb-server/server/server.go b/tidb-server/server/server.go index d162090544..efd05bc70f 100644 --- a/tidb-server/server/server.go +++ b/tidb-server/server/server.go @@ -88,10 +88,6 @@ func (s *Server) skipAuth() bool { return s.cfg.SkipAuth } -func (s *Server) cfgGetPwd(user string) string { - return s.cfg.Password // TODO: support multiple users -} - // NewServer creates a new Server. func NewServer(cfg *Config, driver IDriver) (*Server, error) { s := &Server{ @@ -147,13 +143,6 @@ func (s *Server) onConn(c net.Conn) { c.Close() return } - conn.ctx, err = s.driver.OpenCtx(conn.capability, uint8(conn.collation), conn.dbname) - if err != nil { - log.Errorf("open ctx error %s", errors.ErrorStack(err)) - c.Close() - return - } - defer func() { log.Infof("close %s", conn) }() diff --git a/tidb-server/server/server_test.go b/tidb-server/server/server_test.go index 63c12675fd..4e2badcd8c 100644 --- a/tidb-server/server/server_test.go +++ b/tidb-server/server/server_test.go @@ -219,3 +219,20 @@ func runTestConcurrentUpdate(c *C) { c.Assert(err, IsNil) }) } + +func runTestAuth(c *C) { + runTests(c, dsn, func(dbt *DBTest) { + dbt.mustExec(`CREATE USER 'test'@'127.0.0.1' IDENTIFIED BY '123';`) + dbt.mustExec(`CREATE USER 'test'@'localhost' IDENTIFIED BY '123';`) + dbt.mustExec(`CREATE USER 'test'@'::1' IDENTIFIED BY '123';`) + }) + newDsn := "test:123@tcp(localhost:4001)/test?strict=true" + runTests(c, newDsn, func(dbt *DBTest) { + dbt.mustExec(`USE mysql;`) + }) + + db, err := sql.Open("mysql", "test:456@tcp(localhost:4001)/test?strict=true") + _, err = db.Query("USE mysql;") + c.Assert(err, NotNil, Commentf("Wrong password should be failed")) + db.Close() +} diff --git a/tidb-server/server/tidb_test.go b/tidb-server/server/tidb_test.go index 9e5fea33ee..b5141f98e3 100644 --- a/tidb-server/server/tidb_test.go +++ b/tidb-server/server/tidb_test.go @@ -33,8 +33,6 @@ func (ts *TidbTestSuite) SetUpSuite(c *C) { ts.tidbdrv = NewTiDBDriver(store) cfg := &Config{ Addr: ":4001", - User: "root", - Password: "", LogLevel: "debug", } server, err := NewServer(cfg, ts.tidbdrv) @@ -69,3 +67,7 @@ func (ts *TidbTestSuite) TestPreparedString(c *C) { func (ts *TidbTestSuite) TestConcurrentUpdate(c *C) { runTestConcurrentUpdate(c) } + +func (ts *TidbTestSuite) TestAuth(c *C) { + runTestAuth(c) +}