stmts: Add missing files

This commit is contained in:
Shen Li
2015-09-20 16:22:14 +08:00
parent 227c9a16d1
commit 428c629dee
2 changed files with 285 additions and 0 deletions

View File

@ -0,0 +1,205 @@
// Copyright 2013 The ql Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSES/QL-LICENSE file.
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package stmts
import (
"github.com/juju/errors"
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/expression"
"github.com/pingcap/tidb/expression/expressions"
"github.com/pingcap/tidb/model"
mysql "github.com/pingcap/tidb/mysqldef"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/parser/opcode"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/rset/rsets"
"github.com/pingcap/tidb/stmt"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/util/format"
"strings"
)
/************************************************************************************
* Account Management Statements
* https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html
************************************************************************************/
var (
_ stmt.Statement = (*CreateUserStmt)(nil)
_ stmt.Statement = (*SetPwdStmt)(nil)
)
// CreateUserStmt creates user account.
// See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html
type CreateUserStmt struct {
IfNotExists bool
Specs []*coldef.UserSpecification
Text string
}
// Explain implements the stmt.Statement Explain interface.
func (s *CreateUserStmt) Explain(ctx context.Context, w format.Formatter) {
w.Format("%s\n", s.Text)
}
// IsDDL implements the stmt.Statement IsDDL interface.
func (s *CreateUserStmt) IsDDL() bool {
return true
}
// OriginText implements the stmt.Statement OriginText interface.
func (s *CreateUserStmt) OriginText() string {
return s.Text
}
// SetText implements the stmt.Statement SetText interface.
func (s *CreateUserStmt) SetText(text string) {
s.Text = text
}
func composeUserTableFilter(name string, host string) expression.Expression {
nameMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("User")}, &expressions.Value{Val: name})
hostMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("Host")}, &expressions.Value{Val: host})
return expressions.NewBinaryOperation(opcode.AndAnd, nameMatch, hostMatch)
}
func composeUserTableRset() *rsets.JoinRset {
return &rsets.JoinRset{
Left: &rsets.TableSource{
Source: table.Ident{
Name: model.NewCIStr(mysql.UserTable),
Schema: model.NewCIStr(mysql.SystemDB),
},
},
}
}
func (s *CreateUserStmt) userExists(ctx context.Context, name string, host string) (bool, error) {
r := composeUserTableRset()
p, err := r.Plan(ctx)
if err != nil {
return false, errors.Trace(err)
}
where := &rsets.WhereRset{
Src: p,
Expr: composeUserTableFilter(name, host),
}
p, err = where.Plan(ctx)
if err != nil {
return false, errors.Trace(err)
}
row, err := p.Next(ctx)
if err != nil {
return false, errors.Trace(err)
}
p.Close()
return row != nil, nil
}
// Exec implements the stmt.Statement Exec interface.
func (s *CreateUserStmt) Exec(ctx context.Context) (rset.Recordset, error) {
st := &InsertIntoStmt{
TableIdent: table.Ident{
Name: model.NewCIStr(mysql.UserTable),
Schema: model.NewCIStr(mysql.SystemDB),
},
}
values := make([][]expression.Expression, 0, len(s.Specs))
for _, spec := range s.Specs {
strs := strings.Split(spec.User, "@")
userName := strs[0]
host := strs[1]
exists, err1 := s.userExists(ctx, userName, host)
if err1 != nil {
return nil, errors.Trace(err1)
}
if exists {
if !s.IfNotExists {
return nil, errors.Errorf("Duplicate user")
}
continue
}
value := make([]expression.Expression, 0, 3)
value = append(value, expressions.Value{Val: host})
value = append(value, expressions.Value{Val: userName})
if spec.AuthOpt.ByAuthString {
value = append(value, expressions.Value{Val: spec.AuthOpt.AuthString})
} else {
// TODO: Maybe we should hash the string here?
value = append(value, expressions.Value{Val: spec.AuthOpt.HashString})
}
values = append(values, value)
}
if len(values) == 0 {
return nil, nil
}
st.Lists = values
_, err := st.Exec(ctx)
if err != nil {
return nil, errors.Trace(err)
}
return nil, nil
}
// SetPwdStmt is a statement to assign a password to user account.
// See: https://dev.mysql.com/doc/refman/5.7/en/set-password.html
type SetPwdStmt struct {
User string
Password string
Text string
}
// Explain implements the stmt.Statement Explain interface.
func (s *SetPwdStmt) Explain(ctx context.Context, w format.Formatter) {
w.Format("%s\n", s.Text)
}
// IsDDL implements the stmt.Statement IsDDL interface.
func (s *SetPwdStmt) IsDDL() bool {
return false
}
// OriginText implements the stmt.Statement OriginText interface.
func (s *SetPwdStmt) OriginText() string {
return s.Text
}
// SetText implements the stmt.Statement SetText interface.
func (s *SetPwdStmt) SetText(text string) {
s.Text = text
}
// Exec implements the stmt.Statement Exec interface.
func (s *SetPwdStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
// If len(s.User) == 0, use CURRENT_USER()
strs := strings.Split(s.User, "@")
userName := strs[0]
host := strs[1]
// Update mysql.user
asgn := expressions.Assignment{
ColName: "Password",
Expr: expressions.Value{Val: s.Password},
}
st := &UpdateStmt{
TableRefs: composeUserTableRset(),
List: []expressions.Assignment{asgn},
Where: composeUserTableFilter(userName, host),
}
return st.Exec(ctx)
}

