diff --git a/.travis.yml b/.travis.yml index 59e1ff828f..27eedb36b0 100644 --- a/.travis.yml +++ b/.travis.yml @@ -2,8 +2,5 @@ language: go go: - 1.5 - -before_script: go get github.com/golang/lint/golint -script: make check - +script: make diff --git a/Makefile b/Makefile index 2001858948..89965e5e4e 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ LDFLAGS += -X "github.com/pingcap/tidb/util/printer.TiDBGitHash=$(shell git rev- TARGET = "" -.PHONY: godep deps all build install parser clean todo test tidbtest mysqltest gotest interpreter +.PHONY: godep deps all build install parser clean todo test gotest interpreter server all: godep parser build test check @@ -45,6 +45,8 @@ parser: golex -o parser/scanner.go parser/scanner.l check: + go get github.com/golang/lint/golint + @echo "vet" @ go tool vet . 2>&1 | grep -vE 'Godeps|parser/scanner.*unreachable code' | awk '{print} END{if(NR>0) {exit 1}}' @echo "vet --shadow" diff --git a/parser/coldef/opt.go b/parser/coldef/opt.go index 7fff76a90b..0fac05dcea 100644 --- a/parser/coldef/opt.go +++ b/parser/coldef/opt.go @@ -221,3 +221,18 @@ func (tc *TableConstraint) String() string { } return strings.Join(tokens, " ") } + +// AuthOption is used for parsing create use statement. +type AuthOption struct { + // AuthString/HashString can be empty, so we need to decide which one to use. + ByAuthString bool + AuthString string + HashString string + // TODO: support auth_plugin +} + +// UserSpecification is used for parsing create use statement. +type UserSpecification struct { + User string + AuthOpt *AuthOption +} diff --git a/parser/parser.y b/parser/parser.y index 57491d3569..ca22095c7a 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -137,6 +137,7 @@ import ( having "HAVING" highPriority "HIGH_PRIORITY" hour "HOUR" + identified "IDENTIFIED" ignore "IGNORE" ifKwd "IF" ifNull "IFNULL" @@ -217,6 +218,7 @@ import ( unsigned "UNSIGNED" update "UPDATE" use "USE" + user "USER" using "USING" userVar "USER_VAR" value "VALUE" @@ -299,6 +301,7 @@ import ( Assignment "assignment" AssignmentList "assignment list" AssignmentListOpt "assignment list opt" + AuthOption "User auth option" AuthString "Password string value" BeginTransactionStmt "BEGIN TRANSACTION statement" CastType "Cast function target type" @@ -327,6 +330,7 @@ import ( CreateSpecificationList "CREATE Database specification list" CreateSpecListOpt "CREATE Database specification list opt" CreateTableStmt "CREATE TABLE statement" + CreateUserStmt "CREATE User statement" CrossOpt "Cross join option" DBName "Database Name" DeallocateSym "Deallocate or drop" @@ -366,6 +370,7 @@ import ( GlobalScope "The scope of variable" GroupByClause "GROUP BY clause" GroupByList "GROUP BY list" + HashString "Hashed string" HavingClause "HAVING clause" IfExists "If Exists" IfNotExists "If Not Exists" @@ -418,7 +423,6 @@ import ( SelectStmtFieldList "SELECT statement field list" SelectStmtLimit "SELECT statement optional LIMIT clause" SelectStmtOpts "Select statement options" - SelectStmtWhere "SELECT statement optional WHERE clause" SelectStmtGroup "SELECT statement optional GROUP BY clause" SelectStmtOrder "SELECT statement optional ORDER BY clause" SetStmt "Set variable statement" @@ -451,6 +455,8 @@ import ( UnionStmt "Union statement" UpdateStmt "UPDATE statement" Username "Username" + UserSpecification "Username and auth option" + UserSpecificationList "Username and auth option list" UserVariable "User defined variable name" UserVariableList "User defined variable name list" UseStmt "USE statement" @@ -1154,7 +1160,7 @@ DeleteFromStmt: Ignore: $4.(bool)} if $7 != nil { - x.Where = $7.(*rsets.WhereRset).Expr + x.Where = $7.(expression.Expression) } if $8 != nil { @@ -1183,7 +1189,7 @@ DeleteFromStmt: Refs: $7.(*rsets.JoinRset), } if $8 != nil { - x.Where = $8.(*rsets.WhereRset).Expr + x.Where = $8.(expression.Expression) } $$ = x if yylex.(*lexer).root { @@ -1202,7 +1208,7 @@ DeleteFromStmt: Refs: $8.(*rsets.JoinRset), } if $9 != nil { - x.Where = $9.(*rsets.WhereRset).Expr + x.Where = $9.(expression.Expression) } $$ = x if yylex.(*lexer).root { @@ -1593,7 +1599,7 @@ UnReservedKeyword: | "DATE" | "DATETIME" | "DEALLOCATE" | "DO" | "END" | "ENGINE" | "ENGINES" | "EXECUTE" | "FIRST" | "FULL" | "LOCAL" | "NAMES" | "OFFSET" | "PASSWORD" %prec lowerThanEq | "PREPARE" | "QUICK" | "ROLLBACK" | "SESSION" | "SIGNED" | "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" -| "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" +| "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" NotKeywordToken: "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" @@ -2634,7 +2640,7 @@ SelectStmt: } } | "SELECT" SelectStmtOpts SelectStmtFieldList "FROM" - FromClause SelectStmtWhere SelectStmtGroup HavingClause SelectStmtOrder + FromClause WhereClauseOptional SelectStmtGroup HavingClause SelectStmtOrder SelectStmtLimit SelectLockOpt { st := &stmts.SelectStmt{ @@ -2645,7 +2651,7 @@ SelectStmt: } if $6 != nil { - st.Where = $6.(*rsets.WhereRset) + st.Where = &rsets.WhereRset{Expr: $6.(expression.Expression)} } if $7 != nil { @@ -2851,13 +2857,6 @@ SelectStmtFieldList: $$ = $1 } -SelectStmtWhere: - /* EMPTY */ - { - $$ = nil - } -| WhereClause - SelectStmtGroup: /* EMPTY */ { @@ -3026,7 +3025,7 @@ PasswordOpt: } AuthString: - Identifier + stringLit { $$ = $1.(string) } @@ -3086,6 +3085,14 @@ ShowStmt: Pattern: &expressions.PatternLike{Pattern: $5.(expression.Expression)}, } } +| "SHOW" GlobalScope "VARIABLES" "WHERE" Expression + { + $$ = &stmts.ShowStmt{ + Target: stmt.ShowVariables, + GlobalScope: $2.(bool), + Where: expressions.Expr($5), + } + } GlobalScope: { @@ -3144,6 +3151,7 @@ Statement: | CreateDatabaseStmt | CreateIndexStmt | CreateTableStmt +| CreateUserStmt | DoStmt | DropDatabaseStmt | DropIndexStmt @@ -3722,17 +3730,15 @@ UpdateStmt: "UPDATE" LowPriorityOptional IgnoreOptional TableRef "SET" AssignmentList WhereClauseOptional OrderByOptional LimitClause { // Single-table syntax - var expr expression.Expression - if w := $7; w != nil { - expr = w.(*rsets.WhereRset).Expr - } r := &rsets.JoinRset{Left: $4, Right: nil} st := &stmts.UpdateStmt{ LowPriority: $2.(bool), TableRefs: r, - List: $6.([]expressions.Assignment), - Where: expr, - } + List: $6.([]expressions.Assignment), + } + if $7 != nil { + st.Where = $7.(expression.Expression) + } if $8 != nil { st.Order = $8.(*rsets.OrderByRset) } @@ -3747,17 +3753,15 @@ UpdateStmt: | "UPDATE" LowPriorityOptional IgnoreOptional TableRefs "SET" AssignmentList WhereClauseOptional { // Multiple-table syntax - var expr expression.Expression - if w := $7; w != nil { - expr = w.(*rsets.WhereRset).Expr - } st := &stmts.UpdateStmt{ LowPriority: $2.(bool), TableRefs: $4.(*rsets.JoinRset), - List: $6.([]expressions.Assignment), - Where: expr, + List: $6.([]expressions.Assignment), MultipleTable: true, - } + } + if $7 != nil { + st.Where = $7.(expression.Expression) + } $$ = st if yylex.(*lexer).root { break @@ -3776,7 +3780,7 @@ UseStmt: WhereClause: "WHERE" Expression { - $$ = &rsets.WhereRset{Expr: expressions.Expr($2)} + $$ = expressions.Expr($2) } WhereClauseOptional: @@ -3795,5 +3799,56 @@ CommaOpt: { } +/************************************************************************************ + * Account Management Statements + * https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html + ************************************************************************************/ +CreateUserStmt: + "CREATE" "USER" IfNotExists UserSpecificationList + { + // See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html + $$ = &stmts.CreateUserStmt{ + IfNotExists: $3.(bool), + Specs: $4.([]*coldef.UserSpecification), + } + } + +UserSpecification: + Username AuthOption + { + $$ = &coldef.UserSpecification{ + User: $1.(string), + AuthOpt: $2.(*coldef.AuthOption), + } + } + +UserSpecificationList: + UserSpecification + { + $$ = []*coldef.UserSpecification{$1.(*coldef.UserSpecification)} + } +| UserSpecificationList ',' UserSpecification + { + $$ = append($1.([]*coldef.UserSpecification), $3.(*coldef.UserSpecification)) + } + +AuthOption: + {} +| "IDENTIFIED" "BY" AuthString + { + $$ = &coldef.AuthOption { + AuthString: $3.(string), + ByAuthString: true, + } + } +| "IDENTIFIED" "BY" "PASSWORD" HashString + { + $$ = &coldef.AuthOption { + HashString: $4.(string), + } + } + +HashString: + stringLit %% diff --git a/parser/parser_test.go b/parser/parser_test.go index 170f30c75d..1603edf55e 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -290,6 +290,7 @@ func (s *testParserSuite) TestParser0(c *C) { {"SHOW SESSION VARIABLES LIKE 'character_set_results'", true}, {"SHOW VARIABLES", true}, {"SHOW GLOBAL VARIABLES", true}, + {"SHOW GLOBAL VARIABLES WHERE Variable_name = 'autocommit'", true}, // For compare subquery {"SELECT 1 > (select 1)", true}, @@ -325,6 +326,15 @@ func (s *testParserSuite) TestParser0(c *C) { // For comparison {"select 1 <=> 0, 1 <=> null, 1 = null", true}, + + // For create user + {`CREATE USER IF NOT EXISTS 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + {`CREATE USER 'root'@'localhost' IDENTIFIED BY 'new-password', 'root'@'127.0.0.1' IDENTIFIED BY PASSWORD 'hashstring'`, true}, + + // For select with where clause + {"SELECT * FROM t WHERE 1 = 1", true}, } for _, t := range table { @@ -347,7 +357,7 @@ func (s *testParserSuite) TestParser0(c *C) { "date", "datetime", "deallocate", "do", "end", "engine", "engines", "execute", "first", "full", "local", "names", "offset", "password", "prepare", "quick", "rollback", "session", "signed", "start", "global", "tables", "text", "time", "timestamp", "transaction", "truncate", "unknown", - "value", "warnings", "year", "now", "substring", "mode", "any", "some", + "value", "warnings", "year", "now", "substring", "mode", "any", "some", "user", "identified", } for _, kw := range unreservedKws { src := fmt.Sprintf("SELECT %s FROM tbl;", kw) diff --git a/parser/scanner.l b/parser/scanner.l index 67731e232e..772b2eef85 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -298,6 +298,7 @@ group_concat {g}{r}{o}{u}{p}_{c}{o}{n}{c}{a}{t} having {h}{a}{v}{i}{n}{g} high_priority {h}{i}{g}{h}_{p}{r}{i}{o}{r}{i}{t}{y} hour {h}{o}{u}{r} +identified {i}{d}{e}{n}{t}{i}{f}{i}{e}{d} if {i}{f} ifnull {i}{f}{n}{u}{l}{l} ignore {i}{g}{n}{o}{r}{e} @@ -431,6 +432,7 @@ duration {d}{u}{r}{a}{t}{i}{o}{n} rune {r}{u}{n}{e} string {s}{t}{r}{i}{n}{g} use {u}{s}{e} +user {u}{s}{e}{r} using {u}{s}{i}{n}{g} idchar0 [a-zA-Z_] @@ -588,6 +590,8 @@ sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident} {high_priority} return highPriority {hour} lval.item = string(l.val) return hour +{identified} lval.item = string(l.val) + return identified {if} lval.item = string(l.val) return ifKwd {ifnull} lval.item = string(l.val) @@ -694,6 +698,8 @@ sys_var "@@"(({global}".")|({session}".")|{local}".")?{ident} return nullIf {update} return update {use} return use +{user} lval.item = string(l.val) + return user {using} return using {value} lval.item = string(l.val) return value diff --git a/plan/plans/show.go b/plan/plans/show.go index dd108ed658..0fd921a804 100644 --- a/plan/plans/show.go +++ b/plan/plans/show.go @@ -48,6 +48,7 @@ type ShowPlan struct { // Used by SHOW VARIABLES GlobalScope bool Pattern expression.Expression + Where expression.Expression rows []*plan.Row cursor int } @@ -229,6 +230,24 @@ func (s *ShowPlan) fetchAll(ctx context.Context) error { if !match { continue } + } else if s.Where != nil { + m := map[interface{}]interface{}{} + + m[expressions.ExprEvalIdentFunc] = func(name string) (interface{}, error) { + if strings.EqualFold(name, "Variable_name") { + return v.Name, nil + } + + return nil, errors.Errorf("unknown field %s", name) + } + + match, err := expressions.EvalBoolExpr(ctx, s.Where, m) + if err != nil { + return errors.Trace(err) + } + if !match { + continue + } } value := v.Value if !s.GlobalScope { diff --git a/plan/plans/show_test.go b/plan/plans/show_test.go index 32d8001b5f..055c2589b2 100644 --- a/plan/plans/show_test.go +++ b/plan/plans/show_test.go @@ -20,6 +20,8 @@ import ( "github.com/pingcap/tidb" "github.com/pingcap/tidb/expression/expressions" "github.com/pingcap/tidb/kv" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/sessionctx/variable" @@ -113,6 +115,25 @@ func (p *testShowSuit) TestShowVariables(c *C) { c.Assert(ok, IsTrue) // Show session varibale get utf8 c.Assert(v, Equals, "utf8") + pln.Close() + pln.Pattern = nil + pln.Where = &expressions.BinaryOperation{ + L: &expressions.Ident{CIStr: model.NewCIStr("Variable_name")}, + R: expressions.Value{Val: "autocommit"}, + Op: opcode.EQ, + } + + ret = map[string]string{} + sessionVars.Systems["autocommit"] = "on" + rset.Do(func(data []interface{}) (bool, error) { + ret[data[0].(string)] = data[1].(string) + return true, nil + }) + + c.Assert(ret, HasLen, 1) + v, ok = ret["autocommit"] + c.Assert(ok, IsTrue) + c.Assert(v, Equals, "on") } func (p *testShowSuit) TearDownSuite(c *C) { diff --git a/stmt/stmts/account_manage.go b/stmt/stmts/account_manage.go new file mode 100644 index 0000000000..e956de5890 --- /dev/null +++ b/stmt/stmts/account_manage.go @@ -0,0 +1,206 @@ +// Copyright 2013 The ql Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSES/QL-LICENSE file. + +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmts + +import ( + "strings" + + "github.com/juju/errors" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/expression/expressions" + "github.com/pingcap/tidb/model" + mysql "github.com/pingcap/tidb/mysqldef" + "github.com/pingcap/tidb/parser/coldef" + "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/rset" + "github.com/pingcap/tidb/rset/rsets" + "github.com/pingcap/tidb/stmt" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/util/format" +) + +/************************************************************************************ + * Account Management Statements + * https://dev.mysql.com/doc/refman/5.7/en/account-management-sql.html + ************************************************************************************/ +var ( + _ stmt.Statement = (*CreateUserStmt)(nil) + _ stmt.Statement = (*SetPwdStmt)(nil) +) + +// CreateUserStmt creates user account. +// See: https://dev.mysql.com/doc/refman/5.7/en/create-user.html +type CreateUserStmt struct { + IfNotExists bool + Specs []*coldef.UserSpecification + + Text string +} + +// Explain implements the stmt.Statement Explain interface. +func (s *CreateUserStmt) Explain(ctx context.Context, w format.Formatter) { + w.Format("%s\n", s.Text) +} + +// IsDDL implements the stmt.Statement IsDDL interface. +func (s *CreateUserStmt) IsDDL() bool { + return true +} + +// OriginText implements the stmt.Statement OriginText interface. +func (s *CreateUserStmt) OriginText() string { + return s.Text +} + +// SetText implements the stmt.Statement SetText interface. +func (s *CreateUserStmt) SetText(text string) { + s.Text = text +} + +func composeUserTableFilter(name string, host string) expression.Expression { + nameMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("User")}, &expressions.Value{Val: name}) + hostMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("Host")}, &expressions.Value{Val: host}) + return expressions.NewBinaryOperation(opcode.AndAnd, nameMatch, hostMatch) +} + +func composeUserTableRset() *rsets.JoinRset { + return &rsets.JoinRset{ + Left: &rsets.TableSource{ + Source: table.Ident{ + Name: model.NewCIStr(mysql.UserTable), + Schema: model.NewCIStr(mysql.SystemDB), + }, + }, + } +} + +func (s *CreateUserStmt) userExists(ctx context.Context, name string, host string) (bool, error) { + r := composeUserTableRset() + p, err := r.Plan(ctx) + if err != nil { + return false, errors.Trace(err) + } + where := &rsets.WhereRset{ + Src: p, + Expr: composeUserTableFilter(name, host), + } + p, err = where.Plan(ctx) + if err != nil { + return false, errors.Trace(err) + } + defer p.Close() + row, err := p.Next(ctx) + if err != nil { + return false, errors.Trace(err) + } + return row != nil, nil +} + +// Exec implements the stmt.Statement Exec interface. +func (s *CreateUserStmt) Exec(ctx context.Context) (rset.Recordset, error) { + st := &InsertIntoStmt{ + TableIdent: table.Ident{ + Name: model.NewCIStr(mysql.UserTable), + Schema: model.NewCIStr(mysql.SystemDB), + }, + } + values := make([][]expression.Expression, 0, len(s.Specs)) + for _, spec := range s.Specs { + strs := strings.Split(spec.User, "@") + userName := strs[0] + host := strs[1] + exists, err1 := s.userExists(ctx, userName, host) + if err1 != nil { + return nil, errors.Trace(err1) + } + if exists { + if !s.IfNotExists { + return nil, errors.Errorf("Duplicate user") + } + continue + } + value := make([]expression.Expression, 0, 3) + value = append(value, expressions.Value{Val: host}) + value = append(value, expressions.Value{Val: userName}) + if spec.AuthOpt.ByAuthString { + value = append(value, expressions.Value{Val: spec.AuthOpt.AuthString}) + } else { + // TODO: Maybe we should hash the string here? + value = append(value, expressions.Value{Val: spec.AuthOpt.HashString}) + } + values = append(values, value) + } + if len(values) == 0 { + return nil, nil + } + st.Lists = values + _, err := st.Exec(ctx) + if err != nil { + return nil, errors.Trace(err) + } + return nil, nil +} + +// SetPwdStmt is a statement to assign a password to user account. +// See: https://dev.mysql.com/doc/refman/5.7/en/set-password.html +type SetPwdStmt struct { + User string + Password string + + Text string +} + +// Explain implements the stmt.Statement Explain interface. +func (s *SetPwdStmt) Explain(ctx context.Context, w format.Formatter) { + w.Format("%s\n", s.Text) +} + +// IsDDL implements the stmt.Statement IsDDL interface. +func (s *SetPwdStmt) IsDDL() bool { + return false +} + +// OriginText implements the stmt.Statement OriginText interface. +func (s *SetPwdStmt) OriginText() string { + return s.Text +} + +// SetText implements the stmt.Statement SetText interface. +func (s *SetPwdStmt) SetText(text string) { + s.Text = text +} + +// Exec implements the stmt.Statement Exec interface. +func (s *SetPwdStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { + // If len(s.User) == 0, use CURRENT_USER() + strs := strings.Split(s.User, "@") + userName := strs[0] + host := strs[1] + // Update mysql.user + asgn := expressions.Assignment{ + ColName: "Password", + Expr: expressions.Value{Val: s.Password}, + } + st := &UpdateStmt{ + TableRefs: composeUserTableRset(), + List: []expressions.Assignment{asgn}, + Where: composeUserTableFilter(userName, host), + } + return st.Exec(ctx) +} diff --git a/stmt/stmts/account_manage_test.go b/stmt/stmts/account_manage_test.go new file mode 100644 index 0000000000..ce78a83201 --- /dev/null +++ b/stmt/stmts/account_manage_test.go @@ -0,0 +1,79 @@ +// Copyright 2015 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package stmts_test + +import ( + . "github.com/pingcap/check" +) + +func (s *testStmtSuite) TestCreateUserStmt(c *C) { + // Make sure user test not in mysql.User. + tx := mustBegin(c, s.testDB) + rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="test" and Host="localhost"`) + c.Assert(err, IsNil) + c.Assert(rows.Next(), IsFalse) + rows.Close() + mustCommit(c, tx) + // Create user test. + createUserSQL := `CREATE USER 'test'@'localhost' IDENTIFIED BY '123';` + mustExec(c, s.testDB, createUserSQL) + // Make sure user test in mysql.User. + tx = mustBegin(c, s.testDB) + rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="test" and Host="localhost"`) + c.Assert(err, IsNil) + rows.Next() + var pwd string + rows.Scan(&pwd) + c.Assert(pwd, Equals, "123") + c.Assert(rows.Next(), IsFalse) + rows.Close() + mustCommit(c, tx) + // Create duplicate user with IfNotExists will be success. + createUserSQL = `CREATE USER IF NOT EXISTS 'test'@'localhost' IDENTIFIED BY '123';` + mustExec(c, s.testDB, createUserSQL) + + // Create duplicate user without IfNotExists will cause error. + createUserSQL = `CREATE USER 'test'@'localhost' IDENTIFIED BY '123';` + tx = mustBegin(c, s.testDB) + _, err = tx.Query(createUserSQL) + c.Assert(err, NotNil) +} + +func (s *testStmtSuite) TestSetPwdStmt(c *C) { + tx := mustBegin(c, s.testDB) + tx.Query(`INSERT INTO mysql.User VALUES ("localhost", "root", ""), ("127.0.0.1", "root", "")`) + rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`) + c.Assert(err, IsNil) + rows.Next() + var pwd string + rows.Scan(&pwd) + c.Assert(pwd, Equals, "") + c.Assert(rows.Next(), IsFalse) + rows.Close() + mustCommit(c, tx) + + tx = mustBegin(c, s.testDB) + tx.Query(`SET PASSWORD FOR 'root'@'localhost' = 'password';`) + mustCommit(c, tx) + + tx = mustBegin(c, s.testDB) + rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`) + c.Assert(err, IsNil) + rows.Next() + rows.Scan(&pwd) + c.Assert(pwd, Equals, "password") + c.Assert(rows.Next(), IsFalse) + rows.Close() + mustCommit(c, tx) +} diff --git a/stmt/stmts/set.go b/stmt/stmts/set.go index 740b6c73d3..8432a77146 100644 --- a/stmt/stmts/set.go +++ b/stmt/stmts/set.go @@ -22,21 +22,15 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/expression/expressions" - "github.com/pingcap/tidb/model" - mysql "github.com/pingcap/tidb/mysqldef" - "github.com/pingcap/tidb/parser/opcode" "github.com/pingcap/tidb/rset" - "github.com/pingcap/tidb/rset/rsets" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/stmt" - "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/format" ) var ( _ stmt.Statement = (*SetStmt)(nil) _ stmt.Statement = (*SetCharsetStmt)(nil) - _ stmt.Statement = (*SetPwdStmt)(nil) ) // VariableAssignment is a varible assignment struct. @@ -203,62 +197,3 @@ func (s *SetCharsetStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) // ctx.Charset = s.Charset return nil, nil } - -// SetPwdStmt is a statement to assign a password to user account. -// See: https://dev.mysql.com/doc/refman/5.7/en/set-password.html -type SetPwdStmt struct { - User string - Password string - - Text string -} - -// Explain implements the stmt.Statement Explain interface. -func (s *SetPwdStmt) Explain(ctx context.Context, w format.Formatter) { - w.Format("%s\n", s.Text) -} - -// IsDDL implements the stmt.Statement IsDDL interface. -func (s *SetPwdStmt) IsDDL() bool { - return false -} - -// OriginText implements the stmt.Statement OriginText interface. -func (s *SetPwdStmt) OriginText() string { - return s.Text -} - -// SetText implements the stmt.Statement SetText interface. -func (s *SetPwdStmt) SetText(text string) { - s.Text = text -} - -// Exec implements the stmt.Statement Exec interface. -func (s *SetPwdStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { - // If len(s.User) == 0, use CURRENT_USER() - strs := strings.Split(s.User, "@") - userName := strs[0] - host := strs[1] - // Update mysql.user - r := &rsets.JoinRset{ - Left: &rsets.TableSource{ - Source: table.Ident{ - Name: model.NewCIStr(mysql.UserTable), - Schema: model.NewCIStr(mysql.SystemDB), - }, - }, - } - asgn := expressions.Assignment{ - ColName: "Password", - Expr: expressions.Value{Val: s.Password}, - } - nameMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("User")}, &expressions.Value{Val: userName}) - hostMatch := expressions.NewBinaryOperation(opcode.EQ, &expressions.Ident{CIStr: model.NewCIStr("Host")}, &expressions.Value{Val: host}) - where := expressions.NewBinaryOperation(opcode.AndAnd, nameMatch, hostMatch) - st := &UpdateStmt{ - TableRefs: r, - List: []expressions.Assignment{asgn}, - Where: where, - } - return st.Exec(ctx) -} diff --git a/stmt/stmts/set_test.go b/stmt/stmts/set_test.go index a0af32ad59..93de2018c1 100644 --- a/stmt/stmts/set_test.go +++ b/stmt/stmts/set_test.go @@ -137,31 +137,3 @@ func (s *testStmtSuite) TestSetCharsetStmt(c *C) { testStmt.Explain(nil, mf) c.Assert(mf.Len(), Greater, 0) } - -func (s *testStmtSuite) TestSetPwdStmt(c *C) { - tx := mustBegin(c, s.testDB) - tx.Query(`INSERT INTO mysql.User VALUES ("localhost", "root", ""), ("127.0.0.1", "root", "")`) - rows, err := tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`) - c.Assert(err, IsNil) - rows.Next() - var pwd string - rows.Scan(&pwd) - c.Assert(pwd, Equals, "") - c.Assert(rows.Next(), IsFalse) - rows.Close() - mustCommit(c, tx) - - tx = mustBegin(c, s.testDB) - tx.Query(`SET PASSWORD FOR 'root'@'localhost' = 'password';`) - mustCommit(c, tx) - - tx = mustBegin(c, s.testDB) - rows, err = tx.Query(`SELECT Password FROM mysql.User WHERE User="root" and Host="localhost"`) - c.Assert(err, IsNil) - rows.Next() - rows.Scan(&pwd) - c.Assert(pwd, Equals, "password") - c.Assert(rows.Next(), IsFalse) - rows.Close() - mustCommit(c, tx) -} diff --git a/stmt/stmts/show.go b/stmt/stmts/show.go index a495967759..494e1dfeb1 100644 --- a/stmt/stmts/show.go +++ b/stmt/stmts/show.go @@ -20,6 +20,7 @@ import ( "github.com/pingcap/tidb/plan/plans" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/rset/rsets" + "github.com/pingcap/tidb/sessionctx/db" "github.com/pingcap/tidb/stmt" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/format" @@ -40,6 +41,7 @@ type ShowStmt struct { // Used by show variables GlobalScope bool Pattern expression.Expression + Where expression.Expression Text string } @@ -70,14 +72,24 @@ func (s *ShowStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { log.Debug("Exec Show Stmt") r := &plans.ShowPlan{ Target: s.Target, - DBName: s.DBName, + DBName: s.getDBName(ctx), TableName: s.TableIdent.Name.O, ColumnName: s.ColumnName, Flag: s.Flag, Full: s.Full, GlobalScope: s.GlobalScope, Pattern: s.Pattern, + Where: s.Where, } return rsets.Recordset{Ctx: ctx, Plan: r}, nil } + +func (s *ShowStmt) getDBName(ctx context.Context) string { + if len(s.DBName) > 0 { + return s.DBName + } + + // if s.DBName is empty, we should use current db name if possible. + return db.GetCurrentSchema(ctx) +} diff --git a/tidb_test.go b/tidb_test.go index da2649d64c..06f0b31775 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -771,6 +771,24 @@ func (s *testSessionSuite) TestSubQuery(c *C) { match(c, rows[1], 2) } +func (s *testSessionSuite) TestShow(c *C) { + store := newStore(c, s.dbName) + se := newSession(c, store, s.dbName) + + mustExecSQL(c, se, "set global autocommit=1") + r := mustExecSQL(c, se, "show variables where variable_name = 'autocommit'") + row, err := r.FirstRow() + c.Assert(err, IsNil) + match(c, row, "autocommit", 1) + + mustExecSQL(c, se, "create table if not exists t (c int)") + r = mustExecSQL(c, se, `show columns from t`) + rows, err := r.Rows(-1, 0) + c.Assert(err, IsNil) + c.Assert(rows, HasLen, 1) + match(c, rows[0], "c", "INT", "YES", "", nil, "") +} + func newSession(c *C, store kv.Storage, dbName string) Session { se, err := CreateSession(store) c.Assert(err, IsNil)