diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000..110acb1688 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,15 @@ +FROM golang + +VOLUME /opt + +RUN apt-get update && apt-get install -y wget git make ; \ + cd /opt ; \ + export PATH=$GOROOT/bin:$GOPATH/bin:$PATH ; \ + go get -d github.com/pingcap/tidb ; \ + cd $GOPATH/src/github.com/pingcap/tidb ; \ + make ; make server ; cp tidb-server/tidb-server /usr/bin/ + +EXPOSE 4000 + +CMD ["/usr/bin/tidb-server"] + diff --git a/ast/functions.go b/ast/functions.go index 7faab4c6e1..df7d752558 100644 --- a/ast/functions.go +++ b/ast/functions.go @@ -338,26 +338,39 @@ func (n *FuncTrimExpr) IsStatic() bool { return n.Str.IsStatic() && n.RemStr.IsStatic() } -// DateArithType is type for DateArith option. +// DateArithType is type for DateArith type. type DateArithType byte const ( - // DateAdd is to run date_add function option. + // DateAdd is to run adddate or date_add function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add DateAdd DateArithType = iota + 1 - // DateSub is to run date_sub function option. + // DateSub is to run subdate or date_sub function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub DateSub + // DateArithDaysForm is to run adddate or subdate function with days form Flag. + DateArithDaysForm ) +// DateArithInterval is the struct of DateArith interval part. +type DateArithInterval struct { + // Form is the flag of DateArith running form. + // The function runs with interval or days. + Form DateArithType + Unit string + Interval ExprNode +} + // FuncDateArithExpr is the struct for date arithmetic functions. type FuncDateArithExpr struct { funcNode - Op DateArithType - Unit string - Date ExprNode - Interval ExprNode + // Op is used for distinguishing date_add and date_sub. + Op DateArithType + Date ExprNode + DateArithInterval } // Accept implements Node Accept interface. @@ -375,11 +388,11 @@ func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) { n.Date = node.(ExprNode) } if n.Interval != nil { - node, ok := n.Date.Accept(v) + node, ok := n.Interval.Accept(v) if !ok { return n, false } - n.Date = node.(ExprNode) + n.Interval = node.(ExprNode) } return v.Leave(n) } diff --git a/ast/misc.go b/ast/misc.go index 2ae8ccbd77..ff58fecd08 100644 --- a/ast/misc.go +++ b/ast/misc.go @@ -166,6 +166,7 @@ const ( ShowCollation ShowCreateTable ShowGrants + ShowTriggers ) // ShowStmt is a statement to provide information about databases, tables, columns and so on. diff --git a/bootstrap.go b/bootstrap.go index b6c3272880..5c7538e424 100644 --- a/bootstrap.go +++ b/bootstrap.go @@ -22,6 +22,7 @@ import ( "runtime/debug" "strings" + "github.com/juju/errors" "github.com/ngaut/log" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/sessionctx/variable" @@ -91,45 +92,94 @@ const ( CreateGloablVariablesTable = `CREATE TABLE if not exists mysql.GLOBAL_VARIABLES( VARIABLE_NAME VARCHAR(64) Not Null PRIMARY KEY, VARIABLE_VALUE VARCHAR(1024) DEFAULT Null);` + // CreateTiDBTable is the SQL statement creates a table in system db. + // This table is a key-value struct contains some information used by TiDB. + // Currently we only put bootstrapped in it which indicates if the system is already bootstrapped. + CreateTiDBTable = `CREATE TABLE if not exists mysql.tidb( + VARIABLE_NAME VARCHAR(64) Not Null PRIMARY KEY, + VARIABLE_VALUE VARCHAR(1024) DEFAULT Null, + COMMENT VARCHAR(1024));` ) // Bootstrap initiates system DB for a store. func bootstrap(s Session) { - // Create a test database. - mustExecute(s, "CREATE DATABASE IF NOT EXISTS test") - - // Check if system db exists. - _, err := s.Execute(fmt.Sprintf("USE %s;", mysql.SystemDB)) - if err == nil { - // We have already finished bootstrap. - return - } else if terror.DatabaseNotExists.NotEqual(err) { + b, err := checkBootstrapped(s) + if err != nil { log.Fatal(err) } - mustExecute(s, fmt.Sprintf("CREATE DATABASE %s;", mysql.SystemDB)) - initUserTable(s) - initPrivTables(s) - initGlobalVariables(s) + if b { + return + } + doDDLWorks(s) + doDMLWorks(s) } -func initUserTable(s Session) { +const ( + bootstrappedVar = "bootstrapped" + bootstrappedVarTrue = "True" +) + +func checkBootstrapped(s Session) (bool, error) { + // Check if system db exists. + _, err := s.Execute(fmt.Sprintf("USE %s;", mysql.SystemDB)) + if err != nil && terror.DatabaseNotExists.NotEqual(err) { + log.Fatal(err) + } + // Check bootstrapped variable value in TiDB table. + v, err := checkBootstrappedVar(s) + if err != nil { + return false, errors.Trace(err) + } + return v, nil +} + +func checkBootstrappedVar(s Session) (bool, error) { + sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s"`, mysql.SystemDB, mysql.TiDBTable, bootstrappedVar) + rs, err := s.Execute(sql) + if err != nil { + if terror.TableNotExists.Equal(err) { + return false, nil + } + return false, errors.Trace(err) + } + if len(rs) != 1 { + return false, errors.New("Wrong number of Recordset") + } + r := rs[0] + row, err := r.Next() + if err != nil || row == nil { + return false, errors.Trace(err) + } + return row.Data[0].(string) == bootstrappedVarTrue, nil +} + +// Execute DDL statements in bootstrap stage. +func doDDLWorks(s Session) { + // Create a test database. + mustExecute(s, "CREATE DATABASE IF NOT EXISTS test") + // Create system db. + mustExecute(s, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", mysql.SystemDB)) + // Create user table. mustExecute(s, CreateUserTable) + // Create privilege tables. + mustExecute(s, CreateDBPrivTable) + mustExecute(s, CreateTablePrivTable) + mustExecute(s, CreateColumnPrivTable) + // Create global systemt variable table. + mustExecute(s, CreateGloablVariablesTable) + // Create TiDB table. + mustExecute(s, CreateTiDBTable) +} + +// Execute DML statements in bootstrap stage. +// All the statements run in a single transaction. +func doDMLWorks(s Session) { + mustExecute(s, "BEGIN") // Insert a default user with empty password. mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y"), ("127.0.0.1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y"), ("::1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y");`) -} - -// Initiates privilege tables including mysql.db, mysql.tables_priv and mysql.column_priv. -func initPrivTables(s Session) { - mustExecute(s, CreateDBPrivTable) - mustExecute(s, CreateTablePrivTable) - mustExecute(s, CreateColumnPrivTable) -} - -// Initiates global system variables table. -func initGlobalVariables(s Session) { - mustExecute(s, CreateGloablVariablesTable) + // Init global system variable table. values := make([]string, 0, len(variable.SysVars)) for k, v := range variable.SysVars { value := fmt.Sprintf(`("%s", "%s")`, strings.ToLower(k), v.Value) @@ -137,6 +187,9 @@ func initGlobalVariables(s Session) { } sql := fmt.Sprintf("INSERT INTO %s.%s VALUES %s;", mysql.SystemDB, mysql.GlobalVariablesTable, strings.Join(values, ", ")) mustExecute(s, sql) + sql = fmt.Sprintf(`INSERT INTO %s.%s VALUES("%s", "%s", "Bootstrap flag. Do not delete.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, bootstrappedVarTrue, bootstrappedVarTrue) + mustExecute(s, sql) + mustExecute(s, "COMMIT") } func mustExecute(s Session, sql string) { diff --git a/column/column.go b/column/column.go index d6831c8f97..bd88c96b13 100644 --- a/column/column.go +++ b/column/column.go @@ -55,10 +55,10 @@ func (c *Col) String() string { } // FindCol finds column in cols by name. -func FindCol(cols []*Col, name string) (c *Col) { - for _, c = range cols { +func FindCol(cols []*Col, name string) *Col { + for _, c := range cols { if strings.EqualFold(c.Name.O, name) { - return + return c } } return nil diff --git a/ddl/ddl.go b/ddl/ddl.go index f6f8a84469..2762d9722f 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -31,7 +31,6 @@ import ( "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/coldef" - "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" "github.com/pingcap/tidb/terror" @@ -577,16 +576,6 @@ func (d *ddl) DropTable(ctx context.Context, ti table.Ident) (err error) { if err != nil { return errors.Trace(err) } - // Check Privilege - privChecker := privilege.GetPrivilegeChecker(ctx) - hasPriv, err := privChecker.Check(ctx, schema, tb.Meta(), mysql.DropPriv) - if err != nil { - return errors.Trace(err) - } - if !hasPriv { - return errors.Errorf("You do not have the privilege to drop table %s.%s.", ti.Schema, ti.Name) - } - err = kv.RunInNewTxn(d.store, false, func(txn kv.Transaction) error { t := meta.NewMeta(txn) err := d.verifySchemaMetaVersion(t, is.SchemaMetaVersion()) diff --git a/expression/date_arith.go b/expression/date_arith.go index 0a9a6c0de3..ea4c469a30 100644 --- a/expression/date_arith.go +++ b/expression/date_arith.go @@ -15,6 +15,7 @@ package expression import ( "fmt" + "regexp" "strings" "time" @@ -28,22 +29,38 @@ import ( type DateArithType byte const ( - // DateAdd is to run date_add function option. + // DateAdd is to run adddate or date_add function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add DateAdd DateArithType = iota + 1 - // DateSub is to run date_sub function option. + // DateSub is to run subdate or date_sub function option. + // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate // See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub DateSub + // DateArithDaysForm is to run adddate or subdate function with days form Flag. + DateArithDaysForm ) // DateArith is used for dealing with addition and substraction of time. type DateArith struct { - Op DateArithType - Unit string + // Op is used for distinguishing date_add and date_sub. + Op DateArithType + // Form is the flag of DateArith running form. + // The function runs with interval or days. + Form DateArithType Date Expression + Unit string Interval Expression } +type evalArgsResult struct { + time mysql.Time + year int64 + month int64 + day int64 + duration time.Duration +} + func (da *DateArith) isAdd() bool { if da.Op == DateAdd { return true @@ -82,58 +99,88 @@ func (da *DateArith) String() string { // Eval implements the Expression Eval interface. func (da *DateArith) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { - t, years, months, days, durations, err := da.evalArgs(ctx, args) - if t.IsZero() || err != nil { + val, err := da.evalArgs(ctx, args) + if val.time.IsZero() || err != nil { return nil, errors.Trace(err) } if !da.isAdd() { - years, months, days, durations = -years, -months, -days, -durations + val.year, val.month, val.day, val.duration = + -val.year, -val.month, -val.day, -val.duration } - t.Time = t.Time.Add(durations) - t.Time = t.Time.AddDate(int(years), int(months), int(days)) + val.time.Time = val.time.Time.Add(val.duration) + val.time.Time = val.time.Time.AddDate(int(val.year), int(val.month), int(val.day)) // "2011-11-11 10:10:20.000000" outputs "2011-11-11 10:10:20". - if t.Time.Nanosecond() == 0 { - t.Fsp = 0 + if val.time.Time.Nanosecond() == 0 { + val.time.Fsp = 0 } - return t, nil + return val.time, nil } func (da *DateArith) evalArgs(ctx context.Context, args map[interface{}]interface{}) ( - mysql.Time, int64, int64, int64, time.Duration, error) { + *evalArgsResult, error) { + ret := &evalArgsResult{time: mysql.ZeroTimestamp} + dVal, err := da.Date.Eval(ctx, args) if dVal == nil || err != nil { - return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err) + return ret, errors.Trace(err) } dValStr, err := types.ToString(dVal) if err != nil { - return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err) + return ret, errors.Trace(err) } f := types.NewFieldType(mysql.TypeDatetime) f.Decimal = mysql.MaxFsp dVal, err = types.Convert(dValStr, f) if dVal == nil || err != nil { - return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err) + return ret, errors.Trace(err) } - t, ok := dVal.(mysql.Time) + + var ok bool + ret.time, ok = dVal.(mysql.Time) if !ok { - return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Errorf("need time type, but got %T", dVal) + return ret, errors.Errorf("need time type, but got %T", dVal) } iVal, err := da.Interval.Eval(ctx, args) if iVal == nil || err != nil { - return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err) + ret.time = mysql.ZeroTimestamp + return ret, errors.Trace(err) } - iValStr, err := types.ToString(iVal) - if err != nil { - return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err) - } - years, months, days, durations, err := mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(iValStr)) - if err != nil { - return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err) + // handle adddate(expr,days) or subdate(expr,days) form + if da.Form == DateArithDaysForm { + if iVal, err = da.evalDaysForm(iVal); err != nil { + return ret, errors.Trace(err) + } } - return t, years, months, days, durations, nil + iValStr, err := types.ToString(iVal) + if err != nil { + return ret, errors.Trace(err) + } + ret.year, ret.month, ret.day, ret.duration, err = mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(iValStr)) + if err != nil { + return ret, errors.Trace(err) + } + + return ret, nil +} + +var reg = regexp.MustCompile(`[\d]+`) + +func (da *DateArith) evalDaysForm(val interface{}) (interface{}, error) { + switch val.(type) { + case string: + if strings.ToLower(val.(string)) == "false" { + return 0, nil + } + if strings.ToLower(val.(string)) == "true" { + return 1, nil + } + val = reg.FindString(val.(string)) + } + + return types.ToInt64(val) } diff --git a/expression/date_arith_test.go b/expression/date_arith_test.go index 2775d7d914..4f492fc640 100644 --- a/expression/date_arith_test.go +++ b/expression/date_arith_test.go @@ -118,6 +118,42 @@ func (t *testDateArithSuite) TestDateArith(c *C) { c.Assert(value.String(), Equals, t.SubExpect) } + // Test eval for adddate and subdate with days form + tblDays := []struct { + Interval interface{} + AddExpect string + SubExpect string + }{ + {"20", "2011-12-01 10:10:10", "2011-10-22 10:10:10"}, + {19.88, "2011-12-01 10:10:10", "2011-10-22 10:10:10"}, + {"19.88", "2011-11-30 10:10:10", "2011-10-23 10:10:10"}, + {"20-11", "2011-12-01 10:10:10", "2011-10-22 10:10:10"}, + {"20,11", "2011-12-01 10:10:10", "2011-10-22 10:10:10"}, + {"1000", "2014-08-07 10:10:10", "2009-02-14 10:10:10"}, + {"true", "2011-11-12 10:10:10", "2011-11-10 10:10:10"}, + } + for _, t := range tblDays { + e := &DateArith{ + Op: DateAdd, + Form: DateArithDaysForm, + Unit: "day", + Date: Value{Val: input}, + Interval: Value{Val: t.Interval}, + } + v, err := e.Eval(nil, nil) + c.Assert(err, IsNil) + value, ok := v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(value.String(), Equals, t.AddExpect) + + e.Op = DateSub + v, err = e.Eval(nil, nil) + c.Assert(err, IsNil) + value, ok = v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(value.String(), Equals, t.SubExpect) + } + // Test error. errInput := "20111111 10:10:10" errTbl := []struct { @@ -147,12 +183,7 @@ func (t *testDateArithSuite) TestDateArith(c *C) { v, err := e.Eval(nil, nil) c.Assert(err, NotNil, Commentf("%s", v)) - e = &DateArith{ - Op: DateSub, - Unit: t.Unit, - Date: Value{Val: input}, - Interval: Value{Val: t.Interval}, - } + e.Op = DateSub _, err = e.Eval(nil, nil) c.Assert(err, NotNil) e.Date = Value{Val: errInput} diff --git a/expression/variable.go b/expression/variable.go index 07acfebb0d..e58169fbde 100644 --- a/expression/variable.go +++ b/expression/variable.go @@ -62,6 +62,7 @@ func (v *Variable) String() string { func (v *Variable) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) { name := strings.ToLower(v.Name) sessionVars := variable.GetSessionVars(ctx) + globalVars := variable.GetGlobalSysVarAccessor(ctx) if !v.IsSystem { // user vars if value, ok := sessionVars.Users[name]; ok { @@ -82,7 +83,7 @@ func (v *Variable) Eval(ctx context.Context, args map[interface{}]interface{}) ( return value, nil } } - value, err := ctx.(variable.GlobalSysVarAccessor).GetGlobalSysVar(ctx, name) + value, err := globalVars.GetGlobalSysVar(ctx, name) if err != nil { return nil, errors.Trace(err) } diff --git a/expression/variable_test.go b/expression/variable_test.go index 2e976fd47f..eb6954fd8e 100644 --- a/expression/variable_test.go +++ b/expression/variable_test.go @@ -27,8 +27,10 @@ type testVariableSuite struct { } func (s *testVariableSuite) SetUpSuite(c *C) { - s.ctx = mock.NewContext() + nc := mock.NewContext() + s.ctx = nc variable.BindSessionVars(s.ctx) + variable.BindGlobalSysVarAccessor(s.ctx, nc) } func (s *testVariableSuite) TestVariable(c *C) { diff --git a/expression/visitor.go b/expression/visitor.go index 6646d59db7..80c5ab46d9 100644 --- a/expression/visitor.go +++ b/expression/visitor.go @@ -108,7 +108,7 @@ type Visitor interface { VisitFunctionTrim(v *FunctionTrim) (Expression, error) // VisitDateArith visits DateArith expression. - VisitDateArith(dc *DateArith) (Expression, error) + VisitDateArith(da *DateArith) (Expression, error) } // BaseVisitor is the base implementation of Visitor. diff --git a/infoschema/infoschema.go b/infoschema/infoschema.go index 7ce9b36c79..8bd2c77c54 100644 --- a/infoschema/infoschema.go +++ b/infoschema/infoschema.go @@ -16,11 +16,11 @@ package infoschema import ( "sync/atomic" - "github.com/juju/errors" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/terror" // import table implementation to init table.TableFromMeta _ "github.com/pingcap/tidb/table/tables" ) @@ -105,7 +105,7 @@ func (is *infoSchema) SchemaExists(schema model.CIStr) bool { func (is *infoSchema) TableByName(schema, table model.CIStr) (t table.Table, err error) { id, ok := is.tableNameToID[tableName{schema: schema.L, table: table.L}] if !ok { - return nil, errors.Errorf("table %s.%s does not exist", schema, table) + return nil, terror.TableNotExists.Gen("table %s.%s does not exist", schema, table) } t = is.tables[id] return diff --git a/mysql/const.go b/mysql/const.go index 6cd649380b..3172ec6932 100644 --- a/mysql/const.go +++ b/mysql/const.go @@ -130,6 +130,8 @@ const ( ColumnPrivTable = "Columns_priv" // GlobalVariablesTable is the table contains global system variables. GlobalVariablesTable = "GLOBAL_VARIABLES" + // TiDBTable is the table contains tidb info. + TiDBTable = "tidb" ) // PrivilegeType privilege diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index a136819ae1..9fe63dd268 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -410,6 +410,9 @@ func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) { case ast.DateSub: oldDateArith.Op = expression.DateSub } + if v.Form == ast.DateArithDaysForm { + oldDateArith.Form = expression.DateArithDaysForm + } c.exprMap[v] = oldDateArith } diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index df769237f1..e578da288a 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -174,6 +174,13 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up return oldUpdate, nil } +func getInnerFromParentheses(expr ast.ExprNode) ast.ExprNode { + if pexpr, ok := expr.(*ast.ParenthesesExpr); ok { + return getInnerFromParentheses(pexpr.Expr) + } + return expr +} + func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) { oldSelect := &stmts.SelectStmt{ Distinct: s.Distinct, @@ -189,9 +196,21 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se if err != nil { return nil, errors.Trace(err) } - // TODO: handle parenthesesed column name expression, which should not set AsName. - if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" { - oldField.AsName = val.Text() + if oldField.AsName == "" { + innerExpr := getInnerFromParentheses(val.Expr) + switch innerExpr.(type) { + case *ast.ColumnNameExpr: + // Do not set column name as name and remove parentheses. + oldField.Expr = converter.exprMap[innerExpr] + case *ast.ValueExpr: + if innerExpr.Text() != "" { + oldField.AsName = innerExpr.Text() + } else { + oldField.AsName = val.Text() + } + default: + oldField.AsName = val.Text() + } } } else if val.WildCard != nil { str := "*" @@ -847,6 +866,10 @@ func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowSt oldShow.Target = stmt.ShowVariables case ast.ShowWarnings: oldShow.Target = stmt.ShowWarnings + case ast.ShowTableStatus: + oldShow.Target = stmt.ShowTableStatus + case ast.ShowTriggers: + oldShow.Target = stmt.ShowTriggers case ast.ShowNone: oldShow.Target = stmt.ShowNone } diff --git a/parser/parser.y b/parser/parser.y index be5988952a..7a33b945ba 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -59,6 +59,7 @@ import ( abs "ABS" add "ADD" + addDate "ADDDATE" after "AFTER" all "ALL" alter "ALTER" @@ -134,6 +135,7 @@ import ( explain "EXPLAIN" extract "EXTRACT" falseKwd "false" + fields "FIELDS" first "FIRST" foreign "FOREIGN" forKwd "FOR" @@ -230,6 +232,7 @@ import ( start "START" status "STATUS" stringType "string" + subDate "SUBDATE" substring "SUBSTRING" substringIndex "SUBSTRING_INDEX" sum "SUM" @@ -241,6 +244,7 @@ import ( to "TO" trailing "TRAILING" transaction "TRANSACTION" + triggers "TRIGGERS" trim "TRIM" trueKwd "true" truncate "TRUNCATE" @@ -376,6 +380,9 @@ import ( CreateTableStmt "CREATE TABLE statement" CreateUserStmt "CREATE User statement" CrossOpt "Cross join option" + DateArithOpt "Date arith dateadd or datesub option" + DateArithMultiFormsOpt "Date arith adddate or subdate option" + DateArithInterval "Date arith interval part" DatabaseSym "DATABASE or SCHEMA" DBName "Database Name" DeallocateSym "Deallocate or drop" @@ -1529,12 +1536,6 @@ Field: { expr := $1.(ast.ExprNode) asName := $2.(string) - if asName != "" { - // Set expr original text. - offset := yyS[yypt-1].offset - end := yyS[yypt].offset-1 - expr.SetText(yylex.(*lexer).src[offset:end]) - } $$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)} } @@ -1570,7 +1571,7 @@ FieldList: Field { field := $1.(*ast.SelectField) - field.Offset = yyS[yypt].offset + field.Offset = yylex.(*lexer).startOffset(yyS[yypt].offset) $$ = []*ast.SelectField{field} } | FieldList ',' Field @@ -1578,12 +1579,13 @@ FieldList: fl := $1.([]*ast.SelectField) last := fl[len(fl)-1] + l := yylex.(*lexer) if last.Expr != nil && last.AsName.O == "" { - lastEnd := yyS[yypt-1].offset-1 // Comma offset. - last.SetText(yylex.(*lexer).src[last.Offset:lastEnd]) + lastEnd := l.endOffset(yyS[yypt-1].offset) + last.SetText(l.src[last.Offset:lastEnd]) } newField := $3.(*ast.SelectField) - newField.Offset = yyS[yypt].offset + newField.Offset = l.startOffset(yyS[yypt].offset) $$ = append(fl, newField) } @@ -1658,12 +1660,14 @@ UnReservedKeyword: | "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN" | "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION" | "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS" -| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" | "STATUS" +| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" | "STATUS" | "FIELDS" | "TRIGGERS" NotKeywordToken: - "ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT" -| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" -| "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK" + "ABS" | "ADDDATE" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH" +| "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT"| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" +| "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS" +| "SUBDATE" | "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" +| "YEARWEEK" /************************************************************************************ * @@ -1872,7 +1876,12 @@ Operand: } | '(' Expression ')' { - $$ = &ast.ParenthesesExpr{Expr: $2.(ast.ExprNode)} + l := yylex.(*lexer) + startOffset := l.startOffset(yyS[yypt-1].offset) + endOffset := l.endOffset(yyS[yypt].offset) + expr := $2.(ast.ExprNode) + expr.SetText(l.src[startOffset:endOffset]) + $$ = &ast.ParenthesesExpr{Expr: expr} } | "DEFAULT" %prec lowerThanLeftParen { @@ -2136,22 +2145,22 @@ FunctionCallNonKeyword: { $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}} } -| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' +| DateArithOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')' { $$ = &ast.FuncDateArithExpr{ - Op:ast.DateAdd, - Unit: $7.(string), + Op: $1.(ast.DateArithType), Date: $3.(ast.ExprNode), - Interval: $6.(ast.ExprNode), + DateArithInterval: ast.DateArithInterval{ + Unit: $7.(string), + Interval: $6.(ast.ExprNode)}, } } -| "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')' +| DateArithMultiFormsOpt '(' Expression ',' DateArithInterval')' { $$ = &ast.FuncDateArithExpr{ - Op:ast.DateSub, - Unit: $7.(string), + Op: $1.(ast.DateArithType), Date: $3.(ast.ExprNode), - Interval: $6.(ast.ExprNode), + DateArithInterval: $5.(ast.DateArithInterval), } } | "EXTRACT" '(' TimeUnit "FROM" Expression ')' @@ -2329,6 +2338,40 @@ FunctionCallNonKeyword: $$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)} } +DateArithOpt: + "DATE_ADD" + { + $$ = ast.DateAdd + } +| "DATE_SUB" + { + $$ = ast.DateSub + } + +DateArithMultiFormsOpt: + "ADDDATE" + { + $$ = ast.DateAdd + } +| "SUBDATE" + { + $$ = ast.DateSub + } + +DateArithInterval: + Expression + { + $$ = ast.DateArithInterval{ + Form: ast.DateArithDaysForm, + Unit: "day", + Interval: $1.(ast.ExprNode), + } + } +| "INTERVAL" Expression TimeUnit + { + $$ = ast.DateArithInterval{Unit: $3.(string), Interval: $2.(ast.ExprNode)} + } + TrimDirection: "BOTH" { @@ -2829,7 +2872,9 @@ TableFactor: | '(' SelectStmt ')' TableAsName { st := $2.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1) + l := yylex.(*lexer) + endOffset := l.endOffset(yyS[yypt-1].offset) + l.SetLastSelectFieldText(st, endOffset) $$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)} } | '(' UnionStmt ')' TableAsName @@ -2976,7 +3021,9 @@ SubSelect: '(' SelectStmt ')' { s := $2.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1) + l := yylex.(*lexer) + endOffset := l.endOffset(yyS[yypt].offset) + l.SetLastSelectFieldText(s, endOffset) src := yylex.(*lexer).src // See the implemention of yyParse function s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1]) @@ -3013,7 +3060,9 @@ UnionStmt: union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) lastSelect := union.Selects[len(union.Selects)-1] - yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + l := yylex.(*lexer) + endOffset := l.endOffset(yyS[yypt-2].offset) + l.SetLastSelectFieldText(lastSelect, endOffset) union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) $$ = union } @@ -3022,9 +3071,12 @@ UnionStmt: union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) lastSelect := union.Selects[len(union.Selects)-1] - yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1) + l := yylex.(*lexer) + endOffset := l.endOffset(yyS[yypt-6].offset) + l.SetLastSelectFieldText(lastSelect, endOffset) st := $5.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1) + endOffset = l.endOffset(yyS[yypt-2].offset) + l.SetLastSelectFieldText(st, endOffset) union.Selects = append(union.Selects, st) if $7 != nil { union.OrderBy = $7.(*ast.OrderByClause) @@ -3048,7 +3100,9 @@ UnionClauseList: union := $1.(*ast.UnionStmt) union.Distinct = union.Distinct || $3.(bool) lastSelect := union.Selects[len(union.Selects)-1] - yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1) + l := yylex.(*lexer) + endOffset := l.endOffset(yyS[yypt-2].offset) + l.SetLastSelectFieldText(lastSelect, endOffset) union.Selects = append(union.Selects, $4.(*ast.SelectStmt)) $$ = union } @@ -3058,7 +3112,9 @@ UnionSelect: | '(' SelectStmt ')' { st := $2.(*ast.SelectStmt) - yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1) + l := yylex.(*lexer) + endOffset := l.endOffset(yyS[yypt].offset) + l.SetLastSelectFieldText(st, endOffset) $$ = st } @@ -3268,6 +3324,16 @@ ShowStmt: Full: $2.(bool), } } +| "SHOW" OptFull "FIELDS" ShowTableAliasOpt ShowDatabaseNameOpt + { + // SHOW FIELDS is a synonym for SHOW COLUMNS. + $$ = &ast.ShowStmt{ + Tp: ast.ShowColumns, + Table: $4.(*ast.TableName), + DBName: $5.(string), + Full: $2.(bool), + } + } | "SHOW" "WARNINGS" { $$ = &ast.ShowStmt{Tp: ast.ShowWarnings} @@ -3317,6 +3383,21 @@ ShowStmt: User: $4.(string), } } +| "SHOW" "TRIGGERS" ShowDatabaseNameOpt ShowLikeOrWhereOpt + { + stmt := &ast.ShowStmt{ + Tp: ast.ShowTriggers, + DBName: $3.(string), + } + if $4 != nil { + if x, ok := $4.(*ast.PatternLikeExpr); ok { + stmt.Pattern = x + } else { + stmt.Where = $4.(ast.ExprNode) + } + } + $$ = stmt + } ShowLikeOrWhereOpt: { diff --git a/parser/parser_test.go b/parser/parser_test.go index c0569ebd4d..24727d8e68 100644 --- a/parser/parser_test.go +++ b/parser/parser_test.go @@ -244,6 +244,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) { {`SHOW FULL TABLES WHERE Table_Type != 'VIEW'`, true}, {`SHOW GRANTS`, true}, {`SHOW GRANTS FOR 'test'@'localhost'`, true}, + {`SHOW COLUMNS FROM City;`, true}, + {`SHOW FIELDS FROM City;`, true}, + {`SHOW TRIGGERS LIKE 't'`, true}, // For default value {"CREATE TABLE sbtest (id INTEGER UNSIGNED NOT NULL AUTO_INCREMENT, k integer UNSIGNED DEFAULT '0' NOT NULL, c char(120) DEFAULT '' NOT NULL, pad char(60) DEFAULT '' NOT NULL, PRIMARY KEY (id) )", true}, @@ -421,6 +424,31 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, {`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + // For adddate + {`select adddate("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", 10)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", 0.10)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", "11,11")`, true}, + // For date_sub {`select date_sub("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, {`select date_sub("2011-11-11 10:10:10.123456", interval 10 second)`, true}, @@ -442,6 +470,31 @@ func (s *testParserSuite) TestBuiltin(c *C) { {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, {`select date_sub("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, {`select date_sub("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + + // For subdate + {`select subdate("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 10 second)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 10 minute)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 10 hour)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 10 day)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 1 week)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 1 month)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 1 quarter)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval 1 year)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true}, + {`select subdate("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", 10)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", 0.10)`, true}, + {`select adddate("2011-11-11 10:10:10.123456", "11,11")`, true}, } s.RunTest(c, table) } diff --git a/parser/scanner.l b/parser/scanner.l index da09e647be..8884b4c1fc 100644 --- a/parser/scanner.l +++ b/parser/scanner.l @@ -128,6 +128,22 @@ func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) { } } +func (l *lexer) startOffset(offset int) int { + offset-- + for unicode.IsSpace(rune(l.src[offset])) { + offset++ + } + return offset +} + +func (l *lexer) endOffset(offset int) int { + offset-- + for offset > 0 && unicode.IsSpace(rune(l.src[offset-1])) { + offset-- + } + return offset +} + func (l *lexer) unget(b byte) { l.ungetBuf = append(l.ungetBuf, b) l.i-- @@ -266,6 +282,7 @@ z [zZ] abs {a}{b}{s} add {a}{d}{d} +adddate {a}{d}{d}{d}{a}{t}{e} after {a}{f}{t}{e}{r} all {a}{l}{l} alter {a}{l}{t}{e}{r} @@ -334,6 +351,7 @@ execute {e}{x}{e}{c}{u}{t}{e} exists {e}{x}{i}{s}{t}{s} explain {e}{x}{p}{l}{a}{i}{n} extract {e}{x}{t}{r}{a}{c}{t} +fields {f}{i}{e}{l}{d}{s} first {f}{i}{r}{s}{t} for {f}{o}{r} foreign {f}{o}{r}{e}{i}{g}{n} @@ -415,6 +433,7 @@ show {s}{h}{o}{w} some {s}{o}{m}{e} start {s}{t}{a}{r}{t} status {s}{t}{a}{t}{u}{s} +subdate {s}{u}{b}{d}{a}{t}{e} substring {s}{u}{b}{s}{t}{r}{i}{n}{g} substring_index {s}{u}{b}{s}{t}{r}{i}{n}{g}_{i}{n}{d}{e}{x} sum {s}{u}{m} @@ -425,6 +444,7 @@ then {t}{h}{e}{n} to {t}{o} trailing {t}{r}{a}{i}{l}{i}{n}{g} transaction {t}{r}{a}{n}{s}{a}{c}{t}{i}{o}{n} +triggers {t}{r}{i}{g}{g}{e}{r}{s} trim {t}{r}{i}{m} truncate {t}{r}{u}{n}{c}{a}{t}{e} max {m}{a}{x} @@ -598,6 +618,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {abs} lval.item = string(l.val) return abs {add} return add +{adddate} return addDate {after} lval.item = string(l.val) return after {all} return all @@ -661,10 +682,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {database} lval.item = string(l.val) return database {databases} return databases -{date_add} lval.item = string(l.val) - return dateAdd -{date_sub} lval.item = string(l.val) - return dateSub +{date_add} return dateAdd +{date_sub} return dateSub {day} lval.item = string(l.val) return day {dayofweek} lval.item = string(l.val) @@ -712,6 +731,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {explain} return explain {extract} lval.item = string(l.val) return extract +{fields} lval.item = string(l.val) + return fields {first} lval.item = string(l.val) return first {for} return forKwd @@ -856,6 +877,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {set} return set {share} return share {show} return show +{subdate} return subDate {substring} lval.item = string(l.val) return substring {substring_index} lval.item = string(l.val) @@ -872,6 +894,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h} {trailing} return trailing {transaction} lval.item = string(l.val) return transaction +{triggers} lval.item = string(l.val) + return triggers {trim} lval.item = string(l.val) return trim {truncate} lval.item = string(l.val) diff --git a/plan/plans/info_test.go b/plan/plans/info_test.go index dc36cfed0f..3db4706f44 100644 --- a/plan/plans/info_test.go +++ b/plan/plans/info_test.go @@ -86,7 +86,7 @@ func (p *testInfoSchemaSuit) TestInfoSchema(c *C) { cnt = mustQuery(c, testDB, "select * from information_schema.columns") c.Assert(cnt, Greater, 0) cnt = mustQuery(c, testDB, "select * from information_schema.statistics") - c.Assert(cnt, Equals, 15) + c.Assert(cnt, Equals, 16) cnt = mustQuery(c, testDB, "select * from information_schema.character_sets") c.Assert(cnt, Greater, 0) cnt = mustQuery(c, testDB, "select * from information_schema.collations") diff --git a/plan/plans/show.go b/plan/plans/show.go index 260af1d702..d157519307 100644 --- a/plan/plans/show.go +++ b/plan/plans/show.go @@ -120,6 +120,11 @@ func (s *ShowPlan) GetFields() []*field.ResultField { names = []string{"Table", "Create Table"} case stmt.ShowGrants: names = []string{fmt.Sprintf("Grants for %s", s.User)} + case stmt.ShowTriggers: + names = []string{"Trigger", "Event", "Table", "Statement", "Timing", "Created", + "sql_mode", "Definer", "character_set_client", "collation_connection", "Database Collation"} + types = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, + mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar} } fields := make([]*field.ResultField, 0, len(names)) for i, name := range names { @@ -181,6 +186,8 @@ func (s *ShowPlan) fetchAll(ctx context.Context) error { return s.fetchShowCreateTable(ctx) case stmt.ShowGrants: return s.fetchShowGrants(ctx) + case stmt.ShowTriggers: + return s.fetchShowTriggers(ctx) } return nil } @@ -396,7 +403,7 @@ func (s *ShowPlan) fetchShowTableStatus(ctx context.Context) error { } now := mysql.GetCurrentTime(mysql.TypeDatetime) data := []interface{}{ - v, "", "", "", "", 100, 100, + v, "InnoDB", "10", "Compact", 100, 100, 100, 100, 100, 100, 100, now, now, now, "utf8_general_ci", "", "", "", @@ -407,6 +414,7 @@ func (s *ShowPlan) fetchShowTableStatus(ctx context.Context) error { } func (s *ShowPlan) fetchShowVariables(ctx context.Context) error { sessionVars := variable.GetSessionVars(ctx) + globalVars := variable.GetGlobalSysVarAccessor(ctx) m := map[interface{}]interface{}{} for _, v := range variable.SysVars { @@ -432,13 +440,19 @@ func (s *ShowPlan) fetchShowVariables(ctx context.Context) error { var value string if !s.GlobalScope { - // Try to get Session Scope variable value + // Try to get Session Scope variable value first. sv, ok := sessionVars.Systems[v.Name] if ok { value = sv + } else { + // If session scope variable is not set, get the global scope value. + value, err = globalVars.GetGlobalSysVar(ctx, v.Name) + if err != nil { + return errors.Trace(err) + } } } else { - value, err = ctx.(variable.GlobalSysVarAccessor).GetGlobalSysVar(ctx, v.Name) + value, err = globalVars.GetGlobalSysVar(ctx, v.Name) if err != nil { return errors.Trace(err) } @@ -462,6 +476,7 @@ func (s *ShowPlan) fetchShowCharset(ctx context.Context) error { } func (s *ShowPlan) fetchShowEngines(ctx context.Context) error { + // Mock data row := &plan.Row{ Data: []interface{}{"InnoDB", "DEFAULT", "Supports transactions, row-level locking, and foreign keys", "YES", "YES", "YES"}, } @@ -574,3 +589,7 @@ func (s *ShowPlan) fetchShowGrants(ctx context.Context) error { } return nil } + +func (s *ShowPlan) fetchShowTriggers(ctx context.Context) error { + return nil +} diff --git a/plan/plans/show_test.go b/plan/plans/show_test.go index df6c9a425a..54e1462eb0 100644 --- a/plan/plans/show_test.go +++ b/plan/plans/show_test.go @@ -53,8 +53,10 @@ type testShowSuit struct { var _ = Suite(&testShowSuit{}) func (p *testShowSuit) SetUpSuite(c *C) { - p.ctx = mock.NewContext() + nc := mock.NewContext() + p.ctx = nc variable.BindSessionVars(p.ctx) + variable.BindGlobalSysVarAccessor(p.ctx, nc) p.dbName = "testshowplan" p.store = newStore(c, p.dbName) @@ -180,6 +182,28 @@ func (p *testShowSuit) TestShowVariables(c *C) { c.Assert(v, Equals, "on") } +func (p *testShowSuit) TestIssue540(c *C) { + // Show variables where variable_name="time_zone" + pln := &plans.ShowPlan{ + Target: stmt.ShowVariables, + GlobalScope: false, + Pattern: &expression.PatternLike{ + Pattern: &expression.Value{ + Val: "time_zone", + }, + }, + } + // Make sure the session scope var is not set. + sessionVars := variable.GetSessionVars(p.ctx) + _, ok := sessionVars.Systems["time_zone"] + c.Assert(ok, IsFalse) + + r, err := pln.Next(p.ctx) + c.Assert(err, IsNil) + c.Assert(r.Data[0], Equals, "time_zone") + c.Assert(r.Data[1], Equals, "SYSTEM") +} + func (p *testShowSuit) TestShowCollation(c *C) { pln := &plans.ShowPlan{} diff --git a/session.go b/session.go index 3b64312a3f..151e16ec6f 100644 --- a/session.go +++ b/session.go @@ -256,7 +256,7 @@ func (s *session) ExecRestrictedSQL(ctx context.Context, sql string) (rset.Recor // Check statement for some restriction // For example only support DML on system meta table. // TODO: Add more restrictions. - log.Infof("Executing %s [%s]", st.OriginText(), sql) + log.Debugf("Executing %s [%s]", st.OriginText(), sql) ctx.SetValue(&sqlexec.RestrictedSQLExecutorKeyType{}, true) defer ctx.ClearValue(&sqlexec.RestrictedSQLExecutorKeyType{}) rs, err := st.Exec(ctx) @@ -543,6 +543,9 @@ func CreateSession(store kv.Storage) (Session, error) { variable.BindSessionVars(s) variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusAutocommit, true) + // session implements variable.GlobalSysVarAccessor. Bind it to ctx. + variable.BindGlobalSysVarAccessor(s, s) + // session implements autocommit.Checker. Bind it to ctx autocommit.BindAutocommitChecker(s, s) sessionMu.Lock() diff --git a/sessionctx/variable/sysvar.go b/sessionctx/variable/sysvar.go index 9d5997cad3..50e554d53b 100644 --- a/sessionctx/variable/sysvar.go +++ b/sessionctx/variable/sysvar.go @@ -581,3 +581,27 @@ type GlobalSysVarAccessor interface { // SetGlobalSysVar sets the global system variable name to value. SetGlobalSysVar(ctx context.Context, name string, value string) error } + +// globalSysVarAccessorKeyType is a dummy type to avoid naming collision in context. +type globalSysVarAccessorKeyType int + +// String defines a Stringer function for debugging and pretty printing. +func (k globalSysVarAccessorKeyType) String() string { + return "global_sysvar_accessor" +} + +const accessorKey globalSysVarAccessorKeyType = 0 + +// BindGlobalSysVarAccessor binds global sysvar accessor to context. +func BindGlobalSysVarAccessor(ctx context.Context, accessor GlobalSysVarAccessor) { + ctx.SetValue(accessorKey, accessor) +} + +// GetGlobalSysVarAccessor gets accessor from ctx. +func GetGlobalSysVarAccessor(ctx context.Context) GlobalSysVarAccessor { + v, ok := ctx.Value(accessorKey).(GlobalSysVarAccessor) + if !ok { + panic("Miss global sysvar accessor") + } + return v +} diff --git a/stmt/stmt.go b/stmt/stmt.go index b7be28744d..242b4997f8 100644 --- a/stmt/stmt.go +++ b/stmt/stmt.go @@ -59,6 +59,7 @@ const ( ShowCollation ShowCreateTable ShowGrants + ShowTriggers ) const ( diff --git a/stmt/stmts/delete.go b/stmt/stmts/delete.go index 68c0d14938..c16fff4053 100644 --- a/stmt/stmts/delete.go +++ b/stmt/stmts/delete.go @@ -100,14 +100,11 @@ func (s *DeleteStmt) plan(ctx context.Context) (plan.Plan, error) { } func removeRow(ctx context.Context, t table.Table, h int64, data []interface{}) error { - // remove row's all indexies - if err := t.RemoveRowAllIndex(ctx, h, data); err != nil { - return err - } - // remove row - if err := t.RemoveRow(ctx, h); err != nil { - return err + err := t.RemoveRecord(ctx, h, data) + if err != nil { + return errors.Trace(err) } + variable.GetSessionVars(ctx).AddAffectedRows(1) return nil } diff --git a/stmt/stmts/drop.go b/stmt/stmts/drop.go index 1111eb2089..3831654d15 100644 --- a/stmt/stmts/drop.go +++ b/stmt/stmts/drop.go @@ -25,6 +25,8 @@ import ( "github.com/pingcap/tidb/context" "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/privilege" "github.com/pingcap/tidb/rset" "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/stmt" @@ -109,15 +111,37 @@ func (s *DropTableStmt) SetText(text string) { // Exec implements the stmt.Statement Exec interface. func (s *DropTableStmt) Exec(ctx context.Context) (rset.Recordset, error) { var notExistTables []string + is := sessionctx.GetDomain(ctx).InfoSchema() for _, ti := range s.TableIdents { - err := sessionctx.GetDomain(ctx).DDL().DropTable(ctx, ti.Full(ctx)) - // TODO: we should return special error for table not exist, checking "not exist" is not enough, - // because some other errors may contain this error string too. + fullti := ti.Full(ctx) + schema, ok := is.SchemaByName(fullti.Schema) + if !ok { + // TODO: we should return special error for table not exist, checking "not exist" is not enough, + // because some other errors may contain this error string too. + notExistTables = append(notExistTables, ti.String()) + continue + } + tb, err := is.TableByName(fullti.Schema, fullti.Name) if err != nil && strings.HasSuffix(err.Error(), "not exist") { notExistTables = append(notExistTables, ti.String()) + continue } else if err != nil { return nil, errors.Trace(err) } + // Check Privilege + privChecker := privilege.GetPrivilegeChecker(ctx) + hasPriv, err := privChecker.Check(ctx, schema, tb.Meta(), mysql.DropPriv) + if err != nil { + return nil, errors.Trace(err) + } + if !hasPriv { + return nil, errors.Errorf("You do not have the privilege to drop table %s.%s.", ti.Schema, ti.Name) + } + + err = sessionctx.GetDomain(ctx).DDL().DropTable(ctx, fullti) + if err != nil { + return nil, errors.Trace(err) + } } if len(notExistTables) > 0 && !s.IfExists { return nil, errors.Errorf("DROP TABLE: table %s does not exist", strings.Join(notExistTables, ",")) diff --git a/stmt/stmts/set.go b/stmt/stmts/set.go index af9d4c0cae..b266b8c792 100644 --- a/stmt/stmts/set.go +++ b/stmt/stmts/set.go @@ -100,7 +100,7 @@ func (s *SetStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { log.Debug("Set sys/user variables") sessionVars := variable.GetSessionVars(ctx) - + globalVars := variable.GetGlobalSysVarAccessor(ctx) for _, v := range s.Variables { // Variable is case insensitive, we use lower case. name := strings.ToLower(v.Name) @@ -138,7 +138,7 @@ func (s *SetStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) { if err != nil { return nil, errors.Trace(err) } - err = ctx.(variable.GlobalSysVarAccessor).SetGlobalSysVar(ctx, name, svalue) + err = globalVars.SetGlobalSysVar(ctx, name, svalue) return nil, errors.Trace(err) } return nil, errors.Errorf("Variable '%s' is a SESSION variable and can't be used with SET GLOBAL", name) diff --git a/store/localstore/kv_test.go b/store/localstore/kv_test.go index 67264d2071..fea59447d5 100644 --- a/store/localstore/kv_test.go +++ b/store/localstore/kv_test.go @@ -451,7 +451,8 @@ func (s *testKVSuite) TestConditionIfNotExist(c *C) { }() } wg.Wait() - c.Assert(success, Greater, int64(1)) + // At least one txn can success. + c.Assert(success, Greater, int64(0)) // Clean up txn, err := s.s.Begin() diff --git a/store/localstore/txn.go b/store/localstore/txn.go index e31ff236be..4a30423d5d 100644 --- a/store/localstore/txn.go +++ b/store/localstore/txn.go @@ -231,7 +231,7 @@ func (txn *dbTxn) Commit() error { txn.close() }() - return txn.doCommit() + return errors.Trace(txn.doCommit()) } func (txn *dbTxn) CommittedVersion() (kv.Version, error) { diff --git a/structure/hash.go b/structure/hash.go index 3162e819c4..80dbf3236f 100644 --- a/structure/hash.go +++ b/structure/hash.go @@ -30,17 +30,17 @@ type HashPair struct { } type hashMeta struct { - Length int64 + FieldCount int64 } func (meta hashMeta) Value() []byte { buf := make([]byte, 8) - binary.BigEndian.PutUint64(buf[0:8], uint64(meta.Length)) + binary.BigEndian.PutUint64(buf[0:8], uint64(meta.FieldCount)) return buf } func (meta hashMeta) IsEmpty() bool { - return meta.Length <= 0 + return meta.FieldCount <= 0 } // HSet sets the string value of a hash field. @@ -92,34 +92,40 @@ func (t *TxStructure) HGetInt64(key []byte, field []byte) (int64, error) { } func (t *TxStructure) updateHash(key []byte, field []byte, fn func(oldValue []byte) ([]byte, error)) error { - metaKey := t.encodeHashMetaKey(key) - meta, err := t.loadHashMeta(metaKey) - if err != nil { - return errors.Trace(err) - } - dataKey := t.encodeHashDataKey(key, field) - var oldValue []byte - oldValue, err = t.loadHashValue(dataKey) + oldValue, err := t.loadHashValue(dataKey) if err != nil { return errors.Trace(err) } - if oldValue == nil { - meta.Length++ - } - - var newValue []byte - newValue, err = fn(oldValue) + newValue, err := fn(oldValue) if err != nil { return errors.Trace(err) } + // Check if new value is equal to old value. + if bytes.Equal(oldValue, newValue) { + return nil + } + if err = t.txn.Set(dataKey, newValue); err != nil { return errors.Trace(err) } - return errors.Trace(t.txn.Set(metaKey, meta.Value())) + metaKey := t.encodeHashMetaKey(key) + meta, err := t.loadHashMeta(metaKey) + if err != nil { + return errors.Trace(err) + } + + if oldValue == nil { + meta.FieldCount++ + if err = t.txn.Set(metaKey, meta.Value()); err != nil { + return errors.Trace(err) + } + } + + return nil } // HLen gets the number of fields in a hash. @@ -129,7 +135,7 @@ func (t *TxStructure) HLen(key []byte) (int64, error) { if err != nil { return 0, errors.Trace(err) } - return meta.Length, nil + return meta.FieldCount, nil } // HDel deletes one or more hash fields. @@ -154,7 +160,7 @@ func (t *TxStructure) HDel(key []byte, fields ...[]byte) error { return errors.Trace(err) } - meta.Length-- + meta.FieldCount-- } } @@ -254,7 +260,7 @@ func (t *TxStructure) loadHashMeta(metaKey []byte) (hashMeta, error) { return hashMeta{}, errors.Trace(err) } - meta := hashMeta{Length: 0} + meta := hashMeta{FieldCount: 0} if v == nil { return meta, nil } @@ -263,7 +269,7 @@ func (t *TxStructure) loadHashMeta(metaKey []byte) (hashMeta, error) { return meta, errors.New("invalid list meta data") } - meta.Length = int64(binary.BigEndian.Uint64(v[0:8])) + meta.FieldCount = int64(binary.BigEndian.Uint64(v[0:8])) return meta, nil } diff --git a/structure/structure_test.go b/structure/structure_test.go index 3b1afd348d..76c2cdd13c 100644 --- a/structure/structure_test.go +++ b/structure/structure_test.go @@ -193,7 +193,7 @@ func (s *tesTxStructureSuite) TestHash(c *C) { value, err = tx.HGet(key, []byte("fake")) c.Assert(err, IsNil) - c.Assert(err, IsNil) + c.Assert(value, IsNil) keys, err := tx.HKeys(key) c.Assert(err, IsNil) @@ -224,13 +224,41 @@ func (s *tesTxStructureSuite) TestHash(c *C) { c.Assert(err, IsNil) c.Assert(l, Equals, int64(2)) + // Test set new value which equals to old value. + value, err = tx.HGet(key, []byte("1")) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, []byte("1")) + + err = tx.HSet(key, []byte("1"), []byte("1")) + c.Assert(err, IsNil) + + value, err = tx.HGet(key, []byte("1")) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, []byte("1")) + + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(2)) + n, err = tx.HInc(key, []byte("1"), 1) c.Assert(err, IsNil) c.Assert(n, Equals, int64(2)) + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(2)) + + n, err = tx.HInc(key, []byte("1"), 1) + c.Assert(err, IsNil) + c.Assert(n, Equals, int64(3)) + + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(2)) + n, err = tx.HGetInt64(key, []byte("1")) c.Assert(err, IsNil) - c.Assert(n, Equals, int64(2)) + c.Assert(n, Equals, int64(3)) l, err = tx.HLen(key) c.Assert(err, IsNil) @@ -246,6 +274,55 @@ func (s *tesTxStructureSuite) TestHash(c *C) { err = tx.HDel(key, []byte("fake_key")) c.Assert(err, IsNil) + // Test set nil value. + value, err = tx.HGet(key, []byte("nil_key")) + c.Assert(err, IsNil) + c.Assert(value, IsNil) + + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(0)) + + err = tx.HSet(key, []byte("nil_key"), nil) + c.Assert(err, IsNil) + + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(0)) + + err = tx.HSet(key, []byte("nil_key"), []byte("1")) + c.Assert(err, IsNil) + + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(1)) + + value, err = tx.HGet(key, []byte("nil_key")) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, []byte("1")) + + err = tx.HSet(key, []byte("nil_key"), nil) + c.Assert(err, NotNil) + + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(1)) + + value, err = tx.HGet(key, []byte("nil_key")) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, []byte("1")) + + err = tx.HSet(key, []byte("nil_key"), []byte("2")) + c.Assert(err, IsNil) + + l, err = tx.HLen(key) + c.Assert(err, IsNil) + c.Assert(l, Equals, int64(1)) + + value, err = tx.HGet(key, []byte("nil_key")) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, []byte("2")) + err = txn.Commit() c.Assert(err, IsNil) diff --git a/table/table.go b/table/table.go index e08a3c1640..67c34db220 100644 --- a/table/table.go +++ b/table/table.go @@ -41,15 +41,9 @@ type Table interface { // Row returns a row for all columns. Row(ctx context.Context, h int64) ([]interface{}, error) - // RemoveRow removes the row of handle h. - RemoveRow(ctx context.Context, h int64) error - // RemoveRowIndex removes an index of a row. RemoveRowIndex(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error - // RemoveRowAllIndex removes all the indices of a row. - RemoveRowAllIndex(ctx context.Context, h int64, rec []interface{}) error - // BuildIndexForRow builds an index for a row. BuildIndexForRow(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error @@ -89,6 +83,9 @@ type Table interface { // UpdateRecord updates a row in the table. UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched []bool) error + // RemoveRecord removes a row in the table. + RemoveRecord(ctx context.Context, h int64, r []interface{}) error + // TableID returns the ID of the table. TableID() int64 diff --git a/table/tables/tables.go b/table/tables/tables.go index 68ccef38bd..8e917545ed 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -444,8 +444,22 @@ func (t *Table) LockRow(ctx context.Context, h int64) error { return errors.Trace(err) } -// RemoveRow implements table.Table RemoveRow interface. -func (t *Table) RemoveRow(ctx context.Context, h int64) error { +// RemoveRecord implements table.Table RemoveRecord interface. +func (t *Table) RemoveRecord(ctx context.Context, h int64, r []interface{}) error { + err := t.removeRowData(ctx, h) + if err != nil { + return errors.Trace(err) + } + + err = t.removeRowIndices(ctx, h, r) + if err != nil { + return errors.Trace(err) + } + + return nil +} + +func (t *Table) removeRowData(ctx context.Context, h int64) error { if err := t.LockRow(ctx, h); err != nil { return errors.Trace(err) } @@ -469,20 +483,8 @@ func (t *Table) RemoveRow(ctx context.Context, h int64) error { return nil } -// RemoveRowIndex implements table.Table RemoveRowIndex interface. -func (t *Table) RemoveRowIndex(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error { - txn, err := ctx.GetTxn(false) - if err != nil { - return errors.Trace(err) - } - if err = idx.X.Delete(txn, vals, h); err != nil { - return errors.Trace(err) - } - return nil -} - -// RemoveRowAllIndex implements table.Table RemoveRowAllIndex interface. -func (t *Table) RemoveRowAllIndex(ctx context.Context, h int64, rec []interface{}) error { +// removeRowAllIndex removes all the indices of a row. +func (t *Table) removeRowIndices(ctx context.Context, h int64, rec []interface{}) error { for _, v := range t.indices { vals, err := v.FetchValues(rec) if vals == nil { @@ -500,6 +502,18 @@ func (t *Table) RemoveRowAllIndex(ctx context.Context, h int64, rec []interface{ return nil } +// RemoveRowIndex implements table.Table RemoveRowIndex interface. +func (t *Table) RemoveRowIndex(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error { + txn, err := ctx.GetTxn(false) + if err != nil { + return errors.Trace(err) + } + if err = idx.X.Delete(txn, vals, h); err != nil { + return errors.Trace(err) + } + return nil +} + // BuildIndexForRow implements table.Table BuildIndexForRow interface. func (t *Table) BuildIndexForRow(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error { txn, err := ctx.GetTxn(false) diff --git a/table/tables/tables_test.go b/table/tables/tables_test.go index edaeeb8c2d..727fe2a9cd 100644 --- a/table/tables/tables_test.go +++ b/table/tables/tables_test.go @@ -88,9 +88,8 @@ func (ts *testSuite) TestBasic(c *C) { return true, nil }) - c.Assert(tb.RemoveRowAllIndex(ctx, rid, []interface{}{1, "cba"}), IsNil) + c.Assert(tb.RemoveRecord(ctx, rid, []interface{}{1, "cba"}), IsNil) - c.Assert(tb.RemoveRow(ctx, rid), IsNil) // Make sure there is index data in the storage. prefix := tb.IndexPrefix() cnt, err := countEntriesWithPrefix(ctx, prefix) diff --git a/terror/terror.go b/terror/terror.go index 7eb56777f8..3845392d54 100644 --- a/terror/terror.go +++ b/terror/terror.go @@ -15,6 +15,7 @@ package terror import ( "fmt" + "runtime" "strconv" "github.com/juju/errors" @@ -23,6 +24,7 @@ import ( // Common base error instances. var ( DatabaseNotExists = ClassSchema.New(CodeDatabaseNotExists, "database not exists") + TableNotExists = ClassSchema.New(CodeTableNotExists, "table not exists") CommitNotInTransaction = ClassExecutor.New(CodeCommitNotInTransaction, "commit not in transaction") RollbackNotInTransaction = ClassExecutor.New(CodeRollbackNotInTransaction, "rollback not in transaction") @@ -35,6 +37,7 @@ type ErrCode int // Schema error codes. const ( CodeDatabaseNotExists ErrCode = iota + 1 + CodeTableNotExists ) // Executor error codes. @@ -89,7 +92,7 @@ func (ec ErrClass) EqualClass(err error) bool { return false } if te, ok := e.(*Error); ok { - return te.Class == ec + return te.class == ec } return false } @@ -103,29 +106,48 @@ func (ec ErrClass) NotEqualClass(err error) bool { // Usually used to create base *Error. func (ec ErrClass) New(code ErrCode, message string) *Error { return &Error{ - Class: ec, - Code: code, - Message: message, + class: ec, + code: code, + message: message, } } // Error implements error interface and adds integer Class and Code, so // errors with different message can be compared. type Error struct { - Class ErrClass - Code ErrCode - Message string + class ErrClass + code ErrCode + message string + file string + line int +} + +// Class returns ErrClass +func (e *Error) Class() ErrClass { + return e.class +} + +// Code returns ErrCode +func (e *Error) Code() ErrCode { + return e.code +} + +// Location returns the location where the error is created, +// implements juju/errors locationer interface. +func (e *Error) Location() (file string, line int) { + return e.file, e.line } // Error implements error interface. func (e *Error) Error() string { - return fmt.Sprintf("[%s:%d]%s", e.Class, e.Code, e.Message) + return fmt.Sprintf("[%s:%d]%s", e.class, e.code, e.message) } // Gen generates a new *Error with the same class and code, and a new formatted message. func (e *Error) Gen(format string, args ...interface{}) *Error { err := *e - err.Message = fmt.Sprintf(format, args...) + err.message = fmt.Sprintf(format, args...) + _, err.file, err.line, _ = runtime.Caller(1) return &err } @@ -136,7 +158,7 @@ func (e *Error) Equal(err error) bool { return false } inErr, ok := originErr.(*Error) - return ok && e.Class == inErr.Class && e.Code == inErr.Code + return ok && e.class == inErr.class && e.code == inErr.code } // NotEqual checks if err is not equal to e. @@ -160,7 +182,7 @@ func ErrorEqual(err1, err2 error) bool { te1, ok1 := e1.(*Error) te2, ok2 := e2.(*Error) if ok1 && ok2 { - return te1.Class == te2.Class && te1.Code == te2.Code + return te1.class == te2.class && te1.code == te2.code } return e1.Error() == e2.Error() diff --git a/terror/terror_test.go b/terror/terror_test.go index e96d6c2754..9cb32b1273 100644 --- a/terror/terror_test.go +++ b/terror/terror_test.go @@ -18,6 +18,7 @@ import ( "github.com/juju/errors" . "github.com/pingcap/check" + "strings" ) func TestT(t *testing.T) { @@ -49,6 +50,27 @@ func (s *testTErrorSuite) TestTError(c *C) { c.Assert(optimizerErr.Equal(errors.New("abc")), IsFalse) } +var predefinedErr = ClassExecutor.New(ErrCode(123), "predefiend error") + +func example() error { + err := call() + return errors.Trace(err) +} + +func call() error { + return predefinedErr.Gen("error message:%s", "abc") +} + +func (s *testTErrorSuite) TestTraceAndLocation(c *C) { + err := example() + stack := errors.ErrorStack(err) + lines := strings.Split(stack, "\n") + c.Assert(len(lines), Equals, 2) + for _, v := range lines { + c.Assert(strings.Contains(v, "terror_test.go"), IsTrue) + } +} + func (s *testTErrorSuite) TestErrorEqual(c *C) { e1 := errors.New("test error") c.Assert(e1, NotNil) diff --git a/tidb_test.go b/tidb_test.go index f635867840..2832e1b0e7 100644 --- a/tidb_test.go +++ b/tidb_test.go @@ -20,6 +20,7 @@ import ( "os" "runtime" "sync" + "sync/atomic" "testing" "time" @@ -28,6 +29,8 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/rset" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/sessionctx/autocommit" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/terror" ) @@ -106,7 +109,8 @@ func checkResult(c *C, r sql.Result, affectedRows int64, insertID int64) { } type testSessionSuite struct { - dbName string + dbName string + dbNameBootstrap string createDBSQL string dropDBSQL string @@ -332,6 +336,7 @@ func sessionExec(c *C, se Session, sql string) ([]rset.Recordset, error) { func (s *testSessionSuite) SetUpSuite(c *C) { s.dbName = "test_session_db" + s.dbNameBootstrap = "test_main_db_bootstrap" s.createDBSQL = fmt.Sprintf("create database if not exists %s;", s.dbName) s.dropDBSQL = fmt.Sprintf("drop database %s;", s.dbName) s.useDBSQL = fmt.Sprintf("use %s;", s.dbName) @@ -1025,6 +1030,61 @@ func (s *testSessionSuite) TestBootstrap(c *C) { c.Assert(v[0], Equals, int64(len(variable.SysVars))) } +// Create a new session on store but only do ddl works. +func (s *testSessionSuite) bootstrapWithError(store kv.Storage, c *C) { + ss := &session{ + values: make(map[fmt.Stringer]interface{}), + store: store, + sid: atomic.AddInt64(&sessionID, 1), + } + domain, err := domap.Get(store) + c.Assert(err, IsNil) + sessionctx.BindDomain(ss, domain) + variable.BindSessionVars(ss) + variable.GetSessionVars(ss).SetStatusFlag(mysql.ServerStatusAutocommit, true) + // session implements autocommit.Checker. Bind it to ctx + autocommit.BindAutocommitChecker(ss, ss) + sessionMu.Lock() + defer sessionMu.Unlock() + b, err := checkBootstrapped(ss) + c.Assert(b, IsFalse) + c.Assert(err, IsNil) + doDDLWorks(ss) + // Leave dml unfinished. +} + +func (s *testSessionSuite) TestBootstrapWithError(c *C) { + store := newStore(c, s.dbNameBootstrap) + s.bootstrapWithError(store, c) + se := newSession(c, store, s.dbNameBootstrap) + mustExecSQL(c, se, "USE mysql;") + r := mustExecSQL(c, se, `select * from user;`) + row, err := r.Next() + c.Assert(err, IsNil) + c.Assert(row, NotNil) + match(c, row.Data, "localhost", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") + row, err = r.Next() + c.Assert(err, IsNil) + c.Assert(row, NotNil) + match(c, row.Data, "127.0.0.1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y") + mustExecSQL(c, se, "USE test;") + // Check privilege tables. + mustExecSQL(c, se, "SELECT * from mysql.db;") + mustExecSQL(c, se, "SELECT * from mysql.tables_priv;") + mustExecSQL(c, se, "SELECT * from mysql.columns_priv;") + // Check privilege tables. + r = mustExecSQL(c, se, "SELECT COUNT(*) from mysql.global_variables;") + v, err := r.FirstRow() + c.Assert(err, IsNil) + c.Assert(v[0], Equals, int64(len(variable.SysVars))) + r = mustExecSQL(c, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="bootstrapped";`) + row, err = r.Next() + c.Assert(err, IsNil) + c.Assert(row, NotNil) + c.Assert(row.Data, HasLen, 1) + c.Assert(row.Data[0], Equals, "True") +} + func (s *testSessionSuite) TestEnum(c *C) { store := newStore(c, s.dbName) se := newSession(c, store, s.dbName) @@ -1287,6 +1347,31 @@ func (s *testSessionSuite) TestBuiltin(c *C) { mustExecFailed(c, se, `select cast("xxx 10:10:10" as datetime)`) } +func (s *testSessionSuite) TestFieldText(c *C) { + store := newStore(c, s.dbName) + se := newSession(c, store, s.dbName) + mustExecSQL(c, se, "drop table if exists t") + mustExecSQL(c, se, "create table t (a int)") + cases := []struct { + sql string + field string + }{ + {"select distinct(a) from t", "a"}, + {"select (1)", "1"}, + {"select (1+1)", "(1+1)"}, + {"select a from t", "a"}, + {"select ((a+1)) from t", "((a+1))"}, + } + for _, v := range cases { + results, err := se.Execute(v.sql) + c.Assert(err, IsNil) + result := results[0] + fields, err := result.Fields() + c.Assert(err, IsNil) + c.Assert(fields[0].Name, Equals, v.field) + } +} + func newSession(c *C, store kv.Storage, dbName string) Session { se, err := CreateSession(store) c.Assert(err, IsNil) diff --git a/util/mock/context.go b/util/mock/context.go index 5cf65c4888..c5f4115a25 100644 --- a/util/mock/context.go +++ b/util/mock/context.go @@ -79,7 +79,7 @@ func (c *Context) SetGlobalSysVar(ctx context.Context, name string, value string } // NewContext creates a new mocked context.Context. -func NewContext() context.Context { +func NewContext() *Context { return &Context{ values: make(map[fmt.Stringer]interface{}), }