stmts: Support grant column scope privilege
This commit is contained in:
@ -126,6 +126,8 @@ const (
|
||||
DBTable = "DB"
|
||||
// TablePrivTable is the table in system db contains table scope privilege info.
|
||||
TablePrivTable = "Tables_priv"
|
||||
// ColumnPrivTable is the table in system db contains column scope privilege info.
|
||||
ColumnPrivTable = "Columns_priv"
|
||||
)
|
||||
|
||||
// PrivilegeType privilege
|
||||
|
||||
@ -22,6 +22,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/model"
|
||||
@ -104,9 +105,15 @@ func (s *GrantStmt) Exec(ctx context.Context) (rset.Recordset, error) {
|
||||
}
|
||||
// Grant each priv to the user.
|
||||
for _, priv := range s.Privs {
|
||||
err := s.grantPriv(ctx, priv, user)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
if len(priv.Cols) > 0 {
|
||||
err1 := s.checkAndInitColumnPriv(ctx, userName, host, priv.Cols)
|
||||
if err1 != nil {
|
||||
return nil, errors.Trace(err1)
|
||||
}
|
||||
}
|
||||
err2 := s.grantPriv(ctx, priv, user)
|
||||
if err2 != nil {
|
||||
return nil, errors.Trace(err2)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -145,6 +152,32 @@ func (s *GrantStmt) checkAndInitTablePriv(ctx context.Context, user string, host
|
||||
return initTablePrivEntry(ctx, user, host, db.Name.O, tbl.TableName().O)
|
||||
}
|
||||
|
||||
func (s *GrantStmt) checkAndInitColumnPriv(ctx context.Context, user string, host string, cols []string) error {
|
||||
db, tbl, err := s.getTargetSchemaAndTable(ctx)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
for _, c := range cols {
|
||||
col := column.FindCol(tbl.Cols(), c)
|
||||
if col == nil {
|
||||
return errors.Errorf("Unknown column: %s", c)
|
||||
}
|
||||
ok, err := columnPrivEntryExists(ctx, user, host, db.Name.O, tbl.TableName().O, col.Name.O)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
if ok {
|
||||
continue
|
||||
}
|
||||
// Entry does not exists for user/host/db. Insert a new entry.
|
||||
err = initColumnPrivEntry(ctx, user, host, db.Name.O, tbl.TableName().O, col.Name.O)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func initDBPrivEntry(ctx context.Context, user string, host string, db string) error {
|
||||
st := &InsertIntoStmt{
|
||||
TableIdent: table.Ident{
|
||||
@ -170,7 +203,7 @@ func initTablePrivEntry(ctx context.Context, user string, host string, db string
|
||||
Name: model.NewCIStr(mysql.TablePrivTable),
|
||||
Schema: model.NewCIStr(mysql.SystemDB),
|
||||
},
|
||||
ColNames: []string{"Host", "User", "DB", "Table_name"},
|
||||
ColNames: []string{"Host", "User", "DB", "Table_name", "Table_priv", "Column_priv"},
|
||||
}
|
||||
values := make([][]expression.Expression, 0, 1)
|
||||
value := make([]expression.Expression, 0, 3)
|
||||
@ -178,6 +211,30 @@ func initTablePrivEntry(ctx context.Context, user string, host string, db string
|
||||
value = append(value, &expression.Value{Val: user})
|
||||
value = append(value, &expression.Value{Val: db})
|
||||
value = append(value, &expression.Value{Val: tbl})
|
||||
value = append(value, &expression.Value{Val: ""})
|
||||
value = append(value, &expression.Value{Val: ""})
|
||||
values = append(values, value)
|
||||
st.Lists = values
|
||||
_, err := st.Exec(ctx)
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
func initColumnPrivEntry(ctx context.Context, user string, host string, db string, tbl string, col string) error {
|
||||
st := &InsertIntoStmt{
|
||||
TableIdent: table.Ident{
|
||||
Name: model.NewCIStr(mysql.ColumnPrivTable),
|
||||
Schema: model.NewCIStr(mysql.SystemDB),
|
||||
},
|
||||
ColNames: []string{"Host", "User", "DB", "Table_name", "Column_name", "Column_priv"},
|
||||
}
|
||||
values := make([][]expression.Expression, 0, 1)
|
||||
value := make([]expression.Expression, 0, 3)
|
||||
value = append(value, &expression.Value{Val: host}) // Host
|
||||
value = append(value, &expression.Value{Val: user}) // User
|
||||
value = append(value, &expression.Value{Val: db}) // DB
|
||||
value = append(value, &expression.Value{Val: tbl}) // Table_name
|
||||
value = append(value, &expression.Value{Val: col}) // Column_name
|
||||
value = append(value, &expression.Value{Val: ""}) // Empty Column_priv
|
||||
values = append(values, value)
|
||||
st.Lists = values
|
||||
_, err := st.Exec(ctx)
|
||||
@ -191,7 +248,10 @@ func (s *GrantStmt) grantPriv(ctx context.Context, priv *coldef.PrivElem, user *
|
||||
case coldef.GrantLevelDB:
|
||||
return s.grantDBPriv(ctx, priv, user)
|
||||
case coldef.GrantLevelTable:
|
||||
return s.grantTablePriv(ctx, priv, user)
|
||||
if len(priv.Cols) == 0 {
|
||||
return s.grantTablePriv(ctx, priv, user)
|
||||
}
|
||||
return s.grantColumnPriv(ctx, priv, user)
|
||||
default:
|
||||
return errors.Errorf("Unknown grant level: %s", s.Level)
|
||||
}
|
||||
@ -335,6 +395,42 @@ func composeTablePrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name
|
||||
return assigns, nil
|
||||
}
|
||||
|
||||
func composeColumnPrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) ([]expression.Assignment, error) {
|
||||
newColumnPriv := ""
|
||||
if priv == mysql.AllPriv {
|
||||
for _, p := range mysql.AllColumnPrivs {
|
||||
v, ok := mysql.Priv2SetStr[p]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Unknown column privilege %s", p)
|
||||
}
|
||||
if len(newColumnPriv) == 0 {
|
||||
newColumnPriv = v
|
||||
} else {
|
||||
newColumnPriv = fmt.Sprintf("%s,%s", newColumnPriv, v)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
p, ok := mysql.Priv2SetStr[priv]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("Unknown priv: %s", priv)
|
||||
}
|
||||
if len(currColumnPriv) == 0 {
|
||||
newColumnPriv = p
|
||||
} else {
|
||||
newColumnPriv = fmt.Sprintf("%s,%s", currColumnPriv, p)
|
||||
}
|
||||
}
|
||||
t := expression.Assignment{
|
||||
ColName: "Column_priv",
|
||||
Expr: expression.Value{Val: newColumnPriv},
|
||||
}
|
||||
assigns := []expression.Assignment{t}
|
||||
return assigns, nil
|
||||
}
|
||||
func composeDBTableFilter(name string, host string, db string) expression.Expression {
|
||||
dbMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("DB")}, &expression.Value{Val: db})
|
||||
return expression.NewBinaryOperation(opcode.AndAnd, composeUserTableFilter(name, host), dbMatch)
|
||||
@ -367,6 +463,22 @@ func composeTablePrivRset() *rsets.JoinRset {
|
||||
}
|
||||
}
|
||||
|
||||
func composeColumnPrivFilter(name string, host string, db string, tbl string, col string) expression.Expression {
|
||||
colMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("Column_name")}, &expression.Value{Val: col})
|
||||
return expression.NewBinaryOperation(opcode.AndAnd, composeTablePrivFilter(name, host, db, tbl), colMatch)
|
||||
}
|
||||
|
||||
func composeColumnPrivRset() *rsets.JoinRset {
|
||||
return &rsets.JoinRset{
|
||||
Left: &rsets.TableSource{
|
||||
Source: table.Ident{
|
||||
Name: model.NewCIStr(mysql.ColumnPrivTable),
|
||||
Schema: model.NewCIStr(mysql.SystemDB),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func dbUserExists(ctx context.Context, name string, host string, db string) (bool, error) {
|
||||
r := composeDBTableRset()
|
||||
p, err := r.Plan(ctx)
|
||||
@ -411,6 +523,28 @@ func tableUserExists(ctx context.Context, name string, host string, db string, t
|
||||
return row != nil, nil
|
||||
}
|
||||
|
||||
func columnPrivEntryExists(ctx context.Context, name string, host string, db string, tbl string, col string) (bool, error) {
|
||||
r := composeColumnPrivRset()
|
||||
p, err := r.Plan(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
where := &rsets.WhereRset{
|
||||
Src: p,
|
||||
Expr: composeColumnPrivFilter(name, host, db, tbl, col),
|
||||
}
|
||||
p, err = where.Plan(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
defer p.Close()
|
||||
row, err := p.Next(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
return row != nil, nil
|
||||
}
|
||||
|
||||
func getTablePriv(ctx context.Context, name string, host string, db string, tbl string) (string, string, error) {
|
||||
r := composeTablePrivRset()
|
||||
p, err := r.Plan(ctx)
|
||||
@ -451,6 +585,36 @@ func getTablePriv(ctx context.Context, name string, host string, db string, tbl
|
||||
return tPriv, cPriv, nil
|
||||
}
|
||||
|
||||
func getColumnPriv(ctx context.Context, name string, host string, db string, tbl string, col string) (string, error) {
|
||||
r := composeColumnPrivRset()
|
||||
p, err := r.Plan(ctx)
|
||||
if err != nil {
|
||||
return "", errors.Trace(err)
|
||||
}
|
||||
where := &rsets.WhereRset{
|
||||
Src: p,
|
||||
Expr: composeColumnPrivFilter(name, host, db, tbl, col),
|
||||
}
|
||||
p, err = where.Plan(ctx)
|
||||
if err != nil {
|
||||
return "", errors.Trace(err)
|
||||
}
|
||||
defer p.Close()
|
||||
row, err := p.Next(ctx)
|
||||
if err != nil {
|
||||
return "", errors.Trace(err)
|
||||
}
|
||||
cPriv := ""
|
||||
if row.Data[6] != nil {
|
||||
columnPriv, ok := row.Data[6].(mysql.Set)
|
||||
if !ok {
|
||||
return "", errors.Errorf("Column Priv should be mysql.Set but get %v with type %T", row.Data[6], row.Data[6])
|
||||
}
|
||||
cPriv = columnPriv.Name
|
||||
}
|
||||
return cPriv, nil
|
||||
}
|
||||
|
||||
func (s *GrantStmt) getTargetSchema(ctx context.Context) (*model.DBInfo, error) {
|
||||
dbName := s.Level.DBName
|
||||
if len(dbName) == 0 {
|
||||
@ -527,3 +691,34 @@ func (s *GrantStmt) grantTablePriv(ctx context.Context, priv *coldef.PrivElem, u
|
||||
_, err = st.Exec(ctx)
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
// Manipulate mysql.tables_priv table.
|
||||
func (s *GrantStmt) grantColumnPriv(ctx context.Context, priv *coldef.PrivElem, user *coldef.UserSpecification) error {
|
||||
db, tbl, err := s.getTargetSchemaAndTable(ctx)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
strs := strings.Split(user.User, "@")
|
||||
userName := strs[0]
|
||||
host := strs[1]
|
||||
for _, c := range priv.Cols {
|
||||
col := column.FindCol(tbl.Cols(), c)
|
||||
if col == nil {
|
||||
return errors.Errorf("Unknown column: %s", c)
|
||||
}
|
||||
asgns, err := composeColumnPrivUpdate(ctx, priv.Priv, userName, host, db.Name.O, tbl.TableName().O, col.Name.O)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
st := &UpdateStmt{
|
||||
TableRefs: composeColumnPrivRset(),
|
||||
List: asgns,
|
||||
Where: composeColumnPrivFilter(userName, host, db.Name.O, tbl.TableName().O, col.Name.O),
|
||||
}
|
||||
_, err = st.Exec(ctx)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -135,7 +135,7 @@ func (s *testStmtSuite) TestGrantDBScope(c *C) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testStmtSuite) TestTableDBScope(c *C) {
|
||||
func (s *testStmtSuite) TestTableScope(c *C) {
|
||||
tx := mustBegin(c, s.testDB)
|
||||
// Create a new user.
|
||||
createUserSQL := `CREATE USER 'testTbl'@'localhost' IDENTIFIED BY '123';`
|
||||
@ -144,7 +144,7 @@ func (s *testStmtSuite) TestTableDBScope(c *C) {
|
||||
mustCommit(c, tx)
|
||||
// Make sure all the table privs for new user is empty.
|
||||
tx = mustBegin(c, s.testDB)
|
||||
rows, err := tx.Query(`SELECT * FROM mysql.User WHERE User="testTbl" and host="localhost" and db="test" and Table_name="test1"`)
|
||||
rows, err := tx.Query(`SELECT * FROM mysql.Tables_priv WHERE User="testTbl" and host="localhost" and db="test" and Table_name="test1"`)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
mustCommit(c, tx)
|
||||
@ -188,3 +188,62 @@ func (s *testStmtSuite) TestTableDBScope(c *C) {
|
||||
mustCommit(c, tx)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testStmtSuite) TestColumnScope(c *C) {
|
||||
tx := mustBegin(c, s.testDB)
|
||||
// Create a new user.
|
||||
createUserSQL := `CREATE USER 'testCol'@'localhost' IDENTIFIED BY '123';`
|
||||
mustExec(c, s.testDB, createUserSQL)
|
||||
mustExec(c, s.testDB, `CREATE TABLE test.test3(c1 int, c2 int);`)
|
||||
mustCommit(c, tx)
|
||||
|
||||
// Make sure all the column privs for new user is empty.
|
||||
tx = mustBegin(c, s.testDB)
|
||||
rows, err := tx.Query(`SELECT * FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c1"`)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
mustCommit(c, tx)
|
||||
tx = mustBegin(c, s.testDB)
|
||||
rows, err = tx.Query(`SELECT * FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c2"`)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
mustCommit(c, tx)
|
||||
|
||||
// Grant each priv to the user.
|
||||
for _, v := range mysql.AllColumnPrivs {
|
||||
sql := fmt.Sprintf("GRANT %s(c1) ON test.test3 TO 'testCol'@'localhost';", mysql.Priv2Str[v])
|
||||
mustExec(c, s.testDB, sql)
|
||||
tx = mustBegin(c, s.testDB)
|
||||
rows, err := tx.Query(`SELECT Column_priv FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c1";`)
|
||||
c.Assert(err, IsNil)
|
||||
rows.Next()
|
||||
var p string
|
||||
rows.Scan(&p)
|
||||
c.Assert(strings.Index(p, mysql.Priv2SetStr[v]), Greater, -1)
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
rows.Close()
|
||||
mustCommit(c, tx)
|
||||
}
|
||||
|
||||
tx = mustBegin(c, s.testDB)
|
||||
// Create a new user.
|
||||
createUserSQL = `CREATE USER 'testCol1'@'localhost' IDENTIFIED BY '123';`
|
||||
mustExec(c, s.testDB, createUserSQL)
|
||||
mustExec(c, s.testDB, "USE test;")
|
||||
// Grant all column scope privs.
|
||||
mustExec(c, s.testDB, "GRANT ALL(c2) ON test3 TO 'testCol1'@'localhost';")
|
||||
mustCommit(c, tx)
|
||||
// Make sure all the column privs for granted user are in the Column_priv set.
|
||||
for _, v := range mysql.AllColumnPrivs {
|
||||
tx = mustBegin(c, s.testDB)
|
||||
rows, err := tx.Query(`SELECT Column_priv FROM mysql.Columns_priv WHERE User="testCol1" and host="localhost" and db="test" and Table_name="test3" and Column_name="c2";`)
|
||||
c.Assert(err, IsNil)
|
||||
rows.Next()
|
||||
var p string
|
||||
rows.Scan(&p)
|
||||
c.Assert(strings.Index(p, mysql.Priv2SetStr[v]), Greater, -1)
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
rows.Close()
|
||||
mustCommit(c, tx)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user