stmts: Add missing files
This commit is contained in:
205
stmt/stmts/account_manage.go
Normal file
205
stmt/stmts/account_manage.go
Normal file
@ -0,0 +1,205 @@
|
||||
// Copyright 2013 The ql Authors. All rights reserved.
|
||||
// Use of this source code is governed by a BSD-style
|
||||
// license that can be found in the LICENSES/QL-LICENSE file.
|
||||
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stmts
|
||||
|
||||
import (
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/expression/expressions"
|
||||
"github.com/pingcap/tidb/model"
|
||||
mysql "github.com/pingcap/tidb/mysqldef"
|
||||
"github.com/pingcap/tidb/parser/coldef"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/rset"
|
||||
"github.com/pingcap/tidb/rset/rsets"
|
||||
"github.com/pingcap/tidb/stmt"
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/util/format"
|
||||
"strings"
|
||||
)
|
||||
|
||||
/************************************************************************************
|
||||
* Account Management Statements
|
||||
* https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html
|
||||
************************************************************************************/
|
||||
var (
|
||||
_ stmt.Statement = (*CreateUserStmt)(nil)
|
||||
_ stmt.Statement = (*SetPwdStmt)(nil)
|
||||
)
|
||||
|
||||
// CreateUserStmt creates user account.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html
|
||||
type CreateUserStmt struct {
|
||||
IfNotExists bool
|
||||
Specs []*coldef.UserSpecification
|
||||
|
||||
Text string
|
||||
}
|
||||
|
||||
// Explain implements the stmt.Statement Explain interface.
|
||||
func (s *CreateUserStmt) Explain(ctx context.Context, w format.Formatter) {
|
||||
w.Format("%s\n", s.Text)
|
||||
}
|
||||
|
||||
// IsDDL implements the stmt.Statement IsDDL interface.
|
||||
func (s *CreateUserStmt) IsDDL() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// OriginText implements the stmt.Statement OriginText interface.
|
||||
func (s *CreateUserStmt) OriginText() string {
|
||||
return s.Text
|
||||
}
|
||||
|
||||
// SetText implements the stmt.Statement SetText interface.
|
||||
func (s *CreateUserStmt) SetText(text string) {
|
||||
s.Text = text
|
||||
}
|
||||
|
||||
func composeUserTableFilter(name string, host string) expression.Expression {
|
||||
nameMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("User")}, &expressions.Value{Val: name})
|
||||
hostMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("Host")}, &expressions.Value{Val: host})
|
||||
return expressions.NewBinaryOperation(opcode.AndAnd, nameMatch, hostMatch)
|
||||
}
|
||||
|
||||
func composeUserTableRset() *rsets.JoinRset {
|
||||
return &rsets.JoinRset{
|
||||
Left: &rsets.TableSource{
|
||||
Source: table.Ident{
|
||||
Name: model.NewCIStr(mysql.UserTable),
|
||||
Schema: model.NewCIStr(mysql.SystemDB),
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *CreateUserStmt) userExists(ctx context.Context, name string, host string) (bool, error) {
|
||||
r := composeUserTableRset()
|
||||
p, err := r.Plan(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
where := &rsets.WhereRset{
|
||||
Src: p,
|
||||
Expr: composeUserTableFilter(name, host),
|
||||
}
|
||||
p, err = where.Plan(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
row, err := p.Next(ctx)
|
||||
if err != nil {
|
||||
return false, errors.Trace(err)
|
||||
}
|
||||
p.Close()
|
||||
return row != nil, nil
|
||||
}
|
||||
|
||||
// Exec implements the stmt.Statement Exec interface.
|
||||
func (s *CreateUserStmt) Exec(ctx context.Context) (rset.Recordset, error) {
|
||||
st := &InsertIntoStmt{
|
||||
TableIdent: table.Ident{
|
||||
Name: model.NewCIStr(mysql.UserTable),
|
||||
Schema: model.NewCIStr(mysql.SystemDB),
|
||||
},
|
||||
}
|
||||
values := make([][]expression.Expression, 0, len(s.Specs))
|
||||
for _, spec := range s.Specs {
|
||||
strs := strings.Split(spec.User, "@")
|
||||
userName := strs[0]
|
||||
host := strs[1]
|
||||
exists, err1 := s.userExists(ctx, userName, host)
|
||||
if err1 != nil {
|
||||
return nil, errors.Trace(err1)
|
||||
}
|
||||
if exists {
|
||||
if !s.IfNotExists {
|
||||
return nil, errors.Errorf("Duplicate user")
|
||||
}
|
||||
continue
|
||||
}
|
||||
value := make([]expression.Expression, 0, 3)
|
||||
value = append(value, expressions.Value{Val: host})
|
||||
value = append(value, expressions.Value{Val: userName})
|
||||
if spec.AuthOpt.ByAuthString {
|
||||
value = append(value, expressions.Value{Val: spec.AuthOpt.AuthString})
|
||||
} else {
|
||||
// TODO: Maybe we should hash the string here?
|
||||
value = append(value, expressions.Value{Val: spec.AuthOpt.HashString})
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
if len(values) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
st.Lists = values
|
||||
_, err := st.Exec(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// SetPwdStmt is a statement to assign a password to user account.
|
||||
// See: https://dev.mysql.com/doc/refman/5.7/en/set-password.html
|
||||
type SetPwdStmt struct {
|
||||
User string
|
||||
Password string
|
||||
|
||||
Text string
|
||||
}
|
||||
|
||||
// Explain implements the stmt.Statement Explain interface.
|
||||
func (s *SetPwdStmt) Explain(ctx context.Context, w format.Formatter) {
|
||||
w.Format("%s\n", s.Text)
|
||||
}
|
||||
|
||||
// IsDDL implements the stmt.Statement IsDDL interface.
|
||||
func (s *SetPwdStmt) IsDDL() bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// OriginText implements the stmt.Statement OriginText interface.
|
||||
func (s *SetPwdStmt) OriginText() string {
|
||||
return s.Text
|
||||
}
|
||||
|
||||
// SetText implements the stmt.Statement SetText interface.
|
||||
func (s *SetPwdStmt) SetText(text string) {
|
||||
s.Text = text
|
||||
}
|
||||
|
||||
// Exec implements the stmt.Statement Exec interface.
|
||||
func (s *SetPwdStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
|
||||
// If len(s.User) == 0, use CURRENT_USER()
|
||||
strs := strings.Split(s.User, "@")
|
||||
userName := strs[0]
|
||||
host := strs[1]
|
||||
// Update mysql.user
|
||||
asgn := expressions.Assignment{
|
||||
ColName: "Password",
|
||||
Expr: expressions.Value{Val: s.Password},
|
||||
}
|
||||
st := &UpdateStmt{
|
||||
TableRefs: composeUserTableRset(),
|
||||
List: []expressions.Assignment{asgn},
|
||||
Where: composeUserTableFilter(userName, host),
|
||||
}
|
||||
return st.Exec(ctx)
|
||||
}
|
||||
80
stmt/stmts/account_manage_test.go
Normal file
80
stmt/stmts/account_manage_test.go
Normal file
@ -0,0 +1,80 @@
|
||||
// Copyright 2015 PingCAP, Inc.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package stmts_test
|
||||
|
||||
import (
|
||||
. "github.com/pingcap/check"
|
||||
)
|
||||
|
||||
func (s *testStmtSuite) TestCreateUserStmt(c *C) {
|
||||
// Make sure user test not in mysql.User.
|
||||
tx := mustBegin(c, s.testDB)
|
||||
rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="test" and Host="localhost"`)
|
||||
c.Assert(err, IsNil)
|
||||
rows.Next()
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
rows.Close()
|
||||
mustCommit(c, tx)
|
||||
// Create user test.
|
||||
createUserSQL := `CREATE USER 'test'@'localhost' IDENTIFIED BY '123';`
|
||||
mustExec(c, s.testDB, createUserSQL)
|
||||
// Make sure user test in mysql.User.
|
||||
tx = mustBegin(c, s.testDB)
|
||||
rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="test" and Host="localhost"`)
|
||||
c.Assert(err, IsNil)
|
||||
rows.Next()
|
||||
var pwd string
|
||||
rows.Scan(&pwd)
|
||||
c.Assert(pwd, Equals, "123")
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
rows.Close()
|
||||
mustCommit(c, tx)
|
||||
// Create duplicate user with IfNotExists will be success.
|
||||
createUserSQL = `CREATE USER IF NOT EXISTS 'test'@'localhost' IDENTIFIED BY '123';`
|
||||
mustExec(c, s.testDB, createUserSQL)
|
||||
|
||||
// Create duplicate user without IfNotExists will cause error.
|
||||
createUserSQL = `CREATE USER 'test'@'localhost' IDENTIFIED BY '123';`
|
||||
tx = mustBegin(c, s.testDB)
|
||||
_, err = tx.Query(createUserSQL)
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
|
||||
func (s *testStmtSuite) TestSetPwdStmt(c *C) {
|
||||
tx := mustBegin(c, s.testDB)
|
||||
tx.Query(`INSERT INTO mysql.User VALUES ("localhost", "root", ""), ("127.0.0.1", "root", "")`)
|
||||
rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`)
|
||||
c.Assert(err, IsNil)
|
||||
rows.Next()
|
||||
var pwd string
|
||||
rows.Scan(&pwd)
|
||||
c.Assert(pwd, Equals, "")
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
rows.Close()
|
||||
mustCommit(c, tx)
|
||||
|
||||
tx = mustBegin(c, s.testDB)
|
||||
tx.Query(`SET PASSWORD FOR 'root'@'localhost' = 'password';`)
|
||||
mustCommit(c, tx)
|
||||
|
||||
tx = mustBegin(c, s.testDB)
|
||||
rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`)
|
||||
c.Assert(err, IsNil)
|
||||
rows.Next()
|
||||
rows.Scan(&pwd)
|
||||
c.Assert(pwd, Equals, "password")
|
||||
c.Assert(rows.Next(), IsFalse)
|
||||
rows.Close()
|
||||
mustCommit(c, tx)
|
||||
}
|
||||
Reference in New Issue
Block a user