stmts: Support grant column scope privilege

This commit is contained in:
Shen Li
2015-10-02 23:09:56 +08:00
parent 2eb08c0033
commit 4c3b015edd
3 changed files with 263 additions and 7 deletions

View File

@ -126,6 +126,8 @@ const (
DBTable = "DB"
// TablePrivTable is the table in system db contains table scope privilege info.
TablePrivTable = "Tables_priv"
// ColumnPrivTable is the table in system db contains column scope privilege info.
ColumnPrivTable = "Columns_priv"
)
// PrivilegeType privilege

View File

@ -22,6 +22,7 @@ import (
"strings"
"github.com/juju/errors"
"github.com/pingcap/tidb/column"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/model"
@ -104,9 +105,15 @@ func (s *GrantStmt) Exec(ctx context.Context) (rset.Recordset, error) {
}
// Grant each priv to the user.
for _, priv := range s.Privs {
err := s.grantPriv(ctx, priv, user)
if err != nil {
return nil, errors.Trace(err)
if len(priv.Cols) > 0 {
err1 := s.checkAndInitColumnPriv(ctx, userName, host, priv.Cols)
if err1 != nil {
return nil, errors.Trace(err1)
}
}
err2 := s.grantPriv(ctx, priv, user)
if err2 != nil {
return nil, errors.Trace(err2)
}
}
}
@ -145,6 +152,32 @@ func (s *GrantStmt) checkAndInitTablePriv(ctx context.Context, user string, host
return initTablePrivEntry(ctx, user, host, db.Name.O, tbl.TableName().O)
}
func (s *GrantStmt) checkAndInitColumnPriv(ctx context.Context, user string, host string, cols []string) error {
db, tbl, err := s.getTargetSchemaAndTable(ctx)
if err != nil {
return errors.Trace(err)
}
for _, c := range cols {
col := column.FindCol(tbl.Cols(), c)
if col == nil {
return errors.Errorf("Unknown column: %s", c)
}
ok, err := columnPrivEntryExists(ctx, user, host, db.Name.O, tbl.TableName().O, col.Name.O)
if err != nil {
return errors.Trace(err)
}
if ok {
continue
}
// Entry does not exists for user/host/db. Insert a new entry.
err = initColumnPrivEntry(ctx, user, host, db.Name.O, tbl.TableName().O, col.Name.O)
if err != nil {
return errors.Trace(err)
}
}
return nil
}
func initDBPrivEntry(ctx context.Context, user string, host string, db string) error {
st := &InsertIntoStmt{
TableIdent: table.Ident{
@ -170,7 +203,7 @@ func initTablePrivEntry(ctx context.Context, user string, host string, db string
Name: model.NewCIStr(mysql.TablePrivTable),
Schema: model.NewCIStr(mysql.SystemDB),
},
ColNames: []string{"Host", "User", "DB", "Table_name"},
ColNames: []string{"Host", "User", "DB", "Table_name", "Table_priv", "Column_priv"},
}
values := make([][]expression.Expression, 0, 1)
value := make([]expression.Expression, 0, 3)
@ -178,6 +211,30 @@ func initTablePrivEntry(ctx context.Context, user string, host string, db string
value = append(value, &expression.Value{Val: user})
value = append(value, &expression.Value{Val: db})
value = append(value, &expression.Value{Val: tbl})
value = append(value, &expression.Value{Val: ""})
value = append(value, &expression.Value{Val: ""})
values = append(values, value)
st.Lists = values
_, err := st.Exec(ctx)
return errors.Trace(err)
}
func initColumnPrivEntry(ctx context.Context, user string, host string, db string, tbl string, col string) error {
st := &InsertIntoStmt{
TableIdent: table.Ident{
Name: model.NewCIStr(mysql.ColumnPrivTable),
Schema: model.NewCIStr(mysql.SystemDB),
},
ColNames: []string{"Host", "User", "DB", "Table_name", "Column_name", "Column_priv"},
}
values := make([][]expression.Expression, 0, 1)
value := make([]expression.Expression, 0, 3)
value = append(value, &expression.Value{Val: host}) // Host
value = append(value, &expression.Value{Val: user}) // User
value = append(value, &expression.Value{Val: db}) // DB
value = append(value, &expression.Value{Val: tbl}) // Table_name
value = append(value, &expression.Value{Val: col}) // Column_name
value = append(value, &expression.Value{Val: ""}) // Empty Column_priv
values = append(values, value)
st.Lists = values
_, err := st.Exec(ctx)
@ -191,7 +248,10 @@ func (s *GrantStmt) grantPriv(ctx context.Context, priv *coldef.PrivElem, user *
case coldef.GrantLevelDB:
return s.grantDBPriv(ctx, priv, user)
case coldef.GrantLevelTable:
return s.grantTablePriv(ctx, priv, user)
if len(priv.Cols) == 0 {
return s.grantTablePriv(ctx, priv, user)
}
return s.grantColumnPriv(ctx, priv, user)
default:
return errors.Errorf("Unknown grant level: %s", s.Level)
}
@ -335,6 +395,42 @@ func composeTablePrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name
return assigns, nil
}
func composeColumnPrivUpdate(ctx context.Context, priv mysql.PrivilegeType, name string, host string, db string, tbl string, col string) ([]expression.Assignment, error) {
newColumnPriv := ""
if priv == mysql.AllPriv {
for _, p := range mysql.AllColumnPrivs {
v, ok := mysql.Priv2SetStr[p]
if !ok {
return nil, errors.Errorf("Unknown column privilege %s", p)
}
if len(newColumnPriv) == 0 {
newColumnPriv = v
} else {
newColumnPriv = fmt.Sprintf("%s,%s", newColumnPriv, v)
}
}
} else {
currColumnPriv, err := getColumnPriv(ctx, name, host, db, tbl, col)
if err != nil {
return nil, errors.Trace(err)
}
p, ok := mysql.Priv2SetStr[priv]
if !ok {
return nil, errors.Errorf("Unknown priv: %s", priv)
}
if len(currColumnPriv) == 0 {
newColumnPriv = p
} else {
newColumnPriv = fmt.Sprintf("%s,%s", currColumnPriv, p)
}
}
t := expression.Assignment{
ColName: "Column_priv",
Expr: expression.Value{Val: newColumnPriv},
}
assigns := []expression.Assignment{t}
return assigns, nil
}
func composeDBTableFilter(name string, host string, db string) expression.Expression {
dbMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("DB")}, &expression.Value{Val: db})
return expression.NewBinaryOperation(opcode.AndAnd, composeUserTableFilter(name, host), dbMatch)
@ -367,6 +463,22 @@ func composeTablePrivRset() *rsets.JoinRset {
}
}
func composeColumnPrivFilter(name string, host string, db string, tbl string, col string) expression.Expression {
colMatch := expression.NewBinaryOperation(opcode.EQ, &expression.Ident{CIStr: model.NewCIStr("Column_name")}, &expression.Value{Val: col})
return expression.NewBinaryOperation(opcode.AndAnd, composeTablePrivFilter(name, host, db, tbl), colMatch)
}
func composeColumnPrivRset() *rsets.JoinRset {
return &rsets.JoinRset{
Left: &rsets.TableSource{
Source: table.Ident{
Name: model.NewCIStr(mysql.ColumnPrivTable),
Schema: model.NewCIStr(mysql.SystemDB),
},
},
}
}
func dbUserExists(ctx context.Context, name string, host string, db string) (bool, error) {
r := composeDBTableRset()
p, err := r.Plan(ctx)
@ -411,6 +523,28 @@ func tableUserExists(ctx context.Context, name string, host string, db string, t
return row != nil, nil
}
func columnPrivEntryExists(ctx context.Context, name string, host string, db string, tbl string, col string) (bool, error) {
r := composeColumnPrivRset()
p, err := r.Plan(ctx)
if err != nil {
return false, errors.Trace(err)
}
where := &rsets.WhereRset{
Src: p,
Expr: composeColumnPrivFilter(name, host, db, tbl, col),
}
p, err = where.Plan(ctx)
if err != nil {
return false, errors.Trace(err)
}
defer p.Close()
row, err := p.Next(ctx)
if err != nil {
return false, errors.Trace(err)
}
return row != nil, nil
}
func getTablePriv(ctx context.Context, name string, host string, db string, tbl string) (string, string, error) {
r := composeTablePrivRset()
p, err := r.Plan(ctx)
@ -451,6 +585,36 @@ func getTablePriv(ctx context.Context, name string, host string, db string, tbl
return tPriv, cPriv, nil
}
func getColumnPriv(ctx context.Context, name string, host string, db string, tbl string, col string) (string, error) {
r := composeColumnPrivRset()
p, err := r.Plan(ctx)
if err != nil {
return "", errors.Trace(err)
}
where := &rsets.WhereRset{
Src: p,
Expr: composeColumnPrivFilter(name, host, db, tbl, col),
}
p, err = where.Plan(ctx)
if err != nil {
return "", errors.Trace(err)
}
defer p.Close()
row, err := p.Next(ctx)
if err != nil {
return "", errors.Trace(err)
}
cPriv := ""
if row.Data[6] != nil {
columnPriv, ok := row.Data[6].(mysql.Set)
if !ok {
return "", errors.Errorf("Column Priv should be mysql.Set but get %v with type %T", row.Data[6], row.Data[6])
}
cPriv = columnPriv.Name
}
return cPriv, nil
}
func (s *GrantStmt) getTargetSchema(ctx context.Context) (*model.DBInfo, error) {
dbName := s.Level.DBName
if len(dbName) == 0 {
@ -527,3 +691,34 @@ func (s *GrantStmt) grantTablePriv(ctx context.Context, priv *coldef.PrivElem, u
_, err = st.Exec(ctx)
return errors.Trace(err)
}
// Manipulate mysql.tables_priv table.
func (s *GrantStmt) grantColumnPriv(ctx context.Context, priv *coldef.PrivElem, user *coldef.UserSpecification) error {
db, tbl, err := s.getTargetSchemaAndTable(ctx)
if err != nil {
return errors.Trace(err)
}
strs := strings.Split(user.User, "@")
userName := strs[0]
host := strs[1]
for _, c := range priv.Cols {
col := column.FindCol(tbl.Cols(), c)
if col == nil {
return errors.Errorf("Unknown column: %s", c)
}
asgns, err := composeColumnPrivUpdate(ctx, priv.Priv, userName, host, db.Name.O, tbl.TableName().O, col.Name.O)
if err != nil {
return errors.Trace(err)
}
st := &UpdateStmt{
TableRefs: composeColumnPrivRset(),
List: asgns,
Where: composeColumnPrivFilter(userName, host, db.Name.O, tbl.TableName().O, col.Name.O),
}
_, err = st.Exec(ctx)
if err != nil {
return errors.Trace(err)
}
}
return nil
}

View File

@ -135,7 +135,7 @@ func (s *testStmtSuite) TestGrantDBScope(c *C) {
}
}
func (s *testStmtSuite) TestTableDBScope(c *C) {
func (s *testStmtSuite) TestTableScope(c *C) {
tx := mustBegin(c, s.testDB)
// Create a new user.
createUserSQL := `CREATE USER 'testTbl'@'localhost' IDENTIFIED BY '123';`
@ -144,7 +144,7 @@ func (s *testStmtSuite) TestTableDBScope(c *C) {
mustCommit(c, tx)
// Make sure all the table privs for new user is empty.
tx = mustBegin(c, s.testDB)
rows, err := tx.Query(`SELECT * FROM mysql.User WHERE User="testTbl" and host="localhost" and db="test" and Table_name="test1"`)
rows, err := tx.Query(`SELECT * FROM mysql.Tables_priv WHERE User="testTbl" and host="localhost" and db="test" and Table_name="test1"`)
c.Assert(err, IsNil)
c.Assert(rows.Next(), IsFalse)
mustCommit(c, tx)
@ -188,3 +188,62 @@ func (s *testStmtSuite) TestTableDBScope(c *C) {
mustCommit(c, tx)
}
}
func (s *testStmtSuite) TestColumnScope(c *C) {
tx := mustBegin(c, s.testDB)
// Create a new user.
createUserSQL := `CREATE USER 'testCol'@'localhost' IDENTIFIED BY '123';`
mustExec(c, s.testDB, createUserSQL)
mustExec(c, s.testDB, `CREATE TABLE test.test3(c1 int, c2 int);`)
mustCommit(c, tx)
// Make sure all the column privs for new user is empty.
tx = mustBegin(c, s.testDB)
rows, err := tx.Query(`SELECT * FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c1"`)
c.Assert(err, IsNil)
c.Assert(rows.Next(), IsFalse)
mustCommit(c, tx)
tx = mustBegin(c, s.testDB)
rows, err = tx.Query(`SELECT * FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c2"`)
c.Assert(err, IsNil)
c.Assert(rows.Next(), IsFalse)
mustCommit(c, tx)
// Grant each priv to the user.
for _, v := range mysql.AllColumnPrivs {
sql := fmt.Sprintf("GRANT %s(c1) ON test.test3 TO 'testCol'@'localhost';", mysql.Priv2Str[v])
mustExec(c, s.testDB, sql)
tx = mustBegin(c, s.testDB)
rows, err := tx.Query(`SELECT Column_priv FROM mysql.Columns_priv WHERE User="testCol" and host="localhost" and db="test" and Table_name="test3" and Column_name="c1";`)
c.Assert(err, IsNil)
rows.Next()
var p string
rows.Scan(&p)
c.Assert(strings.Index(p, mysql.Priv2SetStr[v]), Greater, -1)
c.Assert(rows.Next(), IsFalse)
rows.Close()
mustCommit(c, tx)
}
tx = mustBegin(c, s.testDB)
// Create a new user.
createUserSQL = `CREATE USER 'testCol1'@'localhost' IDENTIFIED BY '123';`
mustExec(c, s.testDB, createUserSQL)
mustExec(c, s.testDB, "USE test;")
// Grant all column scope privs.
mustExec(c, s.testDB, "GRANT ALL(c2) ON test3 TO 'testCol1'@'localhost';")
mustCommit(c, tx)
// Make sure all the column privs for granted user are in the Column_priv set.
for _, v := range mysql.AllColumnPrivs {
tx = mustBegin(c, s.testDB)
rows, err := tx.Query(`SELECT Column_priv FROM mysql.Columns_priv WHERE User="testCol1" and host="localhost" and db="test" and Table_name="test3" and Column_name="c2";`)
c.Assert(err, IsNil)
rows.Next()
var p string
rows.Scan(&p)
c.Assert(strings.Index(p, mysql.Priv2SetStr[v]), Greater, -1)
c.Assert(rows.Next(), IsFalse)
rows.Close()
mustCommit(c, tx)
}
}