From ccdd43275efdec90cc06d83fdbf56adde2cc4c8e Mon Sep 17 00:00:00 2001 From: Yujie Xia Date: Thu, 10 Mar 2022 15:19:50 +0800 Subject: [PATCH] lightning: support base64 encoding of password (#31195) close pingcap/tidb#31194 --- br/pkg/lightning/checkpoints/checkpoints.go | 2 +- br/pkg/lightning/common/util.go | 52 +++++++++++++++- br/pkg/lightning/common/util_test.go | 69 +++++++++++++++++++++ 3 files changed, 119 insertions(+), 4 deletions(-) diff --git a/br/pkg/lightning/checkpoints/checkpoints.go b/br/pkg/lightning/checkpoints/checkpoints.go index d1b59a12e4..6f9ab545fa 100644 --- a/br/pkg/lightning/checkpoints/checkpoints.go +++ b/br/pkg/lightning/checkpoints/checkpoints.go @@ -515,7 +515,7 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) { switch cfg.Checkpoint.Driver { case config.CheckpointDriverMySQL: - db, err := sql.Open("mysql", cfg.Checkpoint.DSN) + db, err := common.ConnectMySQL(cfg.Checkpoint.DSN) if err != nil { return nil, errors.Trace(err) } diff --git a/br/pkg/lightning/common/util.go b/br/pkg/lightning/common/util.go index c24ee74fe3..054edcbb17 100644 --- a/br/pkg/lightning/common/util.go +++ b/br/pkg/lightning/common/util.go @@ -17,6 +17,7 @@ package common import ( "context" "database/sql" + "encoding/base64" "encoding/json" "fmt" "io" @@ -27,9 +28,12 @@ import ( "syscall" "time" + "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/lightning/log" "github.com/pingcap/tidb/br/pkg/utils" + tmysql "github.com/pingcap/tidb/errno" "github.com/pingcap/tidb/parser/model" "go.uber.org/zap" ) @@ -64,13 +68,55 @@ func (param *MySQLConnectParam) ToDSN() string { return dsn } -func (param *MySQLConnectParam) Connect() (*sql.DB, error) { - db, err := sql.Open("mysql", param.ToDSN()) +func tryConnectMySQL(dsn string) (*sql.DB, error) { + driverName := "mysql" + failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) { + driverName = val.(string) + }) + db, err := sql.Open(driverName, dsn) if err != nil { return nil, errors.Trace(err) } + if err = db.Ping(); err != nil { + _ = db.Close() + return nil, errors.Trace(err) + } + return db, nil +} - return db, errors.Trace(db.Ping()) +// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding, +// we will try to connect MySQL with the base64 decoding of the password. +func ConnectMySQL(dsn string) (*sql.DB, error) { + cfg, err := mysql.ParseDSN(dsn) + if err != nil { + return nil, errors.Trace(err) + } + // Try plain password first. + db, firstErr := tryConnectMySQL(dsn) + if firstErr == nil { + return db, nil + } + // If access is denied and password is encoded by base64, try the decoded string as well. + if mysqlErr, ok := errors.Cause(firstErr).(*mysql.MySQLError); ok && mysqlErr.Number == tmysql.ErrAccessDenied { + // If password is encoded by base64, try the decoded string as well. + if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd { + cfg.Passwd = string(password) + db, err = tryConnectMySQL(cfg.FormatDSN()) + if err == nil { + return db, nil + } + } + } + // If we can't connect successfully, return the first error. + return nil, errors.Trace(firstErr) +} + +func (param *MySQLConnectParam) Connect() (*sql.DB, error) { + db, err := ConnectMySQL(param.ToDSN()) + if err != nil { + return nil, errors.Trace(err) + } + return db, nil } // IsDirExists checks if dir exists. diff --git a/br/pkg/lightning/common/util_test.go b/br/pkg/lightning/common/util_test.go index d06daa00b1..3b6ff5f92e 100644 --- a/br/pkg/lightning/common/util_test.go +++ b/br/pkg/lightning/common/util_test.go @@ -16,17 +16,26 @@ package common_test import ( "context" + "database/sql" + "database/sql/driver" + "encoding/base64" "encoding/json" + "fmt" "io" + "math/rand" "net/http" "net/http/httptest" + "strconv" "testing" "time" "github.com/DATA-DOG/go-sqlmock" + "github.com/go-sql-driver/mysql" "github.com/pingcap/errors" + "github.com/pingcap/failpoint" "github.com/pingcap/tidb/br/pkg/lightning/common" "github.com/pingcap/tidb/br/pkg/lightning/log" + tmysql "github.com/pingcap/tidb/errno" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -92,6 +101,66 @@ func TestToDSN(t *testing.T) { require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN()) } +type mockDriver struct { + driver.Driver + plainPsw string +} + +func (m *mockDriver) Open(dsn string) (driver.Conn, error) { + cfg, err := mysql.ParseDSN(dsn) + if err != nil { + return nil, err + } + accessDenied := cfg.Passwd != m.plainPsw + return &mockConn{accessDenied: accessDenied}, nil +} + +type mockConn struct { + driver.Conn + driver.Pinger + accessDenied bool +} + +func (c *mockConn) Ping(ctx context.Context) error { + if c.accessDenied { + return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"} + } + return nil +} + +func (c *mockConn) Close() error { + return nil +} + +func TestConnect(t *testing.T) { + plainPsw := "dQAUoDiyb1ucWZk7" + driverName := "mysql-mock-" + strconv.Itoa(rand.Int()) + sql.Register(driverName, &mockDriver{plainPsw: plainPsw}) + + require.NoError(t, failpoint.Enable( + "github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver", + fmt.Sprintf("return(\"%s\")", driverName))) + defer func() { + require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver")) + }() + + param := common.MySQLConnectParam{ + Host: "127.0.0.1", + Port: 4000, + User: "root", + Password: plainPsw, + SQLMode: "strict", + MaxAllowedPacket: 1234, + } + db, err := param.Connect() + require.NoError(t, err) + require.NoError(t, db.Close()) + param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw)) + db, err = param.Connect() + require.NoError(t, err) + require.NoError(t, db.Close()) +} + func TestIsContextCanceledError(t *testing.T) { require.True(t, common.IsContextCanceledError(context.Canceled)) require.False(t, common.IsContextCanceledError(io.EOF))