View File

@ -0,0 +1,80 @@
// Copyright 2015 PingCAP, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.
package stmts_test
import (
. "github.com/pingcap/check"
)
func (s *testStmtSuite) TestCreateUserStmt(c *C) {
// Make sure user test not in mysql.User.
tx := mustBegin(c, s.testDB)
rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="test" and Host="localhost"`)
c.Assert(err, IsNil)
rows.Next()
c.Assert(rows.Next(), IsFalse)
rows.Close()
mustCommit(c, tx)
// Create user test.
createUserSQL := `CREATE USER 'test'@'localhost' IDENTIFIED BY '123';`
mustExec(c, s.testDB, createUserSQL)
// Make sure user test in mysql.User.
tx = mustBegin(c, s.testDB)
rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="test" and Host="localhost"`)
c.Assert(err, IsNil)
rows.Next()
var pwd string
rows.Scan(&pwd)
c.Assert(pwd, Equals, "123")
c.Assert(rows.Next(), IsFalse)
rows.Close()
mustCommit(c, tx)
// Create duplicate user with IfNotExists will be success.
createUserSQL = `CREATE USER IF NOT EXISTS 'test'@'localhost' IDENTIFIED BY '123';`
mustExec(c, s.testDB, createUserSQL)
// Create duplicate user without IfNotExists will cause error.
createUserSQL = `CREATE USER 'test'@'localhost' IDENTIFIED BY '123';`
tx = mustBegin(c, s.testDB)
_, err = tx.Query(createUserSQL)
c.Assert(err, NotNil)
}
func (s *testStmtSuite) TestSetPwdStmt(c *C) {
tx := mustBegin(c, s.testDB)
tx.Query(`INSERT INTO mysql.User VALUES ("localhost", "root", ""), ("127.0.0.1", "root", "")`)
rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`)
c.Assert(err, IsNil)
rows.Next()
var pwd string
rows.Scan(&pwd)
c.Assert(pwd, Equals, "")
c.Assert(rows.Next(), IsFalse)
rows.Close()
mustCommit(c, tx)
tx = mustBegin(c, s.testDB)
tx.Query(`SET PASSWORD FOR 'root'@'localhost' = 'password';`)
mustCommit(c, tx)
tx = mustBegin(c, s.testDB)
rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`)
c.Assert(err, IsNil)
rows.Next()
rows.Scan(&pwd)
c.Assert(pwd, Equals, "password")
c.Assert(rows.Next(), IsFalse)
rows.Close()
mustCommit(c, tx)
}