Files
tidb/pkg/extension/registry_test.go

477 lines
16 KiB
Go

// Copyright 2022 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package extension_test
import (
"testing"
"github.com/pingcap/errors"
"github.com/pingcap/tidb/pkg/extension"
"github.com/pingcap/tidb/pkg/parser/auth"
"github.com/pingcap/tidb/pkg/parser/mysql"
"github.com/pingcap/tidb/pkg/privilege/privileges"
"github.com/pingcap/tidb/pkg/sessionctx/variable"
"github.com/pingcap/tidb/pkg/testkit"
"github.com/pingcap/tidb/pkg/util/sem"
"github.com/stretchr/testify/require"
)
func TestSetupExtensions(t *testing.T) {
defer func() {
extension.Reset()
}()
extension.Reset()
require.NoError(t, extension.Setup())
extensions, err := extension.GetExtensions()
require.NoError(t, err)
require.Equal(t, 0, len(extensions.Manifests()))
extension.Reset()
require.NoError(t, extension.Register("test1"))
require.NoError(t, extension.Register("test2"))
require.NoError(t, extension.Setup())
extensions, err = extension.GetExtensions()
require.NoError(t, err)
require.Equal(t, 2, len(extensions.Manifests()))
require.Equal(t, "test1", extensions.Manifests()[0].Name())
require.Equal(t, "test2", extensions.Manifests()[1].Name())
}
func TestExtensionRegisterName(t *testing.T) {
defer extension.Reset()
// test empty name
extension.Reset()
require.EqualError(t, extension.Register(""), "extension name should not be empty")
// test dup name
extension.Reset()
require.NoError(t, extension.Register("test"))
require.EqualError(t, extension.Register("test"), "extension with name 'test' already registered")
}
func TestRegisterExtensionWithClose(t *testing.T) {
defer extension.Reset()
// normal register
extension.Reset()
cnt := 0
require.NoError(t, extension.Register("test1", extension.WithClose(func() {
cnt++
})))
require.NoError(t, extension.Setup())
require.Equal(t, 0, cnt)
// reset will call close
extension.Reset()
require.Equal(t, 1, cnt)
// reset again has no effect
extension.Reset()
require.Equal(t, 1, cnt)
// Auto close when error
cnt = 0
extension.Reset()
require.NoError(t, extension.Register("test1", extension.WithClose(func() {
cnt++
})))
require.NoError(t, extension.RegisterFactory("test2", func() ([]extension.Option, error) {
return nil, errors.New("error abc")
}))
require.EqualError(t, extension.Setup(), "error abc")
require.Equal(t, 1, cnt)
}
func TestRegisterExtensionWithDyncPrivs(t *testing.T) {
defer extension.Reset()
origDynPrivs := privileges.GetDynamicPrivileges()
origDynPrivs = append([]string{}, origDynPrivs...)
extension.Reset()
require.NoError(t, extension.Register("test", extension.WithCustomDynPrivs([]string{"priv1", "priv2"})))
require.NoError(t, extension.Setup())
privs := privileges.GetDynamicPrivileges()
require.Equal(t, origDynPrivs, privs[:len(origDynPrivs)])
require.Equal(t, []string{"PRIV1", "PRIV2"}, privs[len(origDynPrivs):])
// test for empty dynamic privilege name
extension.Reset()
require.NoError(t, extension.Register("test", extension.WithCustomDynPrivs([]string{"priv1", ""})))
require.EqualError(t, extension.Setup(), "privilege name should not be empty")
require.Equal(t, origDynPrivs, privileges.GetDynamicPrivileges())
// test for duplicate name with builtin
extension.Reset()
require.NoError(t, extension.Register("test", extension.WithCustomDynPrivs([]string{"priv1", "ROLE_ADMIN"})))
require.EqualError(t, extension.Setup(), "privilege is already registered")
require.Equal(t, origDynPrivs, privileges.GetDynamicPrivileges())
// test for duplicate name with other extension
extension.Reset()
require.NoError(t, extension.Register("test1", extension.WithCustomDynPrivs([]string{"priv1"})))
require.NoError(t, extension.Register("test2", extension.WithCustomDynPrivs([]string{"priv2", "priv1"})))
require.EqualError(t, extension.Setup(), "privilege is already registered")
require.Equal(t, origDynPrivs, privileges.GetDynamicPrivileges())
}
func TestRegisterExtensionWithSysVars(t *testing.T) {
defer extension.Reset()
sysVar1 := &variable.SysVar{
Scope: variable.ScopeGlobal | variable.ScopeSession,
Name: "var1",
Value: variable.On,
Type: variable.TypeBool,
}
sysVar2 := &variable.SysVar{
Scope: variable.ScopeSession,
Name: "var2",
Value: "val2",
Type: variable.TypeStr,
}
// normal register
extension.Reset()
require.NoError(t, extension.Register("test", extension.WithCustomSysVariables([]*variable.SysVar{sysVar1, sysVar2})))
require.NoError(t, extension.Setup())
require.Same(t, sysVar1, variable.GetSysVar("var1"))
require.Same(t, sysVar2, variable.GetSysVar("var2"))
// test for empty name
extension.Reset()
require.NoError(t, extension.Register("test", extension.WithCustomSysVariables([]*variable.SysVar{
{Scope: variable.ScopeGlobal, Name: "", Value: "val3"},
})))
require.EqualError(t, extension.Setup(), "system var name should not be empty")
require.Nil(t, variable.GetSysVar(""))
// test for duplicate name with builtin
extension.Reset()
require.NoError(t, extension.Register("test", extension.WithCustomSysVariables([]*variable.SysVar{
sysVar1,
{Scope: variable.ScopeGlobal, Name: variable.TiDBSnapshot, Value: "val3"},
})))
require.EqualError(t, extension.Setup(), "system var 'tidb_snapshot' has already registered")
require.Nil(t, variable.GetSysVar("var1"))
require.Equal(t, "", variable.GetSysVar(variable.TiDBSnapshot).Value)
require.Equal(t, variable.ScopeSession, variable.GetSysVar(variable.TiDBSnapshot).Scope)
// test for duplicate name with other extension
extension.Reset()
require.NoError(t, extension.Register("test1", extension.WithCustomSysVariables([]*variable.SysVar{sysVar1, sysVar2})))
require.NoError(t, extension.Register("test2", extension.WithCustomSysVariables([]*variable.SysVar{sysVar1})))
require.EqualError(t, extension.Setup(), "system var 'var1' has already registered")
require.Nil(t, variable.GetSysVar("var1"))
require.Nil(t, variable.GetSysVar("var2"))
}
func TestSetVariablePrivilege(t *testing.T) {
defer extension.Reset()
sysVar1 := &variable.SysVar{
Scope: variable.ScopeGlobal | variable.ScopeSession,
Name: "var1",
Value: "1",
MinValue: 0,
MaxValue: 100,
Type: variable.TypeInt,
RequireDynamicPrivileges: func(isGlobal bool, sem bool) []string {
privs := []string{"priv1"}
if isGlobal {
privs = append(privs, "priv2")
}
if sem {
privs = append(privs, "restricted_priv3")
}
return privs
},
}
extension.Reset()
require.NoError(t, extension.Register(
"test",
extension.WithCustomSysVariables([]*variable.SysVar{sysVar1}),
extension.WithCustomDynPrivs([]string{"priv1", "priv2", "restricted_priv3"}),
))
require.NoError(t, extension.Setup())
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create user u2@localhost")
tk1 := testkit.NewTestKit(t, store)
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil))
tk2 := testkit.NewTestKit(t, store)
require.NoError(t, tk2.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil))
sem.Disable()
tk1.MustExec("set @@var1=7")
tk1.MustQuery("select @@var1").Check(testkit.Rows("7"))
require.EqualError(t, tk2.ExecToErr("set @@var1=10"), "[planner:1227]Access denied; you need (at least one of) the SUPER or priv1 privilege(s) for this operation")
tk2.MustQuery("select @@var1").Check(testkit.Rows("1"))
tk.MustExec("GRANT priv1 on *.* TO u2@localhost")
tk2.MustExec("set @@var1=8")
tk2.MustQuery("select @@var1").Check(testkit.Rows("8"))
tk1.MustExec("set @@global.var1=17")
tk1.MustQuery("select @@global.var1").Check(testkit.Rows("17"))
tk.MustExec("GRANT SYSTEM_VARIABLES_ADMIN on *.* TO u2@localhost")
require.EqualError(t, tk2.ExecToErr("set @@global.var1=18"), "[planner:1227]Access denied; you need (at least one of) the SUPER or priv2 privilege(s) for this operation")
tk2.MustQuery("select @@global.var1").Check(testkit.Rows("17"))
tk.MustExec("GRANT priv2 on *.* TO u2@localhost")
tk2.MustExec("set @@global.var1=18")
tk2.MustQuery("select @@global.var1").Check(testkit.Rows("18"))
sem.Enable()
defer sem.Disable()
require.EqualError(t, tk1.ExecToErr("set @@global.var1=27"), "[planner:1227]Access denied; you need (at least one of) the restricted_priv3 privilege(s) for this operation")
tk1.MustQuery("select @@global.var1").Check(testkit.Rows("18"))
require.EqualError(t, tk2.ExecToErr("set @@global.var1=27"), "[planner:1227]Access denied; you need (at least one of) the restricted_priv3 privilege(s) for this operation")
tk2.MustQuery("select @@global.var1").Check(testkit.Rows("18"))
tk.MustExec("GRANT restricted_priv3 on *.* TO u2@localhost")
tk2.MustExec("set @@global.var1=28")
tk2.MustQuery("select @@global.var1").Check(testkit.Rows("28"))
}
func TestCustomAccessCheck(t *testing.T) {
defer extension.Reset()
extension.Reset()
require.NoError(t, extension.Register(
"test",
extension.WithCustomDynPrivs([]string{"priv1", "priv2", "restricted_priv3"}),
extension.WithCustomAccessCheck(func(db, tbl, column string, priv mysql.PrivilegeType, sem bool) []string {
if db != "test" || tbl != "t1" {
return nil
}
var privs []string
if priv == mysql.SelectPriv {
privs = append(privs, "priv1")
} else if priv == mysql.UpdatePriv {
privs = append(privs, "priv2")
if sem {
privs = append(privs, "restricted_priv3")
}
} else {
return nil
}
return privs
}),
))
require.NoError(t, extension.Setup())
store := testkit.CreateMockStore(t)
tk := testkit.NewTestKit(t, store)
tk.MustExec("use test")
tk.MustExec("create user u2@localhost")
tk1 := testkit.NewTestKit(t, store)
require.NoError(t, tk1.Session().Auth(&auth.UserIdentity{Username: "root", Hostname: "localhost"}, nil, nil, nil))
tk1.MustExec("use test")
tk2 := testkit.NewTestKit(t, store)
require.NoError(t, tk2.Session().Auth(&auth.UserIdentity{Username: "u2", Hostname: "localhost"}, nil, nil, nil))
tk.MustExec("GRANT all on test.t1 TO u2@localhost")
tk2.MustExec("use test")
tk1.MustExec("create table t1(id int primary key, v int)")
tk1.MustExec("insert into t1 values (1, 10), (2, 20)")
tk1.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 10"))
tk1.MustQuery("select * from t1").Check(testkit.Rows("1 10", "2 20"))
require.EqualError(t, tk2.ExecToErr("select * from t1 where id=1"), "[planner:1142]SELECT command denied to user 'u2'@'localhost' for table 't1'")
require.EqualError(t, tk2.ExecToErr("select * from t1"), "[planner:1142]SELECT command denied to user 'u2'@'localhost' for table 't1'")
tk.MustExec("GRANT priv1 on *.* TO u2@localhost")
tk2.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 10"))
tk2.MustQuery("select * from t1").Check(testkit.Rows("1 10", "2 20"))
require.EqualError(t, tk2.ExecToErr("update t1 set v=11 where id=1"), "[planner:8121]privilege check for 'Update' fail")
require.EqualError(t, tk2.ExecToErr("update t1 set v=11 where id<2"), "[planner:8121]privilege check for 'Update' fail")
tk2.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 10"))
tk.MustExec("GRANT priv2 on *.* TO u2@localhost")
tk2.MustExec("update t1 set v=11 where id=1")
tk2.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 11"))
tk2.MustExec("update t1 set v=12 where id<2")
tk2.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 12"))
sem.Enable()
defer sem.Disable()
require.EqualError(t, tk1.ExecToErr("update t1 set v=21 where id=1"), "[planner:8121]privilege check for 'Update' fail")
require.EqualError(t, tk1.ExecToErr("update t1 set v=21 where id<2"), "[planner:8121]privilege check for 'Update' fail")
tk1.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 12"))
require.EqualError(t, tk2.ExecToErr("update t1 set v=21 where id=1"), "[planner:8121]privilege check for 'Update' fail")
require.EqualError(t, tk2.ExecToErr("update t1 set v=21 where id<2"), "[planner:8121]privilege check for 'Update' fail")
tk2.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 12"))
tk.MustExec("GRANT restricted_priv3 on *.* TO u2@localhost")
tk2.MustExec("update t1 set v=31 where id=1")
tk2.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 31"))
tk2.MustExec("update t1 set v=32 where id<2")
tk2.MustQuery("select * from t1 where id=1").Check(testkit.Rows("1 32"))
}
func TestAuthPluginValidation(t *testing.T) {
defer extension.Reset()
extension.Reset()
require.NoError(t, extension.Register("test", extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{Name: ""},
})))
require.ErrorContains(t, extension.Setup(), "auth plugin name cannot be empty")
extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
ValidateAuthString: func(pwdHash string) bool {
return false
},
GenerateAuthString: func(pwd string) (string, bool) {
return pwd, true
},
},
})),
)
require.ErrorContains(t, extension.Setup(), "auth plugin AuthenticateUser function cannot be nil for plugin1")
extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
},
GenerateAuthString: func(pwd string) (string, bool) {
return pwd, true
},
},
})),
)
require.ErrorContains(t, extension.Setup(), "auth plugin ValidateAuthString function cannot be nil for plugin1")
extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
},
ValidateAuthString: func(pwdHash string) bool {
return true
},
},
})),
)
require.ErrorContains(t, extension.Setup(), "auth plugin GenerateAuthString function cannot be nil for plugin1")
extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
},
GenerateAuthString: func(pwd string) (string, bool) {
return pwd, true
},
ValidateAuthString: func(pwdHash string) bool {
return true
},
},
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
},
GenerateAuthString: func(pwd string) (string, bool) {
return pwd, true
},
ValidateAuthString: func(pwdHash string) bool {
return true
},
},
})),
)
require.ErrorContains(t, extension.Setup(), "has already been registered")
extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "mysql_native_password",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
},
GenerateAuthString: func(pwd string) (string, bool) {
return pwd, true
},
ValidateAuthString: func(pwdHash string) bool {
return true
},
},
})),
)
require.ErrorContains(t, extension.Setup(), "a reserved name for default auth plugins")
extension.Reset()
require.NoError(t, extension.Register("test",
extension.WithCustomAuthPlugins([]*extension.AuthPlugin{
{
Name: "plugin1",
AuthenticateUser: func(ctx extension.AuthenticateRequest) error {
return nil
},
GenerateAuthString: func(pwd string) (string, bool) {
return pwd, true
},
ValidateAuthString: func(pwdHash string) bool {
return true
},
},
})),
)
require.NoError(t, extension.Setup())
}