Merge master and pass mysqldump dump data

This commit is contained in:
Shen Li
2015-11-11 00:36:10 +08:00
38 changed files with 858 additions and 205 deletions

15
Dockerfile Normal file
View File

@ -0,0 +1,15 @@
FROM golang
VOLUME /opt
RUN apt-get update && apt-get install -y wget git make ; \
cd /opt ; \
export PATH=$GOROOT/bin:$GOPATH/bin:$PATH ; \
go get -d github.com/pingcap/tidb ; \
cd $GOPATH/src/github.com/pingcap/tidb ; \
make ; make server ; cp tidb-server/tidb-server /usr/bin/
EXPOSE 4000
CMD ["/usr/bin/tidb-server"]

View File

@ -338,26 +338,39 @@ func (n *FuncTrimExpr) IsStatic() bool {
return n.Str.IsStatic() && n.RemStr.IsStatic()
}
// DateArithType is type for DateArith option.
// DateArithType is type for DateArith type.
type DateArithType byte
const (
// DateAdd is to run date_add function option.
// DateAdd is to run adddate or date_add function option.
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
DateAdd DateArithType = iota + 1
// DateSub is to run date_sub function option.
// DateSub is to run subdate or date_sub function option.
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
DateSub
// DateArithDaysForm is to run adddate or subdate function with days form Flag.
DateArithDaysForm
)
// DateArithInterval is the struct of DateArith interval part.
type DateArithInterval struct {
// Form is the flag of DateArith running form.
// The function runs with interval or days.
Form DateArithType
Unit string
Interval ExprNode
}
// FuncDateArithExpr is the struct for date arithmetic functions.
type FuncDateArithExpr struct {
funcNode
Op DateArithType
Unit string
Date ExprNode
Interval ExprNode
// Op is used for distinguishing date_add and date_sub.
Op DateArithType
Date ExprNode
DateArithInterval
}
// Accept implements Node Accept interface.
@ -375,11 +388,11 @@ func (n *FuncDateArithExpr) Accept(v Visitor) (Node, bool) {
n.Date = node.(ExprNode)
}
if n.Interval != nil {
node, ok := n.Date.Accept(v)
node, ok := n.Interval.Accept(v)
if !ok {
return n, false
}
n.Date = node.(ExprNode)
n.Interval = node.(ExprNode)
}
return v.Leave(n)
}

View File

@ -166,6 +166,7 @@ const (
ShowCollation
ShowCreateTable
ShowGrants
ShowTriggers
)
// ShowStmt is a statement to provide information about databases, tables, columns and so on.

View File

