diff --git a/mysqldef/const.go b/mysqldef/const.go index cc3dc9f893..4f67492971 100644 --- a/mysqldef/const.go +++ b/mysqldef/const.go @@ -126,6 +126,8 @@ const ( DBTable = "DB" // TablePrivTable is the table in system db contains table scope privilege info. TablePrivTable = "Tables_priv" + // ColumnPrivTable is the table in system db contains column scope privilege info. + ColumnPrivTable = "Columns_priv" ) // PrivilegeType privilege diff --git a/stmt/stmts/grant.go b/stmt/stmts/grant.go index 0af1def1e9..6e7b5e8151 100644 --- a/stmt/stmts/grant.go +++ b/stmt/stmts/grant.go @@ -22,6 +22,7 @@ import ( "strings" "github.com/juju/errors" + "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/model" @@ -104,9 +105,15 @@ func (s *GrantStmt) Exec(ctx context.Context) (rset.Recordset, error) { } // Grant each priv to the user. for _, priv := range s.Privs { - err := s.grantPriv(ctx, priv, user) - if err != nil { - return nil, errors.Trace(err) + if len(priv.Cols) > 0 { + err1 := s.checkAndInitColumnPriv(ctx, userName, host, priv.Cols) + if err1 != nil { + return nil, errors.Trace(err1) + } + } + err2 := s.grantPriv(ctx, priv, user) + if err2 != nil { + return nil, errors.Trace(err2) } } } @@ -145,6 +152,32 @@ func (s *GrantStmt) checkAndInitTablePriv(ctx context.Context, user string, host return initTablePrivEntry(ctx, user, host, db.Name.O, tbl.TableName().O) } +func (s *GrantStmt) checkAndInitColumnPriv(ctx context.Context, user string, host string, cols []string) error { + db, tbl, err := s.getTargetSchemaAndTable(ctx) + if err != nil { + return errors.Trace(err) + } + for _, c := range cols { + col := column.FindCol(tbl.Cols(), c) + if col == nil { + return errors.Errorf("Unknown column: %s", c) + } + ok, err := columnPrivEntryExists(ctx, user, host, db.Name.O, tbl.TableName().O, col.Name.O) + if err != nil { + return errors.Trace(err) + } + if ok { + continue + } + // Entry does not exists for user/host/db. Insert a new entry. + err = initColumnPrivEntry(ctx, user, host, db.Name.O, tbl.TableName().O, col.Name.O) + if err != nil { + return errors.Trace(err) + } + } + return nil +} + func initDBPrivEntry(ctx context.Context, user string, host string, db string) error { st := &InsertIntoStmt{ TableIdent: table.Ident{ @@ -170,7 +203,7 @@ func initTablePrivEntry(ctx context.Context, user string, host string, db string Name: model.NewCIStr(mysql.TablePrivTable), Schema: model.NewCIStr(mysql.SystemDB), }, - ColNames: []string{"Host", "User", "DB", "Table_name"}, + ColNames: []string{"Host", "User", "DB", "Table_name", "Table_priv", "Column_priv"}, } values := make([][]expression.Expression, 0, 1) value := make([]expression.Expression, 0, 3) @@ -178,6 +211,30 @@ func initTablePrivEntry(ctx context.Context, user string, host string, db string value = append(value, &expression.Value{Val: user}) value = append(value, &expression.Value{Val: db}) value = append(value, &expression.Value{Val: tbl}) + value = append(value, &expression.Value{Val: ""}) + value = append(value, &expression.Value{Val: ""}) + values = append(values, value) + st.Lists = values + _, err := st.Exec(ctx) + return errors.Trace(err) +} + +func initColumnPrivEntry(ctx context.Context, user string, host string, db string, tbl string, col string) error { + st := &InsertIntoStmt{ + TableIdent: table.Ident{ + Name: model.NewCIStr(mysql.ColumnPrivTable), + Schema: model.NewCIStr(mysql.SystemDB), + }, + ColNames: []string{"Host", "User", "DB", "Table_name", "Column_name", "Column_priv"}, + } + values := make([][]expression.Expression, 0, 1) + value := make([]expression.Expression, 0, 3) + value = append(value, &expression.Value{Val: host}) // Host + value = append(value, &expression.Value{Val: user}) // User + value = append(value, &expression.Value{Val: db}) // DB + value = append(value, &expression.Value{Val: tbl}) // Table_name + value = append(value, &expression.Value{Val: col}) // Column_name + value = append(value, &expression.Value{Val: ""}) // Empty Column_priv values = append(values, value) st.Lists = values _, err := st.Exec(ctx) @@ -191,7 +248,10 @@ func (s *GrantStmt) grantPriv(ctx context.Context, priv *coldef.PrivElem, user * case coldef.GrantLevelDB: return s.grantDBPriv(ctx, priv, user) case coldef.GrantLevelTable: - return s.grantTablePriv(ctx, priv, user) + if len(priv.Cols) == 0 { + return s.grantTablePriv(ctx, priv, user) + } + return s.grantColumnPriv(ctx, priv, user) default: return errors.Errorf("Unknown grant level: %s", s.Level) } @@ -335,6 +395,42 @@ func composeTablePrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name return assigns, nil } +func composeColumnPrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) ([]expression.Assignment, error) { + newColumnPriv := "" + if priv == mysql.AllPriv { + for _, p := range mysql.AllColumnPrivs { + v, ok := mysql.Priv2SetStr[p] + if !ok { + return nil, errors.Errorf("Unknown column privilege %s", p) + } + if len(newColumnPriv) == 0 { + newColumnPriv = v + } else { + newColumnPriv = fmt.Sprintf("%s,%s", newColumnPriv, v) + } + } + } else { + currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col) + if err != nil { + return nil, errors.Trace(err) + } + p, ok := mysql.Priv2SetStr[priv] + if !ok { + return nil, errors.Errorf("Unknown priv: %s", priv) + } + if len(currColumnPriv) == 0 { + newColumnPriv = p + } else { + newColumnPriv = fmt.Sprintf("%s,%s", currColumnPriv, p) + } + } + t := expression.Assignment{ + ColName: "Column_priv", + Expr: expression.Value{Val: newColumnPriv}, + } + assigns := []expression.Assignment{t} + return assigns, nil +} func composeDBTableFilter(name string, host string, db string) expression.Expression { dbMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("DB")}, &expression.Value{Val: db}) return expression.NewBinaryOperation(opcode.AndAnd, composeUserTableFilter(name, host), dbMatch) @@ -367,6 +463,22 @@ func composeTablePrivRset() *rsets.JoinRset { } } +func composeColumnPrivFilter(name string, host string, db string, tbl string, col string) expression.Expression { + colMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("Column_name")}, &expression.Value{Val: col}) + return expression.NewBinaryOperation(opcode.AndAnd, composeTablePrivFilter(name, host, db, tbl), colMatch) +} + +func composeColumnPrivRset() *rsets.JoinRset { + return &rsets.JoinRset{ + Left: &rsets.TableSource{ + Source: table.Ident{ + Name: model.NewCIStr(mysql.ColumnPrivTable), + Schema: model.NewCIStr(mysql.SystemDB), + }, + }, + } +} + func dbUserExists(ctx context.Context, name string, host string, db string) (bool, error) { r := composeDBTableRset() p, err := r.Plan(ctx) @@ -411,6 +523,28 @@ func tableUserExists(ctx context.Context, name string, host string, db string, t return row != nil, nil } +func columnPrivEntryExists(ctx context.Context, name string, host string, db string, tbl string, col string) (bool, error) { + r := composeColumnPrivRset() + p, err := r.Plan(ctx) + if err != nil { + return false, errors.Trace(err) + } + where := &rsets.WhereRset{ + Src: p, + Expr: composeColumnPrivFilter(name, host, db, tbl, col), + } + p, err = where.Plan(ctx) + if err != nil { + return false, errors.Trace(err) + } + defer p.Close() + row, err := p.Next(ctx) + if err != nil { + return false, errors.Trace(err) + } + return row != nil, nil +} + func getTablePriv(ctx context.Context, name string, host string, db string, tbl string) (string, string, error) { r := composeTablePrivRset() p, err := r.Plan(ctx) @@ -451,6 +585,36 @@ func getTablePriv(ctx context.Context, name string, host string, db string, tbl return tPriv, cPriv, nil } +func getColumnPriv(ctx context.Context, name string, host string, db string, tbl string, col string) (string, error) { + r := composeColumnPrivRset() + p, err := r.Plan(ctx) + if err != nil { + return "", errors.Trace(err) + } + where := &rsets.WhereRset{ + Src: p, + Expr: composeColumnPrivFilter(name, host, db, tbl, col), + } + p, err = where.Plan(ctx) + if err != nil { + return "", errors.Trace(err) + } + defer p.Close() + row, err := p.Next(ctx) + if err != nil { + return "", errors.Trace(err) + } + cPriv := "" + if row.Data[6] != nil { + columnPriv, ok := row.Data[6].(mysql.Set) + if !ok { + return "", errors.Errorf("Column Priv should be mysql.Set but get %v with type %T", row.Data[6], row.Data[6]) + } + cPriv = columnPriv.Name + } + return cPriv, nil +} + func (s *GrantStmt) getTargetSchema(ctx context.Context) (*model.DBInfo, error) { dbName := s.Level.DBName if len(dbName) == 0 { @@ -527,3 +691,34 @@ func (s *GrantStmt) grantTablePriv(ctx context.Context, priv *coldef.PrivElem, u _, err = st.Exec(ctx) return errors.Trace(err) } + +// Manipulate mysql.tables_priv table. +func (s *GrantStmt) grantColumnPriv(ctx context.Context, priv *coldef.PrivElem, user *coldef.UserSpecification) error { + db, tbl, err := s.getTargetSchemaAndTable(ctx) + if err != nil { + return errors.Trace(err) + } + strs := strings.Split(user.User, "@") + userName := strs[0] + host := strs[1] + for _, c := range priv.Cols { + col := column.FindCol(tbl.Cols(), c) + if col == nil { + return errors.Errorf("Unknown column: %s", c) + } + asgns, err := composeColumnPrivUpdate(ctx, priv.Priv, userName, host, db.Name.O, tbl.TableName().O, col.Name.O) + if err != nil { + return errors.Trace(err) + } + st := &UpdateStmt{ + TableRefs: composeColumnPrivRset(), + List: asgns, + Where: composeColumnPrivFilter(userName, host, db.Name.O, tbl.TableName().O, col.Name.O), + } + _, err = st.Exec(ctx) + if err != nil { + return errors.Trace(err) + } + } + return nil +} diff --git a/stmt/stmts/grant_test.go b/stmt/stmts/grant_test.go index aaf66f5d40..e53d189eaf 100644 --- a/stmt/stmts/grant_test.go +++ b/stmt/stmts/grant_test.go @@ -135,7 +135,7 @@ func (s *testStmtSuite) TestGrantDBScope(c *C) { } } -func (s *testStmtSuite) TestTableDBScope(c *C) { +func (s *testStmtSuite) TestTableScope(c *C) { tx := mustBegin(c, s.testDB) // Create a new user. createUserSQL := `CREATE USER 'testTbl'@'localhost' IDENTIFIED BY '123';` @@ -144,7 +144,7 @@ func (s *testStmtSuite) TestTableDBScope(c *C) { mustCommit(c, tx) // Make sure all the table privs for new user is empty. tx = mustBegin(c, s.testDB) - rows, err := tx.Query(`SELECT * FROM mysql.User WHERE User="testTbl" and host="localhost" and db="test" and Table_name="test1"`) + rows, err := tx.Query(`SELECT * FROM mysql.Tables_priv WHERE User="testTbl" and host="localhost" and db="test" and Table_name="test1"`) c.Assert(err, IsNil) c.Assert(rows.Next(), IsFalse) mustCommit(c, tx) @@ -188,3 +188,62 @@ func (s *testStmtSuite) TestTableDBScope(c *C) { mustCommit(c, tx) } } + +func (s *testStmtSuite) TestColumnScope(c *C) { + tx := mustBegin(c, s.testDB) + // Create a new user. + createUserSQL := `CREATE USER 'testCol'@'localhost' IDENTIFIED BY '123';` + mustExec(c, s.testDB, createUserSQL) + mustExec(c, s.testDB, `CREATE TABLE test.test3(c1 int, c2 int);`) + mustCommit(c, tx) + + // Make sure all the column privs for new user is empty. + tx = mustBegin(c, s.testDB) + rows, err := tx.Query(`SELECT * FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c1"`) + c.Assert(err, IsNil) + c.Assert(rows.Next(), IsFalse) + mustCommit(c, tx) + tx = mustBegin(c, s.testDB) + rows, err = tx.Query(`SELECT * FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c2"`) + c.Assert(err, IsNil) + c.Assert(rows.Next(), IsFalse) + mustCommit(c, tx) + + // Grant each priv to the user. + for _, v := range mysql.AllColumnPrivs { + sql := fmt.Sprintf("GRANT %s(c1) ON test.test3 TO 'testCol'@'localhost';", mysql.Priv2Str[v]) + mustExec(c, s.testDB, sql) + tx = mustBegin(c, s.testDB) + rows, err := tx.Query(`SELECT Column_priv FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c1";`) + c.Assert(err, IsNil) + rows.Next() + var p string + rows.Scan(&p) + c.Assert(strings.Index(p, mysql.Priv2SetStr[v]), Greater, -1) + c.Assert(rows.Next(), IsFalse) + rows.Close() + mustCommit(c, tx) + } + + tx = mustBegin(c, s.testDB) + // Create a new user. + createUserSQL = `CREATE USER 'testCol1'@'localhost' IDENTIFIED BY '123';` + mustExec(c, s.testDB, createUserSQL) + mustExec(c, s.testDB, "USE test;") + // Grant all column scope privs. + mustExec(c, s.testDB, "GRANT ALL(c2) ON test3 TO 'testCol1'@'localhost';") + mustCommit(c, tx) + // Make sure all the column privs for granted user are in the Column_priv set. + for _, v := range mysql.AllColumnPrivs { + tx = mustBegin(c, s.testDB) + rows, err := tx.Query(`SELECT Column_priv FROM mysql.Columns_priv WHERE User="testCol1" and host="localhost" and db="test" and Table_name="test3" and Column_name="c2";`) + c.Assert(err, IsNil) + rows.Next() + var p string + rows.Scan(&p) + c.Assert(strings.Index(p, mysql.Priv2SetStr[v]), Greater, -1) + c.Assert(rows.Next(), IsFalse) + rows.Close() + mustCommit(c, tx) + } +}