privileges: fix create temporary tables privilege (#29279)
This commit is contained in:
@ -255,7 +255,7 @@ const (
|
||||
CreateRolePriv
|
||||
// DropRolePriv is the privilege to drop a role.
|
||||
DropRolePriv
|
||||
|
||||
// CreateTMPTablePriv is the privilege to create a temporary table.
|
||||
CreateTMPTablePriv
|
||||
LockTablesPriv
|
||||
CreateRoutinePriv
|
||||
|
||||
@ -365,7 +365,8 @@ func (e *Execute) handleExecuteBuilderOption(sctx sessionctx.Context,
|
||||
func (e *Execute) checkPreparedPriv(ctx context.Context, sctx sessionctx.Context,
|
||||
preparedObj *CachedPrepareStmt, is infoschema.InfoSchema) error {
|
||||
if pm := privilege.GetPrivilegeManager(sctx); pm != nil {
|
||||
if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, preparedObj.VisitInfos); err != nil {
|
||||
visitInfo := VisitInfo4PrivCheck(is, preparedObj.PreparedAst.Stmt, preparedObj.VisitInfos)
|
||||
if err := CheckPrivilege(sctx.GetSessionVars().ActiveRoles, pm, visitInfo); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"github.com/pingcap/tidb/lock"
|
||||
"github.com/pingcap/tidb/parser/ast"
|
||||
"github.com/pingcap/tidb/parser/auth"
|
||||
"github.com/pingcap/tidb/parser/model"
|
||||
"github.com/pingcap/tidb/parser/mysql"
|
||||
"github.com/pingcap/tidb/planner/property"
|
||||
"github.com/pingcap/tidb/privilege"
|
||||
@ -120,6 +121,76 @@ func CheckPrivilege(activeRoles []*auth.RoleIdentity, pm privilege.Manager, vs [
|
||||
return nil
|
||||
}
|
||||
|
||||
// VisitInfo4PrivCheck generates privilege check infos because privilege check of local temporary tables is different
|
||||
// with normal tables. `CREATE` statement needs `CREATE TEMPORARY TABLE` privilege from the database, and subsequent
|
||||
// statements do not need any privileges.
|
||||
func VisitInfo4PrivCheck(is infoschema.InfoSchema, node ast.Node, vs []visitInfo) (privVisitInfo []visitInfo) {
|
||||
if node == nil {
|
||||
return vs
|
||||
}
|
||||
|
||||
switch stmt := node.(type) {
|
||||
case *ast.CreateTableStmt:
|
||||
privVisitInfo = make([]visitInfo, 0, len(vs))
|
||||
for _, v := range vs {
|
||||
if v.privilege == mysql.CreatePriv {
|
||||
if stmt.TemporaryKeyword == ast.TemporaryLocal {
|
||||
// `CREATE TEMPORARY TABLE` privilege is required from the database, not the table.
|
||||
newVisitInfo := v
|
||||
newVisitInfo.privilege = mysql.CreateTMPTablePriv
|
||||
newVisitInfo.table = ""
|
||||
privVisitInfo = append(privVisitInfo, newVisitInfo)
|
||||
} else {
|
||||
// If both the normal table and temporary table already exist, we need to check the privilege.
|
||||
privVisitInfo = append(privVisitInfo, v)
|
||||
}
|
||||
} else {
|
||||
// `CREATE TABLE LIKE tmp` or `CREATE TABLE FROM SELECT tmp` in the future.
|
||||
if needCheckTmpTablePriv(is, v) {
|
||||
privVisitInfo = append(privVisitInfo, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
case *ast.DropTableStmt:
|
||||
// Dropping a local temporary table doesn't need any privileges.
|
||||
if stmt.IsView {
|
||||
privVisitInfo = vs
|
||||
} else {
|
||||
privVisitInfo = make([]visitInfo, 0, len(vs))
|
||||
if stmt.TemporaryKeyword != ast.TemporaryLocal {
|
||||
for _, v := range vs {
|
||||
if needCheckTmpTablePriv(is, v) {
|
||||
privVisitInfo = append(privVisitInfo, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
case *ast.GrantStmt, *ast.DropSequenceStmt, *ast.DropPlacementPolicyStmt:
|
||||
// Some statements ignore local temporary tables, so they should check the privileges on normal tables.
|
||||
privVisitInfo = vs
|
||||
default:
|
||||
privVisitInfo = make([]visitInfo, 0, len(vs))
|
||||
for _, v := range vs {
|
||||
if needCheckTmpTablePriv(is, v) {
|
||||
privVisitInfo = append(privVisitInfo, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func needCheckTmpTablePriv(is infoschema.InfoSchema, v visitInfo) bool {
|
||||
if v.db != "" && v.table != "" {
|
||||
// Other statements on local temporary tables except `CREATE` do not check any privileges.
|
||||
tb, err := is.TableByName(model.NewCIStr(v.db), model.NewCIStr(v.table))
|
||||
// If the table doesn't exist, we do not report errors to avoid leaking the existence of the table.
|
||||
if err == nil && tb.Meta().TempTableType == model.TempTableLocal {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// CheckTableLock checks the table lock.
|
||||
func CheckTableLock(ctx sessionctx.Context, is infoschema.InfoSchema, vs []visitInfo) error {
|
||||
if !config.TableLockEnabled() {
|
||||
|
||||
@ -3868,8 +3868,15 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err
|
||||
}
|
||||
}
|
||||
if b.ctx.GetSessionVars().User != nil {
|
||||
authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE", b.ctx.GetSessionVars().User.AuthUsername,
|
||||
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
|
||||
// This is tricky here: we always need the visitInfo because it's not only used in privilege checks, and we
|
||||
// must pass the table name. However, the privilege check is towards the database. We'll deal with it later.
|
||||
if v.TemporaryKeyword == ast.TemporaryLocal {
|
||||
authErr = ErrDBaccessDenied.GenWithStackByArgs(b.ctx.GetSessionVars().User.AuthUsername,
|
||||
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Schema.L)
|
||||
} else {
|
||||
authErr = ErrTableaccessDenied.GenWithStackByArgs("CREATE", b.ctx.GetSessionVars().User.AuthUsername,
|
||||
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
|
||||
}
|
||||
}
|
||||
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.CreatePriv, v.Table.Schema.L,
|
||||
v.Table.Name.L, "", authErr)
|
||||
@ -3936,7 +3943,7 @@ func (b *PlanBuilder) buildDDL(ctx context.Context, node ast.DDLNode) (Plan, err
|
||||
"", "", authErr)
|
||||
case *ast.DropIndexStmt:
|
||||
if b.ctx.GetSessionVars().User != nil {
|
||||
authErr = ErrTableaccessDenied.GenWithStackByArgs("INDEx", b.ctx.GetSessionVars().User.AuthUsername,
|
||||
authErr = ErrTableaccessDenied.GenWithStackByArgs("INDEX", b.ctx.GetSessionVars().User.AuthUsername,
|
||||
b.ctx.GetSessionVars().User.AuthHostname, v.Table.Name.L)
|
||||
}
|
||||
b.visitInfo = appendVisitInfo(b.visitInfo, mysql.IndexPriv, v.Table.Schema.L,
|
||||
|
||||
@ -345,7 +345,8 @@ func optimize(ctx context.Context, sctx sessionctx.Context, node ast.Node, is in
|
||||
// we need the table information to check privilege, which is collected
|
||||
// into the visitInfo in the logical plan builder.
|
||||
if pm := privilege.GetPrivilegeManager(sctx); pm != nil {
|
||||
if err := plannercore.CheckPrivilege(activeRoles, pm, builder.GetVisitInfo()); err != nil {
|
||||
visitInfo := plannercore.VisitInfo4PrivCheck(is, node, builder.GetVisitInfo())
|
||||
if err := plannercore.CheckPrivilege(activeRoles, pm, visitInfo); err != nil {
|
||||
return nil, nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -26,6 +26,7 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pingcap/tidb/errno"
|
||||
"github.com/pingcap/tidb/executor"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/parser/auth"
|
||||
@ -2666,7 +2667,7 @@ func TestGrantCreateTmpTables(t *testing.T) {
|
||||
tk.MustExec("CREATE TABLE create_tmp_table_table (a int)")
|
||||
tk.MustExec("GRANT CREATE TEMPORARY TABLES on create_tmp_table_db.* to u1")
|
||||
tk.MustExec("GRANT CREATE TEMPORARY TABLES on *.* to u1")
|
||||
// Must set a session user to avoid null pointer dereferencing
|
||||
// Must set a session user to avoid null pointer dereference
|
||||
tk.Session().Auth(&auth.UserIdentity{
|
||||
Username: "root",
|
||||
Hostname: "localhost",
|
||||
@ -2678,6 +2679,164 @@ func TestGrantCreateTmpTables(t *testing.T) {
|
||||
tk.MustExec("DROP DATABASE create_tmp_table_db")
|
||||
}
|
||||
|
||||
func TestCreateTmpTablesPriv(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, clean := newStore(t)
|
||||
defer clean()
|
||||
|
||||
createStmt := "CREATE TEMPORARY TABLE test.tmp(id int)"
|
||||
dropStmt := "DROP TEMPORARY TABLE IF EXISTS test.tmp"
|
||||
|
||||
tk := testkit.NewTestKit(t, store)
|
||||
tk.MustExec(dropStmt)
|
||||
tk.MustExec("CREATE TABLE test.t(id int primary key)")
|
||||
tk.MustExec("CREATE SEQUENCE test.tmp")
|
||||
tk.MustExec("CREATE USER vcreate, vcreate_tmp, vcreate_tmp_all")
|
||||
tk.MustExec("GRANT CREATE, USAGE ON test.* TO vcreate")
|
||||
tk.MustExec("GRANT CREATE TEMPORARY TABLES, USAGE ON test.* TO vcreate_tmp")
|
||||
tk.MustExec("GRANT CREATE TEMPORARY TABLES, USAGE ON *.* TO vcreate_tmp_all")
|
||||
|
||||
tk.Session().Auth(&auth.UserIdentity{Username: "vcreate", Hostname: "localhost"}, nil, nil)
|
||||
err := tk.ExecToErr(createStmt)
|
||||
require.EqualError(t, err, "[planner:1044]Access denied for user 'vcreate'@'%' to database 'test'")
|
||||
tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp", Hostname: "localhost"}, nil, nil)
|
||||
tk.MustExec(createStmt)
|
||||
tk.MustExec(dropStmt)
|
||||
tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp_all", Hostname: "localhost"}, nil, nil)
|
||||
// TODO: issue #29280 to be fixed.
|
||||
//err = tk.ExecToErr(createStmt)
|
||||
//require.EqualError(t, err, "[planner:1044]Access denied for user 'vcreate_tmp_all'@'%' to database 'test'")
|
||||
|
||||
tests := []struct {
|
||||
sql string
|
||||
errcode int
|
||||
}{
|
||||
{
|
||||
sql: "create temporary table tmp(id int primary key)",
|
||||
},
|
||||
{
|
||||
sql: "insert into tmp value(1)",
|
||||
},
|
||||
{
|
||||
sql: "insert into tmp value(1) on duplicate key update id=1",
|
||||
},
|
||||
{
|
||||
sql: "replace tmp values(1)",
|
||||
},
|
||||
{
|
||||
sql: "insert into tmp select * from t",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "update tmp set id=1 where id=1",
|
||||
},
|
||||
{
|
||||
sql: "update tmp t1, t t2 set t1.id=t2.id where t1.id=t2.id",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "delete from tmp where id=1",
|
||||
},
|
||||
{
|
||||
sql: "delete t1 from tmp t1 join t t2 where t1.id=t2.id",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "select * from tmp where id=1",
|
||||
},
|
||||
{
|
||||
sql: "select * from tmp where id in (1,2)",
|
||||
},
|
||||
{
|
||||
sql: "select * from tmp",
|
||||
},
|
||||
{
|
||||
sql: "select * from tmp join t where tmp.id=t.id",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "(select * from tmp) union (select * from t)",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "create temporary table tmp1 like t",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "create table tmp(id int primary key)",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "create table t(id int primary key)",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "analyze table tmp",
|
||||
},
|
||||
{
|
||||
sql: "analyze table tmp, t",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "show create table tmp",
|
||||
},
|
||||
// TODO: issue #29281 to be fixed.
|
||||
//{
|
||||
// sql: "show create table t",
|
||||
// errcode: mysql.ErrTableaccessDenied,
|
||||
//},
|
||||
{
|
||||
sql: "drop sequence tmp",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "alter table tmp add column c1 char(10)",
|
||||
errcode: errno.ErrUnsupportedDDLOperation,
|
||||
},
|
||||
{
|
||||
sql: "truncate table tmp",
|
||||
},
|
||||
{
|
||||
sql: "drop temporary table t",
|
||||
errcode: mysql.ErrBadTable,
|
||||
},
|
||||
{
|
||||
sql: "drop table t",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "drop table t, tmp",
|
||||
errcode: mysql.ErrTableaccessDenied,
|
||||
},
|
||||
{
|
||||
sql: "drop temporary table tmp",
|
||||
},
|
||||
}
|
||||
|
||||
tk.Session().Auth(&auth.UserIdentity{Username: "vcreate_tmp", Hostname: "localhost"}, nil, nil)
|
||||
tk.MustExec("use test")
|
||||
tk.MustExec(dropStmt)
|
||||
for _, test := range tests {
|
||||
if test.errcode == 0 {
|
||||
tk.MustExec(test.sql)
|
||||
} else {
|
||||
tk.MustGetErrCode(test.sql, test.errcode)
|
||||
}
|
||||
}
|
||||
|
||||
// TODO: issue #29282 to be fixed.
|
||||
//for i, test := range tests {
|
||||
// preparedStmt := fmt.Sprintf("prepare stmt%d from '%s'", i, test.sql)
|
||||
// executeStmt := fmt.Sprintf("execute stmt%d", i)
|
||||
// tk.MustExec(preparedStmt)
|
||||
// if test.errcode == 0 {
|
||||
// tk.MustExec(executeStmt)
|
||||
// } else {
|
||||
// tk.MustGetErrCode(executeStmt, test.errcode)
|
||||
// }
|
||||
//}
|
||||
}
|
||||
|
||||
func TestRevokeSecondSyntax(t *testing.T) {
|
||||
t.Parallel()
|
||||
store, clean := newStore(t)
|
||||
|
||||
Reference in New Issue
Block a user