lightning: support base64 encoding of password (#31195)
close pingcap/tidb#31194
This commit is contained in:
@ -515,7 +515,7 @@ func OpenCheckpointsDB(ctx context.Context, cfg *config.Config) (DB, error) {
|
||||
|
||||
switch cfg.Checkpoint.Driver {
|
||||
case config.CheckpointDriverMySQL:
|
||||
db, err := sql.Open("mysql", cfg.Checkpoint.DSN)
|
||||
db, err := common.ConnectMySQL(cfg.Checkpoint.DSN)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -17,6 +17,7 @@ package common
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@ -27,9 +28,12 @@ import (
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/failpoint"
|
||||
"github.com/pingcap/tidb/br/pkg/lightning/log"
|
||||
"github.com/pingcap/tidb/br/pkg/utils"
|
||||
tmysql "github.com/pingcap/tidb/errno"
|
||||
"github.com/pingcap/tidb/parser/model"
|
||||
"go.uber.org/zap"
|
||||
)
|
||||
@ -64,13 +68,55 @@ func (param *MySQLConnectParam) ToDSN() string {
|
||||
return dsn
|
||||
}
|
||||
|
||||
func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
|
||||
db, err := sql.Open("mysql", param.ToDSN())
|
||||
func tryConnectMySQL(dsn string) (*sql.DB, error) {
|
||||
driverName := "mysql"
|
||||
failpoint.Inject("MockMySQLDriver", func(val failpoint.Value) {
|
||||
driverName = val.(string)
|
||||
})
|
||||
db, err := sql.Open(driverName, dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
if err = db.Ping(); err != nil {
|
||||
_ = db.Close()
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
return db, errors.Trace(db.Ping())
|
||||
// ConnectMySQL connects MySQL with the dsn. If access is denied and the password is a valid base64 encoding,
|
||||
// we will try to connect MySQL with the base64 decoding of the password.
|
||||
func ConnectMySQL(dsn string) (*sql.DB, error) {
|
||||
cfg, err := mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
// Try plain password first.
|
||||
db, firstErr := tryConnectMySQL(dsn)
|
||||
if firstErr == nil {
|
||||
return db, nil
|
||||
}
|
||||
// If access is denied and password is encoded by base64, try the decoded string as well.
|
||||
if mysqlErr, ok := errors.Cause(firstErr).(*mysql.MySQLError); ok && mysqlErr.Number == tmysql.ErrAccessDenied {
|
||||
// If password is encoded by base64, try the decoded string as well.
|
||||
if password, decodeErr := base64.StdEncoding.DecodeString(cfg.Passwd); decodeErr == nil && string(password) != cfg.Passwd {
|
||||
cfg.Passwd = string(password)
|
||||
db, err = tryConnectMySQL(cfg.FormatDSN())
|
||||
if err == nil {
|
||||
return db, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
// If we can't connect successfully, return the first error.
|
||||
return nil, errors.Trace(firstErr)
|
||||
}
|
||||
|
||||
func (param *MySQLConnectParam) Connect() (*sql.DB, error) {
|
||||
db, err := ConnectMySQL(param.ToDSN())
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return db, nil
|
||||
}
|
||||
|
||||
// IsDirExists checks if dir exists.
|
||||
|
||||
@ -16,17 +16,26 @@ package common_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/rand"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/go-sql-driver/mysql"
|
||||
"github.com/pingcap/errors"
|
||||
"github.com/pingcap/failpoint"
|
||||
"github.com/pingcap/tidb/br/pkg/lightning/common"
|
||||
"github.com/pingcap/tidb/br/pkg/lightning/log"
|
||||
tmysql "github.com/pingcap/tidb/errno"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
@ -92,6 +101,66 @@ func TestToDSN(t *testing.T) {
|
||||
require.Equal(t, "root:123456@tcp(127.0.0.1:4000)/?charset=utf8mb4&sql_mode='strict'&maxAllowedPacket=1234&tls=cluster&tidb_distsql_scan_concurrency='1'", param.ToDSN())
|
||||
}
|
||||
|
||||
type mockDriver struct {
|
||||
driver.Driver
|
||||
plainPsw string
|
||||
}
|
||||
|
||||
func (m *mockDriver) Open(dsn string) (driver.Conn, error) {
|
||||
cfg, err := mysql.ParseDSN(dsn)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
accessDenied := cfg.Passwd != m.plainPsw
|
||||
return &mockConn{accessDenied: accessDenied}, nil
|
||||
}
|
||||
|
||||
type mockConn struct {
|
||||
driver.Conn
|
||||
driver.Pinger
|
||||
accessDenied bool
|
||||
}
|
||||
|
||||
func (c *mockConn) Ping(ctx context.Context) error {
|
||||
if c.accessDenied {
|
||||
return &mysql.MySQLError{Number: tmysql.ErrAccessDenied, Message: "access denied"}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mockConn) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
plainPsw := "dQAUoDiyb1ucWZk7"
|
||||
driverName := "mysql-mock-" + strconv.Itoa(rand.Int())
|
||||
sql.Register(driverName, &mockDriver{plainPsw: plainPsw})
|
||||
|
||||
require.NoError(t, failpoint.Enable(
|
||||
"github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver",
|
||||
fmt.Sprintf("return(\"%s\")", driverName)))
|
||||
defer func() {
|
||||
require.NoError(t, failpoint.Disable("github.com/pingcap/tidb/br/pkg/lightning/common/MockMySQLDriver"))
|
||||
}()
|
||||
|
||||
param := common.MySQLConnectParam{
|
||||
Host: "127.0.0.1",
|
||||
Port: 4000,
|
||||
User: "root",
|
||||
Password: plainPsw,
|
||||
SQLMode: "strict",
|
||||
MaxAllowedPacket: 1234,
|
||||
}
|
||||
db, err := param.Connect()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
param.Password = base64.StdEncoding.EncodeToString([]byte(plainPsw))
|
||||
db, err = param.Connect()
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, db.Close())
|
||||
}
|
||||
|
||||
func TestIsContextCanceledError(t *testing.T) {
|
||||
require.True(t, common.IsContextCanceledError(context.Canceled))
|
||||
require.False(t, common.IsContextCanceledError(io.EOF))
|
||||
|
||||
Reference in New Issue
Block a user