@ -22,6 +22,7 @@ import (
"runtime/debug"
"strings"
"github.com/juju/errors"
"github.com/ngaut/log"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/sessionctx/variable"
@ -91,45 +92,94 @@ const (
CreateGloablVariablesTable = `CREATE TABLE if not exists mysql.GLOBAL_VARIABLES(
VARIABLE_NAME VARCHAR(64) Not Null PRIMARY KEY,
VARIABLE_VALUE VARCHAR(1024) DEFAULT Null);`
// CreateTiDBTable is the SQL statement creates a table in system db.
// This table is a key-value struct contains some information used by TiDB.
// Currently we only put bootstrapped in it which indicates if the system is already bootstrapped.
CreateTiDBTable = `CREATE TABLE if not exists mysql.tidb(
VARIABLE_NAME VARCHAR(64) Not Null PRIMARY KEY,
VARIABLE_VALUE VARCHAR(1024) DEFAULT Null,
COMMENT VARCHAR(1024));`
)
// Bootstrap initiates system DB for a store.
func bootstrap(s Session) {
// Create a test database.
mustExecute(s, "CREATE DATABASE IF NOT EXISTS test")
// Check if system db exists.
_, err := s.Execute(fmt.Sprintf("USE %s;", mysql.SystemDB))
if err == nil {
// We have already finished bootstrap.
return
} else if terror.DatabaseNotExists.NotEqual(err) {
b, err := checkBootstrapped(s)
if err != nil {
log.Fatal(err)
}
mustExecute(s, fmt.Sprintf("CREATE DATABASE %s;", mysql.SystemDB))
initUserTable(s)
initPrivTables(s)
initGlobalVariables(s)
if b {
return
}
doDDLWorks(s)
doDMLWorks(s)
}
func initUserTable(s Session) {
const (
bootstrappedVar = "bootstrapped"
bootstrappedVarTrue = "True"
)
func checkBootstrapped(s Session) (bool, error) {
// Check if system db exists.
_, err := s.Execute(fmt.Sprintf("USE %s;", mysql.SystemDB))
if err != nil && terror.DatabaseNotExists.NotEqual(err) {
log.Fatal(err)
}
// Check bootstrapped variable value in TiDB table.
v, err := checkBootstrappedVar(s)
if err != nil {
return false, errors.Trace(err)
}
return v, nil
}
func checkBootstrappedVar(s Session) (bool, error) {
sql := fmt.Sprintf(`SELECT VARIABLE_VALUE FROM %s.%s WHERE VARIABLE_NAME="%s"`, mysql.SystemDB, mysql.TiDBTable, bootstrappedVar)
rs, err := s.Execute(sql)
if err != nil {
if terror.TableNotExists.Equal(err) {
return false, nil
}
return false, errors.Trace(err)
}
if len(rs) != 1 {
return false, errors.New("Wrong number of Recordset")
}
r := rs[0]
row, err := r.Next()
if err != nil || row == nil {
return false, errors.Trace(err)
}
return row.Data[0].(string) == bootstrappedVarTrue, nil
}
// Execute DDL statements in bootstrap stage.
func doDDLWorks(s Session) {
// Create a test database.
mustExecute(s, "CREATE DATABASE IF NOT EXISTS test")
// Create system db.
mustExecute(s, fmt.Sprintf("CREATE DATABASE IF NOT EXISTS %s;", mysql.SystemDB))
// Create user table.
mustExecute(s, CreateUserTable)
// Create privilege tables.
mustExecute(s, CreateDBPrivTable)
mustExecute(s, CreateTablePrivTable)
mustExecute(s, CreateColumnPrivTable)
// Create global systemt variable table.
mustExecute(s, CreateGloablVariablesTable)
// Create TiDB table.
mustExecute(s, CreateTiDBTable)
}
// Execute DML statements in bootstrap stage.
// All the statements run in a single transaction.
func doDMLWorks(s Session) {
mustExecute(s, "BEGIN")
// Insert a default user with empty password.
mustExecute(s, `INSERT INTO mysql.user VALUES ("localhost", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y"),
("127.0.0.1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y"),
("::1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y");`)
}
// Initiates privilege tables including mysql.db, mysql.tables_priv and mysql.column_priv.
func initPrivTables(s Session) {
mustExecute(s, CreateDBPrivTable)
mustExecute(s, CreateTablePrivTable)
mustExecute(s, CreateColumnPrivTable)
}
// Initiates global system variables table.
func initGlobalVariables(s Session) {
mustExecute(s, CreateGloablVariablesTable)
// Init global system variable table.
values := make([]string, 0, len(variable.SysVars))
for k, v := range variable.SysVars {
value := fmt.Sprintf(`("%s", "%s")`, strings.ToLower(k), v.Value)
@ -137,6 +187,9 @@ func initGlobalVariables(s Session) {
}
sql := fmt.Sprintf("INSERT INTO %s.%s VALUES %s;", mysql.SystemDB, mysql.GlobalVariablesTable, strings.Join(values, ", "))
mustExecute(s, sql)
sql = fmt.Sprintf(`INSERT INTO %s.%s VALUES("%s", "%s", "Bootstrap flag. Do not delete.") ON DUPLICATE KEY UPDATE VARIABLE_VALUE="%s"`, mysql.SystemDB, mysql.TiDBTable, bootstrappedVar, bootstrappedVarTrue, bootstrappedVarTrue)
mustExecute(s, sql)
mustExecute(s, "COMMIT")
}
func mustExecute(s Session, sql string) {

View File

@ -55,10 +55,10 @@ func (c *Col) String() string {
}
// FindCol finds column in cols by name.
func FindCol(cols []*Col, name string) (c *Col) {
for _, c = range cols {
func FindCol(cols []*Col, name string) *Col {
for _, c := range cols {
if strings.EqualFold(c.Name.O, name) {
return
return c
}
}
return nil

View File

@ -31,7 +31,6 @@ import (
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/parser/coldef"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/table/tables"
"github.com/pingcap/tidb/terror"
@ -577,16 +576,6 @@ func (d *ddl) DropTable(ctx context.Context, ti table.Ident) (err error) {
if err != nil {
return errors.Trace(err)
}
// Check Privilege
privChecker := privilege.GetPrivilegeChecker(ctx)
hasPriv, err := privChecker.Check(ctx, schema, tb.Meta(), mysql.DropPriv)
if err != nil {
return errors.Trace(err)
}
if !hasPriv {
return errors.Errorf("You do not have the privilege to drop table %s.%s.", ti.Schema, ti.Name)
}
err = kv.RunInNewTxn(d.store, false, func(txn kv.Transaction) error {
t := meta.NewMeta(txn)
err := d.verifySchemaMetaVersion(t, is.SchemaMetaVersion())

View File

@ -15,6 +15,7 @@ package expression
import (
"fmt"
"regexp"
"strings"
"time"
@ -28,22 +29,38 @@ import (
type DateArithType byte
const (
// DateAdd is to run date_add function option.
// DateAdd is to run adddate or date_add function option.
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_adddate
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-add
DateAdd DateArithType = iota + 1
// DateSub is to run date_sub function option.
// DateSub is to run subdate or date_sub function option.
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_subdate
// See: https://dev.mysql.com/doc/refman/5.7/en/date-and-time-functions.html#function_date-sub
DateSub
// DateArithDaysForm is to run adddate or subdate function with days form Flag.
DateArithDaysForm
)
// DateArith is used for dealing with addition and substraction of time.
type DateArith struct {
Op DateArithType
Unit string
// Op is used for distinguishing date_add and date_sub.
Op DateArithType
// Form is the flag of DateArith running form.
// The function runs with interval or days.
Form DateArithType
Date Expression
Unit string
Interval Expression
}
type evalArgsResult struct {
time mysql.Time
year int64
month int64
day int64
duration time.Duration
}
func (da *DateArith) isAdd() bool {
if da.Op == DateAdd {
return true
@ -82,58 +99,88 @@ func (da *DateArith) String() string {
// Eval implements the Expression Eval interface.
func (da *DateArith) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
t, years, months, days, durations, err := da.evalArgs(ctx, args)
if t.IsZero() || err != nil {
val, err := da.evalArgs(ctx, args)
if val.time.IsZero() || err != nil {
return nil, errors.Trace(err)
}
if !da.isAdd() {
years, months, days, durations = -years, -months, -days, -durations
val.year, val.month, val.day, val.duration =
-val.year, -val.month, -val.day, -val.duration
}
t.Time = t.Time.Add(durations)
t.Time = t.Time.AddDate(int(years), int(months), int(days))
val.time.Time = val.time.Time.Add(val.duration)
val.time.Time = val.time.Time.AddDate(int(val.year), int(val.month), int(val.day))
// "2011-11-11 10:10:20.000000" outputs "2011-11-11 10:10:20".
if t.Time.Nanosecond() == 0 {
t.Fsp = 0
if val.time.Time.Nanosecond() == 0 {
val.time.Fsp = 0
}
return t, nil
return val.time, nil
}
func (da *DateArith) evalArgs(ctx context.Context, args map[interface{}]interface{}) (
mysql.Time, int64, int64, int64, time.Duration, error) {
*evalArgsResult, error) {
ret := &evalArgsResult{time: mysql.ZeroTimestamp}
dVal, err := da.Date.Eval(ctx, args)
if dVal == nil || err != nil {
return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err)
return ret, errors.Trace(err)
}
dValStr, err := types.ToString(dVal)
if err != nil {
return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err)
return ret, errors.Trace(err)
}
f := types.NewFieldType(mysql.TypeDatetime)
f.Decimal = mysql.MaxFsp
dVal, err = types.Convert(dValStr, f)
if dVal == nil || err != nil {
return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err)
return ret, errors.Trace(err)
}
t, ok := dVal.(mysql.Time)
var ok bool
ret.time, ok = dVal.(mysql.Time)
if !ok {
return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Errorf("need time type, but got %T", dVal)
return ret, errors.Errorf("need time type, but got %T", dVal)
}
iVal, err := da.Interval.Eval(ctx, args)
if iVal == nil || err != nil {
return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err)
ret.time = mysql.ZeroTimestamp
return ret, errors.Trace(err)
}
iValStr, err := types.ToString(iVal)
if err != nil {
return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err)
}
years, months, days, durations, err := mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(iValStr))
if err != nil {
return mysql.ZeroTimestamp, 0, 0, 0, 0, errors.Trace(err)
// handle adddate(expr,days) or subdate(expr,days) form
if da.Form == DateArithDaysForm {
if iVal, err = da.evalDaysForm(iVal); err != nil {
return ret, errors.Trace(err)
}
}
return t, years, months, days, durations, nil
iValStr, err := types.ToString(iVal)
if err != nil {
return ret, errors.Trace(err)
}
ret.year, ret.month, ret.day, ret.duration, err = mysql.ExtractTimeValue(da.Unit, strings.TrimSpace(iValStr))
if err != nil {
return ret, errors.Trace(err)
}
return ret, nil
}
var reg = regexp.MustCompile(`[\d]+`)
func (da *DateArith) evalDaysForm(val interface{}) (interface{}, error) {
switch val.(type) {
case string:
if strings.ToLower(val.(string)) == "false" {
return 0, nil
}
if strings.ToLower(val.(string)) == "true" {
return 1, nil
}
val = reg.FindString(val.(string))
}
return types.ToInt64(val)
}

View File

@ -118,6 +118,42 @@ func (t *testDateArithSuite) TestDateArith(c *C) {
c.Assert(value.String(), Equals, t.SubExpect)
}
// Test eval for adddate and subdate with days form
tblDays := []struct {
Interval interface{}
AddExpect string
SubExpect string
}{
{"20", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{19.88, "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"19.88", "2011-11-30 10:10:10", "2011-10-23 10:10:10"},
{"20-11", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"20,11", "2011-12-01 10:10:10", "2011-10-22 10:10:10"},
{"1000", "2014-08-07 10:10:10", "2009-02-14 10:10:10"},
{"true", "2011-11-12 10:10:10", "2011-11-10 10:10:10"},
}
for _, t := range tblDays {
e := &DateArith{
Op: DateAdd,
Form: DateArithDaysForm,
Unit: "day",
Date: Value{Val: input},
Interval: Value{Val: t.Interval},
}
v, err := e.Eval(nil, nil)
c.Assert(err, IsNil)
value, ok := v.(mysql.Time)
c.Assert(ok, IsTrue)
c.Assert(value.String(), Equals, t.AddExpect)
e.Op = DateSub
v, err = e.Eval(nil, nil)
c.Assert(err, IsNil)
value, ok = v.(mysql.Time)
c.Assert(ok, IsTrue)
c.Assert(value.String(), Equals, t.SubExpect)
}
// Test error.
errInput := "20111111 10:10:10"
errTbl := []struct {
@ -147,12 +183,7 @@ func (t *testDateArithSuite) TestDateArith(c *C) {
v, err := e.Eval(nil, nil)
c.Assert(err, NotNil, Commentf("%s", v))
e = &DateArith{
Op: DateSub,
Unit: t.Unit,
Date: Value{Val: input},
Interval: Value{Val: t.Interval},
}
e.Op = DateSub
_, err = e.Eval(nil, nil)
c.Assert(err, NotNil)
e.Date = Value{Val: errInput}

View File

@ -62,6 +62,7 @@ func (v *Variable) String() string {
func (v *Variable) Eval(ctx context.Context, args map[interface{}]interface{}) (interface{}, error) {
name := strings.ToLower(v.Name)
sessionVars := variable.GetSessionVars(ctx)
globalVars := variable.GetGlobalSysVarAccessor(ctx)
if !v.IsSystem {
// user vars
if value, ok := sessionVars.Users[name]; ok {
@ -82,7 +83,7 @@ func (v *Variable) Eval(ctx context.Context, args map[interface{}]interface{}) (
return value, nil
}
}
value, err := ctx.(variable.GlobalSysVarAccessor).GetGlobalSysVar(ctx, name)
value, err := globalVars.GetGlobalSysVar(ctx, name)
if err != nil {
return nil, errors.Trace(err)
}

View File

@ -27,8 +27,10 @@ type testVariableSuite struct {
}
func (s *testVariableSuite) SetUpSuite(c *C) {
s.ctx = mock.NewContext()
nc := mock.NewContext()
s.ctx = nc
variable.BindSessionVars(s.ctx)
variable.BindGlobalSysVarAccessor(s.ctx, nc)
}
func (s *testVariableSuite) TestVariable(c *C) {

View File

@ -108,7 +108,7 @@ type Visitor interface {
VisitFunctionTrim(v *FunctionTrim) (Expression, error)
// VisitDateArith visits DateArith expression.
VisitDateArith(dc *DateArith) (Expression, error)
VisitDateArith(da *DateArith) (Expression, error)
}
// BaseVisitor is the base implementation of Visitor.

View File

@ -16,11 +16,11 @@ package infoschema
import (
"sync/atomic"
"github.com/juju/errors"
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/meta/autoid"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/table"
"github.com/pingcap/tidb/terror"
// import table implementation to init table.TableFromMeta
_ "github.com/pingcap/tidb/table/tables"
)
@ -105,7 +105,7 @@ func (is *infoSchema) SchemaExists(schema model.CIStr) bool {
func (is *infoSchema) TableByName(schema, table model.CIStr) (t table.Table, err error) {
id, ok := is.tableNameToID[tableName{schema: schema.L, table: table.L}]
if !ok {
return nil, errors.Errorf("table %s.%s does not exist", schema, table)
return nil, terror.TableNotExists.Gen("table %s.%s does not exist", schema, table)
}
t = is.tables[id]
return

View File

@ -130,6 +130,8 @@ const (
ColumnPrivTable = "Columns_priv"
// GlobalVariablesTable is the table contains global system variables.
GlobalVariablesTable = "GLOBAL_VARIABLES"
// TiDBTable is the table contains tidb info.
TiDBTable = "tidb"
)
// PrivilegeType privilege

View File

@ -410,6 +410,9 @@ func (c *expressionConverter) funcDateArith(v *ast.FuncDateArithExpr) {
case ast.DateSub:
oldDateArith.Op = expression.DateSub
}
if v.Form == ast.DateArithDaysForm {
oldDateArith.Form = expression.DateArithDaysForm
}
c.exprMap[v] = oldDateArith
}

View File

@ -174,6 +174,13 @@ func convertUpdate(converter *expressionConverter, v *ast.UpdateStmt) (*stmts.Up
return oldUpdate, nil
}
func getInnerFromParentheses(expr ast.ExprNode) ast.ExprNode {
if pexpr, ok := expr.(*ast.ParenthesesExpr); ok {
return getInnerFromParentheses(pexpr.Expr)
}
return expr
}
func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.SelectStmt, error) {
oldSelect := &stmts.SelectStmt{
Distinct: s.Distinct,
@ -189,9 +196,21 @@ func convertSelect(converter *expressionConverter, s *ast.SelectStmt) (*stmts.Se
if err != nil {
return nil, errors.Trace(err)
}
// TODO: handle parenthesesed column name expression, which should not set AsName.
if _, ok := oldField.Expr.(*expression.Ident); !ok && oldField.AsName == "" {
oldField.AsName = val.Text()
if oldField.AsName == "" {
innerExpr := getInnerFromParentheses(val.Expr)
switch innerExpr.(type) {
case *ast.ColumnNameExpr:
// Do not set column name as name and remove parentheses.
oldField.Expr = converter.exprMap[innerExpr]
case *ast.ValueExpr:
if innerExpr.Text() != "" {
oldField.AsName = innerExpr.Text()
} else {
oldField.AsName = val.Text()
}
default:
oldField.AsName = val.Text()
}
}
} else if val.WildCard != nil {
str := "*"
@ -847,6 +866,10 @@ func convertShow(converter *expressionConverter, v *ast.ShowStmt) (*stmts.ShowSt
oldShow.Target = stmt.ShowVariables
case ast.ShowWarnings:
oldShow.Target = stmt.ShowWarnings
case ast.ShowTableStatus:
oldShow.Target = stmt.ShowTableStatus
case ast.ShowTriggers:
oldShow.Target = stmt.ShowTriggers
case ast.ShowNone:
oldShow.Target = stmt.ShowNone
}

View File

@ -59,6 +59,7 @@ import (
abs "ABS"
add "ADD"
addDate "ADDDATE"
after "AFTER"
all "ALL"
alter "ALTER"
@ -134,6 +135,7 @@ import (
explain "EXPLAIN"
extract "EXTRACT"
falseKwd "false"
fields "FIELDS"
first "FIRST"
foreign "FOREIGN"
forKwd "FOR"
@ -230,6 +232,7 @@ import (
start "START"
status "STATUS"
stringType "string"
subDate "SUBDATE"
substring "SUBSTRING"
substringIndex "SUBSTRING_INDEX"
sum "SUM"
@ -241,6 +244,7 @@ import (
to "TO"
trailing "TRAILING"
transaction "TRANSACTION"
triggers "TRIGGERS"
trim "TRIM"
trueKwd "true"
truncate "TRUNCATE"
@ -376,6 +380,9 @@ import (
CreateTableStmt "CREATE TABLE statement"
CreateUserStmt "CREATE User statement"
CrossOpt "Cross join option"
DateArithOpt "Date arith dateadd or datesub option"
DateArithMultiFormsOpt "Date arith adddate or subdate option"
DateArithInterval "Date arith interval part"
DatabaseSym "DATABASE or SCHEMA"
DBName "Database Name"
DeallocateSym "Deallocate or drop"
@ -1529,12 +1536,6 @@ Field:
{
expr := $1.(ast.ExprNode)
asName := $2.(string)
if asName != "" {
// Set expr original text.
offset := yyS[yypt-1].offset
end := yyS[yypt].offset-1
expr.SetText(yylex.(*lexer).src[offset:end])
}
$$ = &ast.SelectField{Expr: expr, AsName: model.NewCIStr(asName)}
}
@ -1570,7 +1571,7 @@ FieldList:
Field
{
field := $1.(*ast.SelectField)
field.Offset = yyS[yypt].offset
field.Offset = yylex.(*lexer).startOffset(yyS[yypt].offset)
$$ = []*ast.SelectField{field}
}
| FieldList ',' Field
@ -1578,12 +1579,13 @@ FieldList:
fl := $1.([]*ast.SelectField)
last := fl[len(fl)-1]
l := yylex.(*lexer)
if last.Expr != nil && last.AsName.O == "" {
lastEnd := yyS[yypt-1].offset-1 // Comma offset.
last.SetText(yylex.(*lexer).src[last.Offset:lastEnd])
lastEnd := l.endOffset(yyS[yypt-1].offset)
last.SetText(l.src[last.Offset:lastEnd])
}
newField := $3.(*ast.SelectField)
newField.Offset = yyS[yypt].offset
newField.Offset = l.startOffset(yyS[yypt].offset)
$$ = append(fl, newField)
}
@ -1658,12 +1660,14 @@ UnReservedKeyword:
| "START" | "GLOBAL" | "TABLES"| "TEXT" | "TIME" | "TIMESTAMP" | "TRANSACTION" | "TRUNCATE" | "UNKNOWN"
| "VALUE" | "WARNINGS" | "YEAR" | "MODE" | "WEEK" | "ANY" | "SOME" | "USER" | "IDENTIFIED" | "COLLATION"
| "COMMENT" | "AVG_ROW_LENGTH" | "CONNECTION" | "CHECKSUM" | "COMPRESSION" | "KEY_BLOCK_SIZE" | "MAX_ROWS" | "MIN_ROWS"
| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" | "STATUS"
| "NATIONAL" | "ROW" | "QUARTER" | "ESCAPE" | "GRANTS" | "STATUS" | "FIELDS" | "TRIGGERS"
NotKeywordToken:
"ABS" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH" | "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT"
| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX" | "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS"
| "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR" | "YEARWEEK"
"ABS" | "ADDDATE" | "COALESCE" | "CONCAT" | "CONCAT_WS" | "COUNT" | "DAY" | "DATE_ADD" | "DATE_SUB" | "DAYOFMONTH"
| "DAYOFWEEK" | "DAYOFYEAR" | "FOUND_ROWS" | "GROUP_CONCAT"| "HOUR" | "IFNULL" | "LENGTH" | "LOCATE" | "MAX"
| "MICROSECOND" | "MIN" | "MINUTE" | "NULLIF" | "MONTH" | "NOW" | "RAND" | "SECOND" | "SQL_CALC_FOUND_ROWS"
| "SUBDATE" | "SUBSTRING" %prec lowerThanLeftParen | "SUBSTRING_INDEX" | "SUM" | "TRIM" | "WEEKDAY" | "WEEKOFYEAR"
| "YEARWEEK"
/************************************************************************************
*
@ -1872,7 +1876,12 @@ Operand:
}
| '(' Expression ')'
{
$$ = &ast.ParenthesesExpr{Expr: $2.(ast.ExprNode)}
l := yylex.(*lexer)
startOffset := l.startOffset(yyS[yypt-1].offset)
endOffset := l.endOffset(yyS[yypt].offset)
expr := $2.(ast.ExprNode)
expr.SetText(l.src[startOffset:endOffset])
$$ = &ast.ParenthesesExpr{Expr: expr}
}
| "DEFAULT" %prec lowerThanLeftParen
{
@ -2136,22 +2145,22 @@ FunctionCallNonKeyword:
{
$$ = &ast.FuncCallExpr{FnName: $1.(string), Args: []ast.ExprNode{$3.(ast.ExprNode)}}
}
| "DATE_ADD" '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
| DateArithOpt '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
{
$$ = &ast.FuncDateArithExpr{
Op:ast.DateAdd,
Unit: $7.(string),
Op: $1.(ast.DateArithType),
Date: $3.(ast.ExprNode),
Interval: $6.(ast.ExprNode),
DateArithInterval: ast.DateArithInterval{
Unit: $7.(string),
Interval: $6.(ast.ExprNode)},
}
}
| "DATE_SUB" '(' Expression ',' "INTERVAL" Expression TimeUnit ')'
| DateArithMultiFormsOpt '(' Expression ',' DateArithInterval')'
{
$$ = &ast.FuncDateArithExpr{
Op:ast.DateSub,
Unit: $7.(string),
Op: $1.(ast.DateArithType),
Date: $3.(ast.ExprNode),
Interval: $6.(ast.ExprNode),
DateArithInterval: $5.(ast.DateArithInterval),
}
}
| "EXTRACT" '(' TimeUnit "FROM" Expression ')'
@ -2329,6 +2338,40 @@ FunctionCallNonKeyword:
$$ = &ast.FuncCallExpr{FnName: $1.(string), Args: $3.([]ast.ExprNode)}
}
DateArithOpt:
"DATE_ADD"
{
$$ = ast.DateAdd
}
| "DATE_SUB"
{
$$ = ast.DateSub
}
DateArithMultiFormsOpt:
"ADDDATE"
{
$$ = ast.DateAdd
}
| "SUBDATE"
{
$$ = ast.DateSub
}
DateArithInterval:
Expression
{
$$ = ast.DateArithInterval{
Form: ast.DateArithDaysForm,
Unit: "day",
Interval: $1.(ast.ExprNode),
}
}
| "INTERVAL" Expression TimeUnit
{
$$ = ast.DateArithInterval{Unit: $3.(string), Interval: $2.(ast.ExprNode)}
}
TrimDirection:
"BOTH"
{
@ -2829,7 +2872,9 @@ TableFactor:
| '(' SelectStmt ')' TableAsName
{
st := $2.(*ast.SelectStmt)
yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-1].offset-1)
l := yylex.(*lexer)
endOffset := l.endOffset(yyS[yypt-1].offset)
l.SetLastSelectFieldText(st, endOffset)
$$ = &ast.TableSource{Source: $2.(*ast.SelectStmt), AsName: $4.(model.CIStr)}
}
| '(' UnionStmt ')' TableAsName
@ -2976,7 +3021,9 @@ SubSelect:
'(' SelectStmt ')'
{
s := $2.(*ast.SelectStmt)
yylex.(*lexer).SetLastSelectFieldText(s, yyS[yypt].offset-1)
l := yylex.(*lexer)
endOffset := l.endOffset(yyS[yypt].offset)
l.SetLastSelectFieldText(s, endOffset)
src := yylex.(*lexer).src
// See the implemention of yyParse function
s.SetText(src[yyS[yypt-1].offset-1:yyS[yypt].offset-1])
@ -3013,7 +3060,9 @@ UnionStmt:
union := $1.(*ast.UnionStmt)
union.Distinct = union.Distinct || $3.(bool)
lastSelect := union.Selects[len(union.Selects)-1]
yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1)
l := yylex.(*lexer)
endOffset := l.endOffset(yyS[yypt-2].offset)
l.SetLastSelectFieldText(lastSelect, endOffset)
union.Selects = append(union.Selects, $4.(*ast.SelectStmt))
$$ = union
}
@ -3022,9 +3071,12 @@ UnionStmt:
union := $1.(*ast.UnionStmt)
union.Distinct = union.Distinct || $3.(bool)
lastSelect := union.Selects[len(union.Selects)-1]
yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-6].offset-1)
l := yylex.(*lexer)
endOffset := l.endOffset(yyS[yypt-6].offset)
l.SetLastSelectFieldText(lastSelect, endOffset)
st := $5.(*ast.SelectStmt)
yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt-2].offset-1)
endOffset = l.endOffset(yyS[yypt-2].offset)
l.SetLastSelectFieldText(st, endOffset)
union.Selects = append(union.Selects, st)
if $7 != nil {
union.OrderBy = $7.(*ast.OrderByClause)
@ -3048,7 +3100,9 @@ UnionClauseList:
union := $1.(*ast.UnionStmt)
union.Distinct = union.Distinct || $3.(bool)
lastSelect := union.Selects[len(union.Selects)-1]
yylex.(*lexer).SetLastSelectFieldText(lastSelect, yyS[yypt-2].offset-1)
l := yylex.(*lexer)
endOffset := l.endOffset(yyS[yypt-2].offset)
l.SetLastSelectFieldText(lastSelect, endOffset)
union.Selects = append(union.Selects, $4.(*ast.SelectStmt))
$$ = union
}
@ -3058,7 +3112,9 @@ UnionSelect:
| '(' SelectStmt ')'
{
st := $2.(*ast.SelectStmt)
yylex.(*lexer).SetLastSelectFieldText(st, yyS[yypt].offset-1)
l := yylex.(*lexer)
endOffset := l.endOffset(yyS[yypt].offset)
l.SetLastSelectFieldText(st, endOffset)
$$ = st
}
@ -3268,6 +3324,16 @@ ShowStmt:
Full: $2.(bool),
}
}
| "SHOW" OptFull "FIELDS" ShowTableAliasOpt ShowDatabaseNameOpt
{
// SHOW FIELDS is a synonym for SHOW COLUMNS.
$$ = &ast.ShowStmt{
Tp: ast.ShowColumns,
Table: $4.(*ast.TableName),
DBName: $5.(string),
Full: $2.(bool),
}
}
| "SHOW" "WARNINGS"
{
$$ = &ast.ShowStmt{Tp: ast.ShowWarnings}
@ -3317,6 +3383,21 @@ ShowStmt:
User: $4.(string),
}
}
| "SHOW" "TRIGGERS" ShowDatabaseNameOpt ShowLikeOrWhereOpt
{
stmt := &ast.ShowStmt{
Tp: ast.ShowTriggers,
DBName: $3.(string),
}
if $4 != nil {
if x, ok := $4.(*ast.PatternLikeExpr); ok {
stmt.Pattern = x
} else {
stmt.Where = $4.(ast.ExprNode)
}
}
$$ = stmt
}
ShowLikeOrWhereOpt:
{

View File

@ -244,6 +244,9 @@ func (s *testParserSuite) TestDMLStmt(c *C) {
{`SHOW FULL TABLES WHERE Table_Type != 'VIEW'`, true},
{`SHOW GRANTS`, true},
{`SHOW GRANTS FOR 'test'@'localhost'`, true},
{`SHOW COLUMNS FROM City;`, true},
{`SHOW FIELDS FROM City;`, true},
{`SHOW TRIGGERS LIKE 't'`, true},
// For default value
{"CREATE TABLE sbtest (id INTEGER UNSIGNED NOT NULL AUTO_INCREMENT, k integer UNSIGNED DEFAULT '0' NOT NULL, c char(120) DEFAULT '' NOT NULL, pad char(60) DEFAULT '' NOT NULL, PRIMARY KEY (id) )", true},
@ -421,6 +424,31 @@ func (s *testParserSuite) TestBuiltin(c *C) {
{`select date_add("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true},
{`select date_add("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true},
// For adddate
{`select adddate("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 10 second)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 10 minute)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 10 hour)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 10 day)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 1 week)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 1 month)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 1 quarter)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval 1 year)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true},
{`select adddate("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true},
{`select adddate("2011-11-11 10:10:10.123456", 10)`, true},
{`select adddate("2011-11-11 10:10:10.123456", 0.10)`, true},
{`select adddate("2011-11-11 10:10:10.123456", "11,11")`, true},
// For date_sub
{`select date_sub("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true},
{`select date_sub("2011-11-11 10:10:10.123456", interval 10 second)`, true},
@ -442,6 +470,31 @@ func (s *testParserSuite) TestBuiltin(c *C) {
{`select date_sub("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true},
{`select date_sub("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true},
{`select date_sub("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true},
// For subdate
{`select subdate("2011-11-11 10:10:10.123456", interval 10 microsecond)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 10 second)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 10 minute)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 10 hour)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 10 day)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 1 week)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 1 month)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 1 quarter)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval 1 year)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "10.10" second_microsecond)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "10:10.10" minute_microsecond)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "10:10" minute_second)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "10:10:10.10" hour_microsecond)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "10:10:10" hour_second)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "10:10" hour_minute)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "11 10:10:10.10" day_microsecond)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "11 10:10:10" day_second)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "11 10:10" day_minute)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "11 10" day_hour)`, true},
{`select subdate("2011-11-11 10:10:10.123456", interval "11-11" year_month)`, true},
{`select adddate("2011-11-11 10:10:10.123456", 10)`, true},
{`select adddate("2011-11-11 10:10:10.123456", 0.10)`, true},
{`select adddate("2011-11-11 10:10:10.123456", "11,11")`, true},
}
s.RunTest(c, table)
}

View File

@ -128,6 +128,22 @@ func (l *lexer) SetLastSelectFieldText(st *ast.SelectStmt, lastEnd int) {
}
}
func (l *lexer) startOffset(offset int) int {
offset--
for unicode.IsSpace(rune(l.src[offset])) {
offset++
}
return offset
}
func (l *lexer) endOffset(offset int) int {
offset--
for offset > 0 && unicode.IsSpace(rune(l.src[offset-1])) {
offset--
}
return offset
}
func (l *lexer) unget(b byte) {
l.ungetBuf = append(l.ungetBuf, b)
l.i--
@ -266,6 +282,7 @@ z [zZ]
abs {a}{b}{s}
add {a}{d}{d}
adddate {a}{d}{d}{d}{a}{t}{e}
after {a}{f}{t}{e}{r}
all {a}{l}{l}
alter {a}{l}{t}{e}{r}
@ -334,6 +351,7 @@ execute {e}{x}{e}{c}{u}{t}{e}
exists {e}{x}{i}{s}{t}{s}
explain {e}{x}{p}{l}{a}{i}{n}
extract {e}{x}{t}{r}{a}{c}{t}
fields {f}{i}{e}{l}{d}{s}
first {f}{i}{r}{s}{t}
for {f}{o}{r}
foreign {f}{o}{r}{e}{i}{g}{n}
@ -415,6 +433,7 @@ show {s}{h}{o}{w}
some {s}{o}{m}{e}
start {s}{t}{a}{r}{t}
status {s}{t}{a}{t}{u}{s}
subdate {s}{u}{b}{d}{a}{t}{e}
substring {s}{u}{b}{s}{t}{r}{i}{n}{g}
substring_index {s}{u}{b}{s}{t}{r}{i}{n}{g}_{i}{n}{d}{e}{x}
sum {s}{u}{m}
@ -425,6 +444,7 @@ then {t}{h}{e}{n}
to {t}{o}
trailing {t}{r}{a}{i}{l}{i}{n}{g}
transaction {t}{r}{a}{n}{s}{a}{c}{t}{i}{o}{n}
triggers {t}{r}{i}{g}{g}{e}{r}{s}
trim {t}{r}{i}{m}
truncate {t}{r}{u}{n}{c}{a}{t}{e}
max {m}{a}{x}
@ -598,6 +618,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{abs} lval.item = string(l.val)
return abs
{add} return add
{adddate} return addDate
{after} lval.item = string(l.val)
return after
{all} return all
@ -661,10 +682,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{database} lval.item = string(l.val)
return database
{databases} return databases
{date_add} lval.item = string(l.val)
return dateAdd
{date_sub} lval.item = string(l.val)
return dateSub
{date_add} return dateAdd
{date_sub} return dateSub
{day} lval.item = string(l.val)
return day
{dayofweek} lval.item = string(l.val)
@ -712,6 +731,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{explain} return explain
{extract} lval.item = string(l.val)
return extract
{fields} lval.item = string(l.val)
return fields
{first} lval.item = string(l.val)
return first
{for} return forKwd
@ -856,6 +877,7 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{set} return set
{share} return share
{show} return show
{subdate} return subDate
{substring} lval.item = string(l.val)
return substring
{substring_index} lval.item = string(l.val)
@ -872,6 +894,8 @@ year_month {y}{e}{a}{r}_{m}{o}{n}{t}{h}
{trailing} return trailing
{transaction} lval.item = string(l.val)
return transaction
{triggers} lval.item = string(l.val)
return triggers
{trim} lval.item = string(l.val)
return trim
{truncate} lval.item = string(l.val)

View File

@ -86,7 +86,7 @@ func (p *testInfoSchemaSuit) TestInfoSchema(c *C) {
cnt = mustQuery(c, testDB, "select * from information_schema.columns")
c.Assert(cnt, Greater, 0)
cnt = mustQuery(c, testDB, "select * from information_schema.statistics")
c.Assert(cnt, Equals, 15)
c.Assert(cnt, Equals, 16)
cnt = mustQuery(c, testDB, "select * from information_schema.character_sets")
c.Assert(cnt, Greater, 0)
cnt = mustQuery(c, testDB, "select * from information_schema.collations")

View File

@ -120,6 +120,11 @@ func (s *ShowPlan) GetFields() []*field.ResultField {
names = []string{"Table", "Create Table"}
case stmt.ShowGrants:
names = []string{fmt.Sprintf("Grants for %s", s.User)}
case stmt.ShowTriggers:
names = []string{"Trigger", "Event", "Table", "Statement", "Timing", "Created",
"sql_mode", "Definer", "character_set_client", "collation_connection", "Database Collation"}
types = []byte{mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar,
mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar, mysql.TypeVarchar}
}
fields := make([]*field.ResultField, 0, len(names))
for i, name := range names {
@ -181,6 +186,8 @@ func (s *ShowPlan) fetchAll(ctx context.Context) error {
return s.fetchShowCreateTable(ctx)
case stmt.ShowGrants:
return s.fetchShowGrants(ctx)
case stmt.ShowTriggers:
return s.fetchShowTriggers(ctx)
}
return nil
}
@ -396,7 +403,7 @@ func (s *ShowPlan) fetchShowTableStatus(ctx context.Context) error {
}
now := mysql.GetCurrentTime(mysql.TypeDatetime)
data := []interface{}{
v, "", "", "", "", 100, 100,
v, "InnoDB", "10", "Compact", 100, 100,
100, 100, 100, 100, 100,
now, now, now, "utf8_general_ci", "",
"", "",
@ -407,6 +414,7 @@ func (s *ShowPlan) fetchShowTableStatus(ctx context.Context) error {
}
func (s *ShowPlan) fetchShowVariables(ctx context.Context) error {
sessionVars := variable.GetSessionVars(ctx)
globalVars := variable.GetGlobalSysVarAccessor(ctx)
m := map[interface{}]interface{}{}
for _, v := range variable.SysVars {
@ -432,13 +440,19 @@ func (s *ShowPlan) fetchShowVariables(ctx context.Context) error {
var value string
if !s.GlobalScope {
// Try to get Session Scope variable value
// Try to get Session Scope variable value first.
sv, ok := sessionVars.Systems[v.Name]
if ok {
value = sv
} else {
// If session scope variable is not set, get the global scope value.
value, err = globalVars.GetGlobalSysVar(ctx, v.Name)
if err != nil {
return errors.Trace(err)
}
}
} else {
value, err = ctx.(variable.GlobalSysVarAccessor).GetGlobalSysVar(ctx, v.Name)
value, err = globalVars.GetGlobalSysVar(ctx, v.Name)
if err != nil {
return errors.Trace(err)
}
@ -462,6 +476,7 @@ func (s *ShowPlan) fetchShowCharset(ctx context.Context) error {
}
func (s *ShowPlan) fetchShowEngines(ctx context.Context) error {
// Mock data
row := &plan.Row{
Data: []interface{}{"InnoDB", "DEFAULT", "Supports transactions, row-level locking, and foreign keys", "YES", "YES", "YES"},
}
@ -574,3 +589,7 @@ func (s *ShowPlan) fetchShowGrants(ctx context.Context) error {
}
return nil
}
func (s *ShowPlan) fetchShowTriggers(ctx context.Context) error {
return nil
}

View File

@ -53,8 +53,10 @@ type testShowSuit struct {
var _ = Suite(&testShowSuit{})
func (p *testShowSuit) SetUpSuite(c *C) {
p.ctx = mock.NewContext()
nc := mock.NewContext()
p.ctx = nc
variable.BindSessionVars(p.ctx)
variable.BindGlobalSysVarAccessor(p.ctx, nc)
p.dbName = "testshowplan"
p.store = newStore(c, p.dbName)
@ -180,6 +182,28 @@ func (p *testShowSuit) TestShowVariables(c *C) {
c.Assert(v, Equals, "on")
}
func (p *testShowSuit) TestIssue540(c *C) {
// Show variables where variable_name="time_zone"
pln := &plans.ShowPlan{
Target: stmt.ShowVariables,
GlobalScope: false,
Pattern: &expression.PatternLike{
Pattern: &expression.Value{
Val: "time_zone",
},
},
}
// Make sure the session scope var is not set.
sessionVars := variable.GetSessionVars(p.ctx)
_, ok := sessionVars.Systems["time_zone"]
c.Assert(ok, IsFalse)
r, err := pln.Next(p.ctx)
c.Assert(err, IsNil)
c.Assert(r.Data[0], Equals, "time_zone")
c.Assert(r.Data[1], Equals, "SYSTEM")
}
func (p *testShowSuit) TestShowCollation(c *C) {
pln := &plans.ShowPlan{}

View File

@ -256,7 +256,7 @@ func (s *session) ExecRestrictedSQL(ctx context.Context, sql string) (rset.Recor
// Check statement for some restriction
// For example only support DML on system meta table.
// TODO: Add more restrictions.
log.Infof("Executing %s [%s]", st.OriginText(), sql)
log.Debugf("Executing %s [%s]", st.OriginText(), sql)
ctx.SetValue(&sqlexec.RestrictedSQLExecutorKeyType{}, true)
defer ctx.ClearValue(&sqlexec.RestrictedSQLExecutorKeyType{})
rs, err := st.Exec(ctx)
@ -543,6 +543,9 @@ func CreateSession(store kv.Storage) (Session, error) {
variable.BindSessionVars(s)
variable.GetSessionVars(s).SetStatusFlag(mysql.ServerStatusAutocommit, true)
// session implements variable.GlobalSysVarAccessor. Bind it to ctx.
variable.BindGlobalSysVarAccessor(s, s)
// session implements autocommit.Checker. Bind it to ctx
autocommit.BindAutocommitChecker(s, s)
sessionMu.Lock()

View File

@ -581,3 +581,27 @@ type GlobalSysVarAccessor interface {
// SetGlobalSysVar sets the global system variable name to value.
SetGlobalSysVar(ctx context.Context, name string, value string) error
}
// globalSysVarAccessorKeyType is a dummy type to avoid naming collision in context.
type globalSysVarAccessorKeyType int
// String defines a Stringer function for debugging and pretty printing.
func (k globalSysVarAccessorKeyType) String() string {
return "global_sysvar_accessor"
}
const accessorKey globalSysVarAccessorKeyType = 0
// BindGlobalSysVarAccessor binds global sysvar accessor to context.
func BindGlobalSysVarAccessor(ctx context.Context, accessor GlobalSysVarAccessor) {
ctx.SetValue(accessorKey, accessor)
}
// GetGlobalSysVarAccessor gets accessor from ctx.
func GetGlobalSysVarAccessor(ctx context.Context) GlobalSysVarAccessor {
v, ok := ctx.Value(accessorKey).(GlobalSysVarAccessor)
if !ok {
panic("Miss global sysvar accessor")
}
return v
}

View File

@ -59,6 +59,7 @@ const (
ShowCollation
ShowCreateTable
ShowGrants
ShowTriggers
)
const (

View File

@ -100,14 +100,11 @@ func (s *DeleteStmt) plan(ctx context.Context) (plan.Plan, error) {
}
func removeRow(ctx context.Context, t table.Table, h int64, data []interface{}) error {
// remove row's all indexies
if err := t.RemoveRowAllIndex(ctx, h, data); err != nil {
return err
}
// remove row
if err := t.RemoveRow(ctx, h); err != nil {
return err
err := t.RemoveRecord(ctx, h, data)
if err != nil {
return errors.Trace(err)
}
variable.GetSessionVars(ctx).AddAffectedRows(1)
return nil
}

View File

@ -25,6 +25,8 @@ import (
"github.com/pingcap/tidb/context"
"github.com/pingcap/tidb/ddl"
"github.com/pingcap/tidb/model"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/privilege"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/stmt"
@ -109,15 +111,37 @@ func (s *DropTableStmt) SetText(text string) {
// Exec implements the stmt.Statement Exec interface.
func (s *DropTableStmt) Exec(ctx context.Context) (rset.Recordset, error) {
var notExistTables []string
is := sessionctx.GetDomain(ctx).InfoSchema()
for _, ti := range s.TableIdents {
err := sessionctx.GetDomain(ctx).DDL().DropTable(ctx, ti.Full(ctx))
// TODO: we should return special error for table not exist, checking "not exist" is not enough,
// because some other errors may contain this error string too.
fullti := ti.Full(ctx)
schema, ok := is.SchemaByName(fullti.Schema)
if !ok {
// TODO: we should return special error for table not exist, checking "not exist" is not enough,
// because some other errors may contain this error string too.
notExistTables = append(notExistTables, ti.String())
continue
}
tb, err := is.TableByName(fullti.Schema, fullti.Name)
if err != nil && strings.HasSuffix(err.Error(), "not exist") {
notExistTables = append(notExistTables, ti.String())
continue
} else if err != nil {
return nil, errors.Trace(err)
}
// Check Privilege
privChecker := privilege.GetPrivilegeChecker(ctx)
hasPriv, err := privChecker.Check(ctx, schema, tb.Meta(), mysql.DropPriv)
if err != nil {
return nil, errors.Trace(err)
}
if !hasPriv {
return nil, errors.Errorf("You do not have the privilege to drop table %s.%s.", ti.Schema, ti.Name)
}
err = sessionctx.GetDomain(ctx).DDL().DropTable(ctx, fullti)
if err != nil {
return nil, errors.Trace(err)
}
}
if len(notExistTables) > 0 && !s.IfExists {
return nil, errors.Errorf("DROP TABLE: table %s does not exist", strings.Join(notExistTables, ","))

View File

@ -100,7 +100,7 @@ func (s *SetStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
log.Debug("Set sys/user variables")
sessionVars := variable.GetSessionVars(ctx)
globalVars := variable.GetGlobalSysVarAccessor(ctx)
for _, v := range s.Variables {
// Variable is case insensitive, we use lower case.
name := strings.ToLower(v.Name)
@ -138,7 +138,7 @@ func (s *SetStmt) Exec(ctx context.Context) (_ rset.Recordset, err error) {
if err != nil {
return nil, errors.Trace(err)
}
err = ctx.(variable.GlobalSysVarAccessor).SetGlobalSysVar(ctx, name, svalue)
err = globalVars.SetGlobalSysVar(ctx, name, svalue)
return nil, errors.Trace(err)
}
return nil, errors.Errorf("Variable '%s' is a SESSION variable and can't be used with SET GLOBAL", name)

View File

@ -451,7 +451,8 @@ func (s *testKVSuite) TestConditionIfNotExist(c *C) {
}()
}
wg.Wait()
c.Assert(success, Greater, int64(1))
// At least one txn can success.
c.Assert(success, Greater, int64(0))
// Clean up
txn, err := s.s.Begin()

View File

@ -231,7 +231,7 @@ func (txn *dbTxn) Commit() error {
txn.close()
}()
return txn.doCommit()
return errors.Trace(txn.doCommit())
}
func (txn *dbTxn) CommittedVersion() (kv.Version, error) {

View File

@ -30,17 +30,17 @@ type HashPair struct {
}
type hashMeta struct {
Length int64
FieldCount int64
}
func (meta hashMeta) Value() []byte {
buf := make([]byte, 8)
binary.BigEndian.PutUint64(buf[0:8], uint64(meta.Length))
binary.BigEndian.PutUint64(buf[0:8], uint64(meta.FieldCount))
return buf
}
func (meta hashMeta) IsEmpty() bool {
return meta.Length <= 0
return meta.FieldCount <= 0
}
// HSet sets the string value of a hash field.
@ -92,34 +92,40 @@ func (t *TxStructure) HGetInt64(key []byte, field []byte) (int64, error) {
}
func (t *TxStructure) updateHash(key []byte, field []byte, fn func(oldValue []byte) ([]byte, error)) error {
metaKey := t.encodeHashMetaKey(key)
meta, err := t.loadHashMeta(metaKey)
if err != nil {
return errors.Trace(err)
}
dataKey := t.encodeHashDataKey(key, field)
var oldValue []byte
oldValue, err = t.loadHashValue(dataKey)
oldValue, err := t.loadHashValue(dataKey)
if err != nil {
return errors.Trace(err)
}
if oldValue == nil {
meta.Length++
}
var newValue []byte
newValue, err = fn(oldValue)
newValue, err := fn(oldValue)
if err != nil {
return errors.Trace(err)
}
// Check if new value is equal to old value.
if bytes.Equal(oldValue, newValue) {
return nil
}
if err = t.txn.Set(dataKey, newValue); err != nil {
return errors.Trace(err)
}
return errors.Trace(t.txn.Set(metaKey, meta.Value()))
metaKey := t.encodeHashMetaKey(key)
meta, err := t.loadHashMeta(metaKey)
if err != nil {
return errors.Trace(err)
}
if oldValue == nil {
meta.FieldCount++
if err = t.txn.Set(metaKey, meta.Value()); err != nil {
return errors.Trace(err)
}
}
return nil
}
// HLen gets the number of fields in a hash.
@ -129,7 +135,7 @@ func (t *TxStructure) HLen(key []byte) (int64, error) {
if err != nil {
return 0, errors.Trace(err)
}
return meta.Length, nil
return meta.FieldCount, nil
}
// HDel deletes one or more hash fields.
@ -154,7 +160,7 @@ func (t *TxStructure) HDel(key []byte, fields ...[]byte) error {
return errors.Trace(err)
}
meta.Length--
meta.FieldCount--
}
}
@ -254,7 +260,7 @@ func (t *TxStructure) loadHashMeta(metaKey []byte) (hashMeta, error) {
return hashMeta{}, errors.Trace(err)
}
meta := hashMeta{Length: 0}
meta := hashMeta{FieldCount: 0}
if v == nil {
return meta, nil
}
@ -263,7 +269,7 @@ func (t *TxStructure) loadHashMeta(metaKey []byte) (hashMeta, error) {
return meta, errors.New("invalid list meta data")
}
meta.Length = int64(binary.BigEndian.Uint64(v[0:8]))
meta.FieldCount = int64(binary.BigEndian.Uint64(v[0:8]))
return meta, nil
}

View File

@ -193,7 +193,7 @@ func (s *tesTxStructureSuite) TestHash(c *C) {
value, err = tx.HGet(key, []byte("fake"))
c.Assert(err, IsNil)
c.Assert(err, IsNil)
c.Assert(value, IsNil)
keys, err := tx.HKeys(key)
c.Assert(err, IsNil)
@ -224,13 +224,41 @@ func (s *tesTxStructureSuite) TestHash(c *C) {
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(2))
// Test set new value which equals to old value.
value, err = tx.HGet(key, []byte("1"))
c.Assert(err, IsNil)
c.Assert(value, DeepEquals, []byte("1"))
err = tx.HSet(key, []byte("1"), []byte("1"))
c.Assert(err, IsNil)
value, err = tx.HGet(key, []byte("1"))
c.Assert(err, IsNil)
c.Assert(value, DeepEquals, []byte("1"))
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(2))
n, err = tx.HInc(key, []byte("1"), 1)
c.Assert(err, IsNil)
c.Assert(n, Equals, int64(2))
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(2))
n, err = tx.HInc(key, []byte("1"), 1)
c.Assert(err, IsNil)
c.Assert(n, Equals, int64(3))
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(2))
n, err = tx.HGetInt64(key, []byte("1"))
c.Assert(err, IsNil)
c.Assert(n, Equals, int64(2))
c.Assert(n, Equals, int64(3))
l, err = tx.HLen(key)
c.Assert(err, IsNil)
@ -246,6 +274,55 @@ func (s *tesTxStructureSuite) TestHash(c *C) {
err = tx.HDel(key, []byte("fake_key"))
c.Assert(err, IsNil)
// Test set nil value.
value, err = tx.HGet(key, []byte("nil_key"))
c.Assert(err, IsNil)
c.Assert(value, IsNil)
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(0))
err = tx.HSet(key, []byte("nil_key"), nil)
c.Assert(err, IsNil)
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(0))
err = tx.HSet(key, []byte("nil_key"), []byte("1"))
c.Assert(err, IsNil)
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(1))
value, err = tx.HGet(key, []byte("nil_key"))
c.Assert(err, IsNil)
c.Assert(value, DeepEquals, []byte("1"))
err = tx.HSet(key, []byte("nil_key"), nil)
c.Assert(err, NotNil)
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(1))
value, err = tx.HGet(key, []byte("nil_key"))
c.Assert(err, IsNil)
c.Assert(value, DeepEquals, []byte("1"))
err = tx.HSet(key, []byte("nil_key"), []byte("2"))
c.Assert(err, IsNil)
l, err = tx.HLen(key)
c.Assert(err, IsNil)
c.Assert(l, Equals, int64(1))
value, err = tx.HGet(key, []byte("nil_key"))
c.Assert(err, IsNil)
c.Assert(value, DeepEquals, []byte("2"))
err = txn.Commit()
c.Assert(err, IsNil)

View File

@ -41,15 +41,9 @@ type Table interface {
// Row returns a row for all columns.
Row(ctx context.Context, h int64) ([]interface{}, error)
// RemoveRow removes the row of handle h.
RemoveRow(ctx context.Context, h int64) error
// RemoveRowIndex removes an index of a row.
RemoveRowIndex(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error
// RemoveRowAllIndex removes all the indices of a row.
RemoveRowAllIndex(ctx context.Context, h int64, rec []interface{}) error
// BuildIndexForRow builds an index for a row.
BuildIndexForRow(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error
@ -89,6 +83,9 @@ type Table interface {
// UpdateRecord updates a row in the table.
UpdateRecord(ctx context.Context, h int64, currData []interface{}, newData []interface{}, touched []bool) error
// RemoveRecord removes a row in the table.
RemoveRecord(ctx context.Context, h int64, r []interface{}) error
// TableID returns the ID of the table.
TableID() int64

View File

@ -444,8 +444,22 @@ func (t *Table) LockRow(ctx context.Context, h int64) error {
return errors.Trace(err)
}
// RemoveRow implements table.Table RemoveRow interface.
func (t *Table) RemoveRow(ctx context.Context, h int64) error {
// RemoveRecord implements table.Table RemoveRecord interface.
func (t *Table) RemoveRecord(ctx context.Context, h int64, r []interface{}) error {
err := t.removeRowData(ctx, h)
if err != nil {
return errors.Trace(err)
}
err = t.removeRowIndices(ctx, h, r)
if err != nil {
return errors.Trace(err)
}
return nil
}
func (t *Table) removeRowData(ctx context.Context, h int64) error {
if err := t.LockRow(ctx, h); err != nil {
return errors.Trace(err)
}
@ -469,20 +483,8 @@ func (t *Table) RemoveRow(ctx context.Context, h int64) error {
return nil
}
// RemoveRowIndex implements table.Table RemoveRowIndex interface.
func (t *Table) RemoveRowIndex(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error {
txn, err := ctx.GetTxn(false)
if err != nil {
return errors.Trace(err)
}
if err = idx.X.Delete(txn, vals, h); err != nil {
return errors.Trace(err)
}
return nil
}
// RemoveRowAllIndex implements table.Table RemoveRowAllIndex interface.
func (t *Table) RemoveRowAllIndex(ctx context.Context, h int64, rec []interface{}) error {
// removeRowAllIndex removes all the indices of a row.
func (t *Table) removeRowIndices(ctx context.Context, h int64, rec []interface{}) error {
for _, v := range t.indices {
vals, err := v.FetchValues(rec)
if vals == nil {
@ -500,6 +502,18 @@ func (t *Table) RemoveRowAllIndex(ctx context.Context, h int64, rec []interface{
return nil
}
// RemoveRowIndex implements table.Table RemoveRowIndex interface.
func (t *Table) RemoveRowIndex(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error {
txn, err := ctx.GetTxn(false)
if err != nil {
return errors.Trace(err)
}
if err = idx.X.Delete(txn, vals, h); err != nil {
return errors.Trace(err)
}
return nil
}
// BuildIndexForRow implements table.Table BuildIndexForRow interface.
func (t *Table) BuildIndexForRow(ctx context.Context, h int64, vals []interface{}, idx *column.IndexedCol) error {
txn, err := ctx.GetTxn(false)

View File

@ -88,9 +88,8 @@ func (ts *testSuite) TestBasic(c *C) {
return true, nil
})
c.Assert(tb.RemoveRowAllIndex(ctx, rid, []interface{}{1, "cba"}), IsNil)
c.Assert(tb.RemoveRecord(ctx, rid, []interface{}{1, "cba"}), IsNil)
c.Assert(tb.RemoveRow(ctx, rid), IsNil)
// Make sure there is index data in the storage.
prefix := tb.IndexPrefix()
cnt, err := countEntriesWithPrefix(ctx, prefix)

View File

@ -15,6 +15,7 @@ package terror
import (
"fmt"
"runtime"
"strconv"
"github.com/juju/errors"
@ -23,6 +24,7 @@ import (
// Common base error instances.
var (
DatabaseNotExists = ClassSchema.New(CodeDatabaseNotExists, "database not exists")
TableNotExists = ClassSchema.New(CodeTableNotExists, "table not exists")
CommitNotInTransaction = ClassExecutor.New(CodeCommitNotInTransaction, "commit not in transaction")
RollbackNotInTransaction = ClassExecutor.New(CodeRollbackNotInTransaction, "rollback not in transaction")
@ -35,6 +37,7 @@ type ErrCode int
// Schema error codes.
const (
CodeDatabaseNotExists ErrCode = iota + 1
CodeTableNotExists
)
// Executor error codes.
@ -89,7 +92,7 @@ func (ec ErrClass) EqualClass(err error) bool {
return false
}
if te, ok := e.(*Error); ok {
return te.Class == ec
return te.class == ec
}
return false
}
@ -103,29 +106,48 @@ func (ec ErrClass) NotEqualClass(err error) bool {
// Usually used to create base *Error.
func (ec ErrClass) New(code ErrCode, message string) *Error {
return &Error{
Class: ec,
Code: code,
Message: message,
class: ec,
code: code,
message: message,
}
}
// Error implements error interface and adds integer Class and Code, so
// errors with different message can be compared.
type Error struct {
Class ErrClass
Code ErrCode
Message string
class ErrClass
code ErrCode
message string
file string
line int
}
// Class returns ErrClass
func (e *Error) Class() ErrClass {
return e.class
}
// Code returns ErrCode
func (e *Error) Code() ErrCode {
return e.code
}
// Location returns the location where the error is created,
// implements juju/errors locationer interface.
func (e *Error) Location() (file string, line int) {
return e.file, e.line
}
// Error implements error interface.
func (e *Error) Error() string {
return fmt.Sprintf("[%s:%d]%s", e.Class, e.Code, e.Message)
return fmt.Sprintf("[%s:%d]%s", e.class, e.code, e.message)
}
// Gen generates a new *Error with the same class and code, and a new formatted message.
func (e *Error) Gen(format string, args ...interface{}) *Error {
err := *e
err.Message = fmt.Sprintf(format, args...)
err.message = fmt.Sprintf(format, args...)
_, err.file, err.line, _ = runtime.Caller(1)
return &err
}
@ -136,7 +158,7 @@ func (e *Error) Equal(err error) bool {
return false
}
inErr, ok := originErr.(*Error)
return ok && e.Class == inErr.Class && e.Code == inErr.Code
return ok && e.class == inErr.class && e.code == inErr.code
}
// NotEqual checks if err is not equal to e.
@ -160,7 +182,7 @@ func ErrorEqual(err1, err2 error) bool {
te1, ok1 := e1.(*Error)
te2, ok2 := e2.(*Error)
if ok1 && ok2 {
return te1.Class == te2.Class && te1.Code == te2.Code
return te1.class == te2.class && te1.code == te2.code
}
return e1.Error() == e2.Error()

View File

@ -18,6 +18,7 @@ import (
"github.com/juju/errors"
. "github.com/pingcap/check"
"strings"
)
func TestT(t *testing.T) {
@ -49,6 +50,27 @@ func (s *testTErrorSuite) TestTError(c *C) {
c.Assert(optimizerErr.Equal(errors.New("abc")), IsFalse)
}
var predefinedErr = ClassExecutor.New(ErrCode(123), "predefiend error")
func example() error {
err := call()
return errors.Trace(err)
}
func call() error {
return predefinedErr.Gen("error message:%s", "abc")
}
func (s *testTErrorSuite) TestTraceAndLocation(c *C) {
err := example()
stack := errors.ErrorStack(err)
lines := strings.Split(stack, "\n")
c.Assert(len(lines), Equals, 2)
for _, v := range lines {
c.Assert(strings.Contains(v, "terror_test.go"), IsTrue)
}
}
func (s *testTErrorSuite) TestErrorEqual(c *C) {
e1 := errors.New("test error")
c.Assert(e1, NotNil)

View File

@ -20,6 +20,7 @@ import (
"os"
"runtime"
"sync"
"sync/atomic"
"testing"
"time"
@ -28,6 +29,8 @@ import (
"github.com/pingcap/tidb/kv"
"github.com/pingcap/tidb/mysql"
"github.com/pingcap/tidb/rset"
"github.com/pingcap/tidb/sessionctx"
"github.com/pingcap/tidb/sessionctx/autocommit"
"github.com/pingcap/tidb/sessionctx/variable"
"github.com/pingcap/tidb/terror"
)
@ -106,7 +109,8 @@ func checkResult(c *C, r sql.Result, affectedRows int64, insertID int64) {
}
type testSessionSuite struct {
dbName string
dbName string
dbNameBootstrap string
createDBSQL string
dropDBSQL string
@ -332,6 +336,7 @@ func sessionExec(c *C, se Session, sql string) ([]rset.Recordset, error) {
func (s *testSessionSuite) SetUpSuite(c *C) {
s.dbName = "test_session_db"
s.dbNameBootstrap = "test_main_db_bootstrap"
s.createDBSQL = fmt.Sprintf("create database if not exists %s;", s.dbName)
s.dropDBSQL = fmt.Sprintf("drop database %s;", s.dbName)
s.useDBSQL = fmt.Sprintf("use %s;", s.dbName)
@ -1025,6 +1030,61 @@ func (s *testSessionSuite) TestBootstrap(c *C) {
c.Assert(v[0], Equals, int64(len(variable.SysVars)))
}
// Create a new session on store but only do ddl works.
func (s *testSessionSuite) bootstrapWithError(store kv.Storage, c *C) {
ss := &session{
values: make(map[fmt.Stringer]interface{}),
store: store,
sid: atomic.AddInt64(&sessionID, 1),
}
domain, err := domap.Get(store)
c.Assert(err, IsNil)
sessionctx.BindDomain(ss, domain)
variable.BindSessionVars(ss)
variable.GetSessionVars(ss).SetStatusFlag(mysql.ServerStatusAutocommit, true)
// session implements autocommit.Checker. Bind it to ctx
autocommit.BindAutocommitChecker(ss, ss)
sessionMu.Lock()
defer sessionMu.Unlock()
b, err := checkBootstrapped(ss)
c.Assert(b, IsFalse)
c.Assert(err, IsNil)
doDDLWorks(ss)
// Leave dml unfinished.
}
func (s *testSessionSuite) TestBootstrapWithError(c *C) {
store := newStore(c, s.dbNameBootstrap)
s.bootstrapWithError(store, c)
se := newSession(c, store, s.dbNameBootstrap)
mustExecSQL(c, se, "USE mysql;")
r := mustExecSQL(c, se, `select * from user;`)
row, err := r.Next()
c.Assert(err, IsNil)
c.Assert(row, NotNil)
match(c, row.Data, "localhost", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y")
row, err = r.Next()
c.Assert(err, IsNil)
c.Assert(row, NotNil)
match(c, row.Data, "127.0.0.1", "root", "", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y", "Y")
mustExecSQL(c, se, "USE test;")
// Check privilege tables.
mustExecSQL(c, se, "SELECT * from mysql.db;")
mustExecSQL(c, se, "SELECT * from mysql.tables_priv;")
mustExecSQL(c, se, "SELECT * from mysql.columns_priv;")
// Check privilege tables.
r = mustExecSQL(c, se, "SELECT COUNT(*) from mysql.global_variables;")
v, err := r.FirstRow()
c.Assert(err, IsNil)
c.Assert(v[0], Equals, int64(len(variable.SysVars)))
r = mustExecSQL(c, se, `SELECT VARIABLE_VALUE from mysql.TiDB where VARIABLE_NAME="bootstrapped";`)
row, err = r.Next()
c.Assert(err, IsNil)
c.Assert(row, NotNil)
c.Assert(row.Data, HasLen, 1)
c.Assert(row.Data[0], Equals, "True")
}
func (s *testSessionSuite) TestEnum(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
@ -1287,6 +1347,31 @@ func (s *testSessionSuite) TestBuiltin(c *C) {
mustExecFailed(c, se, `select cast("xxx 10:10:10" as datetime)`)
}
func (s *testSessionSuite) TestFieldText(c *C) {
store := newStore(c, s.dbName)
se := newSession(c, store, s.dbName)
mustExecSQL(c, se, "drop table if exists t")
mustExecSQL(c, se, "create table t (a int)")
cases := []struct {
sql string
field string
}{
{"select distinct(a) from t", "a"},
{"select (1)", "1"},
{"select (1+1)", "(1+1)"},
{"select a from t", "a"},
{"select ((a+1)) from t", "((a+1))"},
}
for _, v := range cases {
results, err := se.Execute(v.sql)
c.Assert(err, IsNil)
result := results[0]
fields, err := result.Fields()
c.Assert(err, IsNil)
c.Assert(fields[0].Name, Equals, v.field)
}
}
func newSession(c *C, store kv.Storage, dbName string) Session {
se, err := CreateSession(store)
c.Assert(err, IsNil)

View File

@ -79,7 +79,7 @@ func (c *Context) SetGlobalSysVar(ctx context.Context, name string, value string
}
// NewContext creates a new mocked context.Context.
func NewContext() context.Context {
func NewContext() *Context {
return &Context{
values: make(map[fmt.Stringer]interface{}),
}