diff --git a/pkg/privilege/privileges/cache.go b/pkg/privilege/privileges/cache.go index 0fcc5c4c18..30db8dff13 100644 --- a/pkg/privilege/privileges/cache.go +++ b/pkg/privilege/privileges/cache.go @@ -921,7 +921,7 @@ func (p *immutable) decodeUserTableRow(row chunk.Row, fs []*resolve.ResultField) defaultAuthPlugin := "" if p.globalVars != nil { val, err := p.globalVars.GetGlobalSysVar(variable.DefaultAuthPlugin) - if err != nil { + if err == nil { defaultAuthPlugin = val } } @@ -1889,6 +1889,11 @@ func (p *MySQLPrivilege) getAllRoles(user, host string) []*auth.RoleIdentity { return ret } +// SetGlobalVarsAccessor is only used for test. +func (p *MySQLPrivilege) SetGlobalVarsAccessor(globalVars variable.GlobalVarAccessor) { + p.globalVars = globalVars +} + // Handle wraps MySQLPrivilege providing thread safe access. type Handle struct { sctx sqlexec.RestrictedSQLExecutor diff --git a/pkg/privilege/privileges/cache_test.go b/pkg/privilege/privileges/cache_test.go index 3527d347aa..2f734fc5e1 100644 --- a/pkg/privilege/privileges/cache_test.go +++ b/pkg/privilege/privileges/cache_test.go @@ -15,6 +15,7 @@ package privileges_test import ( + "context" "fmt" "testing" "time" @@ -22,6 +23,7 @@ import ( "github.com/pingcap/tidb/pkg/parser/auth" "github.com/pingcap/tidb/pkg/parser/mysql" "github.com/pingcap/tidb/pkg/privilege/privileges" + "github.com/pingcap/tidb/pkg/sessionctx/variable" "github.com/pingcap/tidb/pkg/testkit" "github.com/pingcap/tidb/pkg/util" "github.com/stretchr/testify/require" @@ -66,6 +68,15 @@ func TestLoadUserTable(t *testing.T) { require.Equal(t, false, user[6].PasswordExpired) require.Equal(t, time.Date(2022, 10, 10, 12, 0, 0, 0, time.Local), user[6].PasswordLastChanged) require.Equal(t, int64(-1), user[6].PasswordLifeTime) + + // test switching default auth plugin + for _, plugin := range []string{mysql.AuthNativePassword, mysql.AuthCachingSha2Password, mysql.AuthTiDBSM3Password} { + p = privileges.MySQLPrivilege{} + p.SetGlobalVarsAccessor(se.GetSessionVars().GlobalVarsAccessor) + require.NoError(t, se.GetSessionVars().GlobalVarsAccessor.SetGlobalSysVar(context.Background(), variable.DefaultAuthPlugin, plugin)) + require.NoError(t, p.LoadUserTable(se.GetRestrictedSQLExecutor())) + require.Equal(t, plugin, p.User()[0].AuthPlugin) + } } func TestLoadGlobalPrivTable(t *testing.T) {