diff --git a/bootstrap.go b/bootstrap.go index f3f558d26e..450fa11316 100644 --- a/bootstrap.go +++ b/bootstrap.go @@ -190,9 +190,7 @@ func doDMLWorks(s Session) { // Insert a default user with empty password. mustExecute(s, `INSERT INTO mysql.user VALUES - ("localhost", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y"), - ("127.0.0.1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y"), - ("::1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y");`) + ("%", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y")`) // Init global system variables table. values := make([]string, 0, len(variable.SysVars)) diff --git a/session.go b/session.go index 0b83dc5d99..6450c084cf 100644 --- a/session.go +++ b/session.go @@ -518,6 +518,21 @@ func (s *session) Close() error { return s.FinishTxn(true) } +func (s *session) getPassword(name, host string) (string, error) { + // Get password for name and host. + authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User='%s' and Host='%s';", mysql.SystemDB, mysql.UserTable, name, host) + pwd, err := s.getExecRet(s, authSQL) + if err == nil { + return pwd, nil + } else if !terror.ExecResultIsEmpty.Equal(err) { + return "", errors.Trace(err) + } + //Try to get user password for name with any host(%). + authSQL = fmt.Sprintf("SELECT Password FROM %s.%s WHERE User='%s' and Host='%%';", mysql.SystemDB, mysql.UserTable, name) + pwd, err = s.getExecRet(s, authSQL) + return pwd, errors.Trace(err) +} + func (s *session) Auth(user string, auth []byte, salt []byte) bool { strs := strings.Split(user, "@") if len(strs) != 2 { @@ -527,27 +542,7 @@ func (s *session) Auth(user string, auth []byte, salt []byte) bool { // Get user password. name := strs[0] host := strs[1] - authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User=\"%s\" and Host=\"%s\";", mysql.SystemDB, mysql.UserTable, name, host) - rs, err := s.Execute(authSQL) - if err != nil { - log.Warnf("Encounter error when auth user %s. Error: %v", user, err) - return false - } - if len(rs) == 0 { - return false - } - row, err := rs[0].Next() - if err != nil { - log.Warnf("Encounter error when auth user %s. Error: %v", user, err) - return false - } - if row == nil || len(row.Data) == 0 { - return false - } - pwd, ok := row.Data[0].(string) - if !ok { - return false - } + pwd, err := s.getPassword(name, host) hpwd, err := util.DecodePassword(pwd) if err != nil { log.Errorf("Decode password string error %v", err) diff --git a/session_test.go b/session_test.go index d9a709d7f8..7209a1aea4 100644 --- a/session_test.go +++ b/session_test.go @@ -723,11 +723,9 @@ func (s *testSessionSuite) TestBootstrap(c *C) { row, err := r.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) - match(c, row.Data, "localhost", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") - row, err = r.Next() - c.Assert(err, IsNil) - c.Assert(row, NotNil) - match(c, row.Data, "127.0.0.1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") + match(c, row.Data, "%", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") + + c.Assert(se.Auth("root@anyhost", []byte(""), []byte("")), IsTrue) mustExecSQL(c, se, "USE test;") // Check privilege tables. mustExecSQL(c, se, "SELECT * from mysql.db;") @@ -794,11 +792,7 @@ func (s *testSessionSuite) TestBootstrapWithError(c *C) { row, err := r.Next() c.Assert(err, IsNil) c.Assert(row, NotNil) - match(c, row.Data, "localhost", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") - row, err = r.Next() - c.Assert(err, IsNil) - c.Assert(row, NotNil) - match(c, row.Data, "127.0.0.1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") + match(c, row.Data, "%", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") mustExecSQL(c, se, "USE test;") // Check privilege tables. mustExecSQL(c, se, "SELECT * from mysql.db;") diff --git a/stmt/stmts/account_manage_test.go b/stmt/stmts/account_manage_test.go index af827c133c..d41e379917 100644 --- a/stmt/stmts/account_manage_test.go +++ b/stmt/stmts/account_manage_test.go @@ -52,8 +52,14 @@ func (s *testStmtSuite) TestCreateUserStmt(c *C) { } func (s *testStmtSuite) TestSetPwdStmt(c *C) { + createUserSQL := `CREATE USER 'testpwd'@'localhost' IDENTIFIED BY '';` tx := mustBegin(c, s.testDB) - rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`) + _, err := tx.Query(createUserSQL) + c.Assert(err, NotNil) + mustCommit(c, tx) + + tx = mustBegin(c, s.testDB) + rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="testpwd" and Host="localhost"`) c.Assert(err, IsNil) rows.Next() var pwd string @@ -64,11 +70,11 @@ func (s *testStmtSuite) TestSetPwdStmt(c *C) { mustCommit(c, tx) tx = mustBegin(c, s.testDB) - tx.Query(`SET PASSWORD FOR 'root'@'localhost' = 'password';`) + tx.Query(`SET PASSWORD FOR 'testpwd'@'localhost' = 'password';`) mustCommit(c, tx) tx = mustBegin(c, s.testDB) - rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`) + rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="testpwd" and Host="localhost"`) c.Assert(err, IsNil) rows.Next() rows.Scan(&pwd) diff --git a/util/auth.go b/util/auth.go index 2869613e6a..8531929656 100644 --- a/util/auth.go +++ b/util/auth.go @@ -55,6 +55,9 @@ func Sha1Hash(bs []byte) []byte { // EncodePassword converts plaintext password to hashed hex string. func EncodePassword(pwd string) string { + if len(pwd) == 0 { + return "" + } hash := Sha1Hash([]byte(pwd)) return hex.EncodeToString(hash) }