Merge pull request #221 from pingcap/shenli/dsn-auth

Support basic auth for tidb-server
This commit is contained in:
Shen Li
2015-09-22 11:47:17 +08:00
10 changed files with 130 additions and 71 deletions

View File

@ -51,7 +51,7 @@ func bootstrap(s Session) {
func initUserTable(s Session) {
mustExecute(s, CreateUserTable)
// Insert a default user with empty password.
mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", ""), ("127.0.0.1", "root", "");`)
mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", ""), ("127.0.0.1", "root", ""), ("::1", "root", "");`)
}
func mustExecute(s Session, sql string) {

View File

@ -18,8 +18,11 @@
package tidb
import (
"bytes"
"crypto/sha1"
"encoding/json"
"fmt"
"strings"
"sync/atomic"
"time"
@ -44,7 +47,6 @@ type Session interface {
LastInsertID() uint64 // Last inserted auto_increment id
AffectedRows() uint64 // Affected rows by lastest executed stmt
Execute(sql string) ([]rset.Recordset, error) // Execute a sql statement
SetUsername(name string) // Current user name
String() string // For debug
FinishTxn(rollback bool) error
// For execute prepare statement in binary protocol
@ -55,6 +57,7 @@ type Session interface {
SetClientCapability(uint32) // Set client capability flags
Close() error
Retry() error
Auth(user string, auth []byte, salt []byte) bool
}
var (
@ -96,13 +99,13 @@ func (h *stmtHistory) clone() *stmtHistory {
}
type session struct {
txn kv.Transaction // Current transaction
userName string
args []interface{} // Statment execution args, this should be cleaned up after exec
values map[fmt.Stringer]interface{}
store kv.Storage
sid int64
history stmtHistory
txn kv.Transaction // Current transaction
user string
args []interface{} // Statment execution args, this should be cleaned up after exec
values map[fmt.Stringer]interface{}
store kv.Storage
sid int64
history stmtHistory
}
func (s *session) Status() uint16 {
@ -117,10 +120,6 @@ func (s *session) AffectedRows() uint64 {
return variable.GetSessionVars(s).AffectedRows
}
func (s *session) SetUsername(name string) {
s.userName = name
}
func (s *session) resetHistory() {
s.ClearValue(forupdate.ForUpdateKey)
s.history.reset()
@ -157,7 +156,7 @@ func (s *session) FinishTxn(rollback bool) error {
func (s *session) String() string {
// TODO: how to print binded context in values appropriately?
data := map[string]interface{}{
"userName": s.userName,
"user": s.user,
"currDBName": db.GetCurrentSchema(s),
"sid": s.sid,
}
@ -383,6 +382,73 @@ func (s *session) Close() error {
return s.FinishTxn(true)
}
func calcPassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write(password)
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
func (s *session) Auth(user string, auth []byte, salt []byte) bool {
strs := strings.Split(user, "@")
if len(strs) != 2 {
log.Warnf("Invalid format for user: %s", user)
return false
}
// Get user password.
name := strs[0]
host := strs[1]
authSQL := fmt.Sprintf("SELECT Password FROM %s.%s WHERE User=\"%s\" and Host=\"%s\";", mysql.SystemDB, mysql.UserTable, name, host)
rs, err := s.Execute(authSQL)
if err != nil {
log.Warnf("Encounter error when auth user %s. Error: %v", user, err)
return false
}
if len(rs) == 0 {
return false
}
row, err := rs[0].Next()
if err != nil {
log.Warnf("Encounter error when auth user %s. Error: %v", user, err)
return false
}
if row == nil || len(row.Data) == 0 {
return false
}
pwd, ok := row.Data[0].(string)
if !ok {
return false
}
checkAuth := calcPassword(salt, []byte(pwd))
if !bytes.Equal(auth, checkAuth) {
return false
}
s.user = user
return true
}
// CreateSession creates a new session environment.
func CreateSession(store kv.Storage) (Session, error) {
s := &session{

View File

@ -42,8 +42,6 @@ func main() {
cfg := &server.Config{
Addr: fmt.Sprintf(":%s", *port),
User: "root",
Password: "",
LogLevel: *logLevel,
}

View File

@ -16,8 +16,6 @@ package server
// Config contains configuration options.
type Config struct {
Addr string `json:"addr" toml:"addr"`
User string `json:"user" toml:"user"`
Password string `json:"password" toml:"password"`
LogLevel string `json:"log_level" toml:"log_level"`
SkipAuth bool `json:"skip_auth" toml:"skip_auth"`
}

View File

@ -36,7 +36,6 @@ package server
import (
"bytes"
"crypto/sha1"
"encoding/binary"
"fmt"
"io"
@ -157,35 +156,6 @@ func (cc *clientConn) writePacket(data []byte) error {
return cc.pkg.writePacket(data)
}
func calcPassword(scramble, password []byte) []byte {
if len(password) == 0 {
return nil
}
// stage1Hash = SHA1(password)
crypt := sha1.New()
crypt.Write(password)
stage1 := crypt.Sum(nil)
// scrambleHash = SHA1(scramble + SHA1(stage1Hash))
// inner Hash
crypt.Reset()
crypt.Write(stage1)
hash := crypt.Sum(nil)
// outer Hash
crypt.Reset()
crypt.Write(scramble)
crypt.Write(hash)
scramble = crypt.Sum(nil)
// token = scrambleHash XOR stage1Hash
for i := range scramble {
scramble[i] ^= stage1[i]
}
return scramble
}
func (cc *clientConn) readHandshakeResponse() error {
data, err := cc.readPacket()
if err != nil {
@ -210,20 +180,31 @@ func (cc *clientConn) readHandshakeResponse() error {
authLen := int(data[pos])
pos++
auth := data[pos : pos+authLen]
checkAuth := calcPassword(cc.salt, []byte(cc.server.cfgGetPwd(cc.user)))
if !bytes.Equal(auth, checkAuth) && !cc.server.skipAuth() {
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.conn.RemoteAddr().String(), cc.user, "Yes"))
}
pos += authLen
if cc.capability|mysql.ClientConnectWithDB > 0 {
if len(data[pos:]) == 0 {
return nil
if len(data[pos:]) > 0 {
idx := bytes.IndexByte(data[pos:], 0)
cc.dbname = string(data[pos : pos+idx])
}
}
// Open session and do auth
cc.ctx, err = cc.server.driver.OpenCtx(cc.capability, uint8(cc.collation), cc.dbname)
if err != nil {
cc.Close()
return errors.Trace(err)
}
if !cc.server.skipAuth() {
// Do Auth
addr := cc.conn.RemoteAddr().String()
host, _, err1 := net.SplitHostPort(addr)
if err1 != nil {
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, addr, "Yes"))
}
user := fmt.Sprintf("%s@%s", cc.user, host)
if !cc.ctx.Auth(user, auth, cc.salt) {
return errors.Trace(mysql.NewDefaultError(mysql.ErAccessDeniedError, cc.user, host, "Yes"))
}
idx := bytes.IndexByte(data[pos:], 0)
cc.dbname = string(data[pos : pos+idx])
}
return nil
}

View File

@ -50,6 +50,9 @@ type IContext interface {
// Close closes the IContext.
Close() error
// Auth verifies user's authentication.
Auth(user string, auth []byte, salt []byte) bool
}
// IStatement is the interface to use a prepared statement.

View File

@ -178,6 +178,11 @@ func (tc *TiDBContext) Close() (err error) {
return tc.session.Close()
}
// Auth implements IContext Auth method.
func (tc *TiDBContext) Auth(user string, auth []byte, salt []byte) bool {
return tc.session.Auth(user, auth, salt)
}
// FieldList implements IContext FieldList method.
func (tc *TiDBContext) FieldList(table string) (colums []*ColumnInfo, err error) {
rs, err := tc.Execute("SELECT * FROM " + table + " LIMIT 0")

View File

@ -88,10 +88,6 @@ func (s *Server) skipAuth() bool {
return s.cfg.SkipAuth
}
func (s *Server) cfgGetPwd(user string) string {
return s.cfg.Password // TODO: support multiple users
}
// NewServer creates a new Server.
func NewServer(cfg *Config, driver IDriver) (*Server, error) {
s := &Server{
@ -147,13 +143,6 @@ func (s *Server) onConn(c net.Conn) {
c.Close()
return
}
conn.ctx, err = s.driver.OpenCtx(conn.capability, uint8(conn.collation), conn.dbname)
if err != nil {
log.Errorf("open ctx error %s", errors.ErrorStack(err))
c.Close()
return
}
defer func() {
log.Infof("close %s", conn)
}()

View File

@ -219,3 +219,20 @@ func runTestConcurrentUpdate(c *C) {
c.Assert(err, IsNil)
})
}
func runTestAuth(c *C) {
runTests(c, dsn, func(dbt *DBTest) {
dbt.mustExec(`CREATE USER 'test'@'127.0.0.1' IDENTIFIED BY '123';`)
dbt.mustExec(`CREATE USER 'test'@'localhost' IDENTIFIED BY '123';`)
dbt.mustExec(`CREATE USER 'test'@'::1' IDENTIFIED BY '123';`)
})
newDsn := "test:123@tcp(localhost:4001)/test?strict=true"
runTests(c, newDsn, func(dbt *DBTest) {
dbt.mustExec(`USE mysql;`)
})
db, err := sql.Open("mysql", "test:456@tcp(localhost:4001)/test?strict=true")
_, err = db.Query("USE mysql;")
c.Assert(err, NotNil, Commentf("Wrong password should be failed"))
db.Close()
}

View File

@ -33,8 +33,6 @@ func (ts *TidbTestSuite) SetUpSuite(c *C) {
ts.tidbdrv = NewTiDBDriver(store)
cfg := &Config{
Addr: ":4001",
User: "root",
Password: "",
LogLevel: "debug",
}
server, err := NewServer(cfg, ts.tidbdrv)
@ -69,3 +67,7 @@ func (ts *TidbTestSuite) TestPreparedString(c *C) {
func (ts *TidbTestSuite) TestConcurrentUpdate(c *C) {
runTestConcurrentUpdate(c)
}
func (ts *TidbTestSuite) TestAuth(c *C) {
runTestAuth(c)
}