diff --git a/ddl/ddl.go b/ddl/ddl.go index 7752b05748..a7a5f7fd69 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -32,10 +32,12 @@ import ( "github.com/pingcap/tidb/meta" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/optimizer/evaluator" "github.com/pingcap/tidb/parser/coldef" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/util/charset" + "github.com/pingcap/tidb/util/types" "github.com/twinj/uuid" ) @@ -345,11 +347,12 @@ func setColumnFlagWithConstraint(colMap map[string]*column.Col, v *coldef.TableC } } -func (d *ddl) buildColumnsAndConstraints(colDefs []*coldef.ColumnDef, constraints []*coldef.TableConstraint) ([]*column.Col, []*coldef.TableConstraint, error) { +func (d *ddl) buildColumnsAndConstraints(ctx context.Context, colDefs []*coldef.ColumnDef, + constraints []*coldef.TableConstraint) ([]*column.Col, []*coldef.TableConstraint, error) { var cols []*column.Col colMap := map[string]*column.Col{} for i, colDef := range colDefs { - col, cts, err := d.buildColumnAndConstraint(i, colDef) + col, cts, err := d.buildColumnAndConstraint(ctx, i, colDef) if err != nil { return nil, nil, errors.Trace(err) } @@ -365,7 +368,8 @@ func (d *ddl) buildColumnsAndConstraints(colDefs []*coldef.ColumnDef, constraint return cols, constraints, nil } -func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) { +func (d *ddl) buildColumnAndConstraint(ctx context.Context, offset int, + colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) { // Set charset. if len(colDef.Tp.Charset) == 0 { switch colDef.Tp.Tp { @@ -377,7 +381,7 @@ func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*c } } - col, cts, err := coldef.ColumnDefToCol(offset, colDef) + col, cts, err := columnDefToCol(ctx, offset, colDef) if err != nil { return nil, nil, errors.Trace(err) } @@ -390,6 +394,192 @@ func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*c return col, cts, nil } +// columnDefToCol converts ColumnDef to Col and TableConstraints. +func columnDefToCol(ctx context.Context, offset int, colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) { + constraints := []*coldef.TableConstraint{} + col := &column.Col{ + ColumnInfo: model.ColumnInfo{ + Offset: offset, + Name: model.NewCIStr(colDef.Name), + FieldType: *colDef.Tp, + }, + } + + // Check and set TimestampFlag and OnUpdateNowFlag. + if col.Tp == mysql.TypeTimestamp { + col.Flag |= mysql.TimestampFlag + col.Flag |= mysql.OnUpdateNowFlag + col.Flag |= mysql.NotNullFlag + } + + // If flen is not assigned, assigned it by type. + if col.Flen == types.UnspecifiedLength { + col.Flen = mysql.GetDefaultFieldLength(col.Tp) + } + if col.Decimal == types.UnspecifiedLength { + col.Decimal = mysql.GetDefaultDecimal(col.Tp) + } + + setOnUpdateNow := false + hasDefaultValue := false + if colDef.Constraints != nil { + keys := []*coldef.IndexColName{ + { + colDef.Name, + colDef.Tp.Flen, + }, + } + for _, v := range colDef.Constraints { + switch v.Tp { + case coldef.ConstrNotNull: + col.Flag |= mysql.NotNullFlag + case coldef.ConstrNull: + col.Flag &= ^uint(mysql.NotNullFlag) + removeOnUpdateNowFlag(col) + case coldef.ConstrAutoIncrement: + col.Flag |= mysql.AutoIncrementFlag + case coldef.ConstrPrimaryKey: + constraint := &coldef.TableConstraint{Tp: coldef.ConstrPrimaryKey, Keys: keys} + constraints = append(constraints, constraint) + col.Flag |= mysql.PriKeyFlag + case coldef.ConstrUniq: + constraint := &coldef.TableConstraint{Tp: coldef.ConstrUniq, ConstrName: colDef.Name, Keys: keys} + constraints = append(constraints, constraint) + col.Flag |= mysql.UniqueKeyFlag + case coldef.ConstrIndex: + constraint := &coldef.TableConstraint{Tp: coldef.ConstrIndex, ConstrName: colDef.Name, Keys: keys} + constraints = append(constraints, constraint) + case coldef.ConstrUniqIndex: + constraint := &coldef.TableConstraint{Tp: coldef.ConstrUniqIndex, ConstrName: colDef.Name, Keys: keys} + constraints = append(constraints, constraint) + col.Flag |= mysql.UniqueKeyFlag + case coldef.ConstrKey: + constraint := &coldef.TableConstraint{Tp: coldef.ConstrKey, ConstrName: colDef.Name, Keys: keys} + constraints = append(constraints, constraint) + case coldef.ConstrUniqKey: + constraint := &coldef.TableConstraint{Tp: coldef.ConstrUniqKey, ConstrName: colDef.Name, Keys: keys} + constraints = append(constraints, constraint) + col.Flag |= mysql.UniqueKeyFlag + case coldef.ConstrDefaultValue: + value, err := getDefaultValue(ctx, v, colDef.Tp.Tp, colDef.Tp.Decimal) + if err != nil { + return nil, nil, errors.Errorf("invalid default value - %s", errors.Trace(err)) + } + col.DefaultValue = value + hasDefaultValue = true + removeOnUpdateNowFlag(col) + case coldef.ConstrOnUpdate: + if !evaluator.IsCurrentTimeExpr(v.Evalue) { + return nil, nil, errors.Errorf("invalid ON UPDATE for - %s", col.Name) + } + + col.Flag |= mysql.OnUpdateNowFlag + setOnUpdateNow = true + case coldef.ConstrFulltext: + // Do nothing. + case coldef.ConstrComment: + // Do nothing. + } + } + } + + setTimestampDefaultValue(col, hasDefaultValue, setOnUpdateNow) + + // Set `NoDefaultValueFlag` if this field doesn't have a default value and + // it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field. + setNoDefaultValueFlag(col, hasDefaultValue) + + err := checkDefaultValue(col, hasDefaultValue) + if err != nil { + return nil, nil, errors.Trace(err) + } + if col.Charset == charset.CharsetBin { + col.Flag |= mysql.BinaryFlag + } + return col, constraints, nil +} + +func getDefaultValue(ctx context.Context, c *coldef.ConstraintOpt, tp byte, fsp int) (interface{}, error) { + if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { + value, err := evaluator.GetTimeValue(ctx, c.Evalue, tp, fsp) + if err != nil { + return nil, errors.Trace(err) + } + + // Value is nil means `default null`. + if value == nil { + return nil, nil + } + + // If value is mysql.Time, convert it to string. + if vv, ok := value.(mysql.Time); ok { + return vv.String(), nil + } + + return value, nil + } + v, err := evaluator.Eval(ctx, c.Evalue) + if err != nil { + return nil, errors.Trace(err) + } + return types.RawData(v), nil +} + +func removeOnUpdateNowFlag(c *column.Col) { + // For timestamp Col, if it is set null or default value, + // OnUpdateNowFlag should be removed. + if mysql.HasTimestampFlag(c.Flag) { + c.Flag &= ^uint(mysql.OnUpdateNowFlag) + } +} + +func setTimestampDefaultValue(c *column.Col, hasDefaultValue bool, setOnUpdateNow bool) { + if hasDefaultValue { + return + } + + // For timestamp Col, if is not set default value or not set null, use current timestamp. + if mysql.HasTimestampFlag(c.Flag) && mysql.HasNotNullFlag(c.Flag) { + if setOnUpdateNow { + c.DefaultValue = evaluator.ZeroTimestamp + } else { + c.DefaultValue = evaluator.CurrentTimestamp + } + } +} + +func setNoDefaultValueFlag(c *column.Col, hasDefaultValue bool) { + if hasDefaultValue { + return + } + + if !mysql.HasNotNullFlag(c.Flag) { + return + } + + // Check if it is an `AUTO_INCREMENT` field or `TIMESTAMP` field. + if !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) { + c.Flag |= mysql.NoDefaultValueFlag + } +} + +func checkDefaultValue(c *column.Col, hasDefaultValue bool) error { + if !hasDefaultValue { + return nil + } + + if c.DefaultValue != nil { + return nil + } + + // Set not null but default null is invalid. + if mysql.HasNotNullFlag(c.Flag) { + return errors.Errorf("invalid default value for %s", c.Name) + } + + return nil +} + func checkDuplicateColumn(colDefs []*coldef.ColumnDef) error { colNames := map[string]bool{} for _, colDef := range colDefs { @@ -511,7 +701,7 @@ func (d *ddl) CreateTable(ctx context.Context, ident table.Ident, colDefs []*col return errors.Trace(err) } - cols, newConstraints, err := d.buildColumnsAndConstraints(colDefs, constraints) + cols, newConstraints, err := d.buildColumnsAndConstraints(ctx, colDefs, constraints) if err != nil { return errors.Trace(err) } @@ -614,7 +804,7 @@ func (d *ddl) AddColumn(ctx context.Context, ti table.Ident, spec *AlterSpecific // ingore table constraints now, maybe return error later // we use length(t.Cols()) as the default offset first, later we will change the // column's offset later. - col, _, err = d.buildColumnAndConstraint(len(t.Cols()), spec.Column) + col, _, err = d.buildColumnAndConstraint(ctx, len(t.Cols()), spec.Column) if err != nil { return errors.Trace(err) } diff --git a/executor/converter/convert_stmt.go b/executor/converter/convert_stmt.go index 04be971242..81a9b3f579 100644 --- a/executor/converter/convert_stmt.go +++ b/executor/converter/convert_stmt.go @@ -486,13 +486,7 @@ func convertColumnOption(converter *expressionConverter, v *ast.ColumnOption) (* case ast.ColumnOptionUniqKey: oldColumnOpt.Tp = coldef.ConstrUniqKey } - if v.Expr != nil { - oldExpr, err := convertExpr(converter, v.Expr) - if err != nil { - return nil, errors.Trace(err) - } - oldColumnOpt.Evalue = oldExpr - } + oldColumnOpt.Evalue = v.Expr return oldColumnOpt, nil } diff --git a/expression/helper_test.go b/expression/helper_test.go index 78a425e4bc..a338c57432 100644 --- a/expression/helper_test.go +++ b/expression/helper_test.go @@ -2,13 +2,10 @@ package expression import ( "errors" - "time" . "github.com/pingcap/check" "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser/opcode" - "github.com/pingcap/tidb/sessionctx/variable" ) var _ = Suite(&testHelperSuite{}) @@ -164,96 +161,6 @@ func (s *testHelperSuite) TestBase(c *C) { } } -func (s *testHelperSuite) TestGetTimeValue(c *C) { - v, err := GetTimeValue(nil, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) - c.Assert(err, IsNil) - - timeValue, ok := v.(mysql.Time) - c.Assert(ok, IsTrue) - c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") - - ctx := newMockCtx() - variable.BindSessionVars(ctx) - sessionVars := variable.GetSessionVars(ctx) - - sessionVars.Systems["timestamp"] = "" - v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) - c.Assert(err, IsNil) - - timeValue, ok = v.(mysql.Time) - c.Assert(ok, IsTrue) - c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") - - sessionVars.Systems["timestamp"] = "0" - v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) - c.Assert(err, IsNil) - - timeValue, ok = v.(mysql.Time) - c.Assert(ok, IsTrue) - c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") - - delete(sessionVars.Systems, "timestamp") - v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) - c.Assert(err, IsNil) - - timeValue, ok = v.(mysql.Time) - c.Assert(ok, IsTrue) - c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") - - sessionVars.Systems["timestamp"] = "1234" - - tbl := []struct { - Expr interface{} - Ret interface{} - }{ - {"2012-12-12 00:00:00", "2012-12-12 00:00:00"}, - {CurrentTimestamp, time.Unix(1234, 0).Format(mysql.TimeFormat)}, - {ZeroTimestamp, "0000-00-00 00:00:00"}, - {Value{"2012-12-12 00:00:00"}, "2012-12-12 00:00:00"}, - {Value{int64(0)}, "0000-00-00 00:00:00"}, - {Value{}, nil}, - {CurrentTimeExpr, CurrentTimestamp}, - {NewUnaryOperation(opcode.Minus, Value{int64(0)}), "0000-00-00 00:00:00"}, - {mockExpr{}, nil}, - } - - for _, t := range tbl { - v, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp) - c.Assert(err, IsNil) - - switch x := v.(type) { - case mysql.Time: - c.Assert(x.String(), DeepEquals, t.Ret) - default: - c.Assert(x, DeepEquals, t.Ret) - } - } - - errTbl := []struct { - Expr interface{} - }{ - {"2012-13-12 00:00:00"}, - {Value{"2012-13-12 00:00:00"}}, - {Value{0}}, - {Value{int64(1)}}, - {&Call{F: "xxx"}}, - {NewUnaryOperation(opcode.Minus, Value{int64(1)})}, - } - - for _, t := range errTbl { - _, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp) - c.Assert(err, NotNil) - } -} - -func (s *testHelperSuite) TestIsCurrentTimeExpr(c *C) { - v := IsCurrentTimeExpr(mockExpr{}) - c.Assert(v, IsFalse) - - v = IsCurrentTimeExpr(CurrentTimeExpr) - c.Assert(v, IsTrue) -} - func convert(v interface{}) interface{} { switch x := v.(type) { case int: diff --git a/optimizer/evaluator/evaluator_test.go b/optimizer/evaluator/evaluator_test.go index a1cbc01d1d..e6f0b6d640 100644 --- a/optimizer/evaluator/evaluator_test.go +++ b/optimizer/evaluator/evaluator_test.go @@ -24,6 +24,7 @@ import ( "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/parser" "github.com/pingcap/tidb/parser/opcode" + "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/mock" "github.com/pingcap/tidb/util/types" @@ -1091,3 +1092,93 @@ func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) { c.Assert(ok, IsTrue) c.Assert(v.Equals(expect), IsTrue) } + +func (s *testEvaluatorSuite) TestGetTimeValue(c *C) { + v, err := GetTimeValue(nil, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) + c.Assert(err, IsNil) + + timeValue, ok := v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") + + ctx := mock.NewContext() + variable.BindSessionVars(ctx) + sessionVars := variable.GetSessionVars(ctx) + + sessionVars.Systems["timestamp"] = "" + v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) + c.Assert(err, IsNil) + + timeValue, ok = v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") + + sessionVars.Systems["timestamp"] = "0" + v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) + c.Assert(err, IsNil) + + timeValue, ok = v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") + + delete(sessionVars.Systems, "timestamp") + v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp) + c.Assert(err, IsNil) + + timeValue, ok = v.(mysql.Time) + c.Assert(ok, IsTrue) + c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00") + + sessionVars.Systems["timestamp"] = "1234" + + tbl := []struct { + Expr interface{} + Ret interface{} + }{ + {"2012-12-12 00:00:00", "2012-12-12 00:00:00"}, + {CurrentTimestamp, time.Unix(1234, 0).Format(mysql.TimeFormat)}, + {ZeroTimestamp, "0000-00-00 00:00:00"}, + {ast.NewValueExpr("2012-12-12 00:00:00"), "2012-12-12 00:00:00"}, + {ast.NewValueExpr(int64(0)), "0000-00-00 00:00:00"}, + {ast.NewValueExpr(nil), nil}, + {&ast.FuncCallExpr{FnName: model.NewCIStr(CurrentTimestamp)}, CurrentTimestamp}, + {&ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr(int64(0))}, "0000-00-00 00:00:00"}, + } + + for i, t := range tbl { + comment := Commentf("expr: %d", i) + v, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp) + c.Assert(err, IsNil) + + switch x := v.(type) { + case mysql.Time: + c.Assert(x.String(), DeepEquals, t.Ret, comment) + default: + c.Assert(x, DeepEquals, t.Ret, comment) + } + } + + errTbl := []struct { + Expr interface{} + }{ + {"2012-13-12 00:00:00"}, + {ast.NewValueExpr("2012-13-12 00:00:00")}, + {ast.NewValueExpr(0)}, + {ast.NewValueExpr(int64(1))}, + {&ast.FuncCallExpr{FnName: model.NewCIStr("xxx")}}, + {&ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr(int64(1))}}, + } + + for _, t := range errTbl { + _, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp) + c.Assert(err, NotNil) + } +} + +func (s *testEvaluatorSuite) TestIsCurrentTimeExpr(c *C) { + v := IsCurrentTimeExpr(ast.NewValueExpr("abc")) + c.Assert(v, IsFalse) + + v = IsCurrentTimeExpr(&ast.FuncCallExpr{FnName: model.NewCIStr("CURRENT_TIMESTAMP")}) + c.Assert(v, IsTrue) +} diff --git a/optimizer/evaluator/helper.go b/optimizer/evaluator/helper.go new file mode 100644 index 0000000000..b4d171784c --- /dev/null +++ b/optimizer/evaluator/helper.go @@ -0,0 +1,136 @@ +package evaluator + +import ( + "strconv" + "strings" + "time" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/sessionctx/variable" + "github.com/pingcap/tidb/util/types" +) + +var ( + // CurrentTimestamp is the keyword getting default value for datetime and timestamp type. + CurrentTimestamp = "CURRENT_TIMESTAMP" + currentTimestampL = "current_timestamp" + // ZeroTimestamp shows the zero datetime and timestamp. + ZeroTimestamp = "0000-00-00 00:00:00" +) + +var ( + errDefaultValue = errors.New("invalid default value") +) + +// GetTimeValue gets the time value with type tp. +func GetTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) { + return getTimeValue(ctx, v, tp, fsp) +} + +func getTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) { + value := mysql.Time{ + Type: tp, + Fsp: fsp, + } + + defaultTime, err := getSystemTimestamp(ctx) + if err != nil { + return nil, errors.Trace(err) + } + + switch x := v.(type) { + case string: + upperX := strings.ToUpper(x) + if upperX == CurrentTimestamp { + value.Time = defaultTime + } else if upperX == ZeroTimestamp { + value, _ = mysql.ParseTimeFromNum(0, tp, fsp) + } else { + value, err = mysql.ParseTime(x, tp, fsp) + if err != nil { + return nil, errors.Trace(err) + } + } + case *ast.ValueExpr: + switch xval := x.Data.(type) { + case string: + value, err = mysql.ParseTime(xval, tp, fsp) + if err != nil { + return nil, errors.Trace(err) + } + case int64: + value, err = mysql.ParseTimeFromNum(int64(xval), tp, fsp) + if err != nil { + return nil, errors.Trace(err) + } + case nil: + return nil, nil + default: + return nil, errors.Trace(errDefaultValue) + } + case *ast.FuncCallExpr: + if x.FnName.L == currentTimestampL { + return CurrentTimestamp, nil + } + return nil, errors.Trace(errDefaultValue) + case *ast.UnaryOperationExpr: + // support some expression, like `-1` + v, err := Eval(ctx, x) + if err != nil { + return nil, errors.Trace(err) + } + ft := types.NewFieldType(mysql.TypeLonglong) + xval, err := types.Convert(v, ft) + if err != nil { + return nil, errors.Trace(err) + } + + value, err = mysql.ParseTimeFromNum(xval.(int64), tp, fsp) + if err != nil { + return nil, errors.Trace(err) + } + default: + return nil, nil + } + + return value, nil +} + +// IsCurrentTimeExpr returns whether e is CurrentTimeExpr. +func IsCurrentTimeExpr(e ast.ExprNode) bool { + x, ok := e.(*ast.FuncCallExpr) + if !ok { + return false + } + return x.FnName.L == currentTimestampL +} + +func getSystemTimestamp(ctx context.Context) (time.Time, error) { + value := time.Now() + + if ctx == nil { + return value, nil + } + + // check whether use timestamp varibale + sessionVars := variable.GetSessionVars(ctx) + if v, ok := sessionVars.Systems["timestamp"]; ok { + if v != "" { + timestamp, err := strconv.ParseInt(v, 10, 64) + if err != nil { + return time.Time{}, errors.Trace(err) + } + + if timestamp <= 0 { + return value, nil + } + + return time.Unix(timestamp, 0), nil + } + } + + return value, nil +} diff --git a/parser/coldef/col_def.go b/parser/coldef/col_def.go index af9cd02785..ad1a9618fd 100644 --- a/parser/coldef/col_def.go +++ b/parser/coldef/col_def.go @@ -17,13 +17,7 @@ import ( "fmt" "strings" - "github.com/juju/errors" - "github.com/pingcap/tidb/column" - "github.com/pingcap/tidb/expression" - "github.com/pingcap/tidb/model" - "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/util/charset" "github.com/pingcap/tidb/util/types" ) @@ -83,186 +77,3 @@ func (c *ColumnDef) String() string { } return strings.Join(ans, " ") } - -func getDefaultValue(c *ConstraintOpt, tp byte, fsp int) (interface{}, error) { - if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime { - value, err := expression.GetTimeValue(nil, c.Evalue, tp, fsp) - if err != nil { - return nil, errors.Trace(err) - } - - // Value is nil means `default null`. - if value == nil { - return nil, nil - } - - // If value is mysql.Time, convert it to string. - if vv, ok := value.(mysql.Time); ok { - return vv.String(), nil - } - - return value, nil - } - v := expression.FastEval(c.Evalue) - return types.RawData(v), nil -} - -func removeOnUpdateNowFlag(c *column.Col) { - // For timestamp Col, if it is set null or default value, - // OnUpdateNowFlag should be removed. - if mysql.HasTimestampFlag(c.Flag) { - c.Flag &= ^uint(mysql.OnUpdateNowFlag) - } -} - -func setTimestampDefaultValue(c *column.Col, hasDefaultValue bool, setOnUpdateNow bool) { - if hasDefaultValue { - return - } - - // For timestamp Col, if is not set default value or not set null, use current timestamp. - if mysql.HasTimestampFlag(c.Flag) && mysql.HasNotNullFlag(c.Flag) { - if setOnUpdateNow { - c.DefaultValue = expression.ZeroTimestamp - } else { - c.DefaultValue = expression.CurrentTimestamp - } - } -} - -func setNoDefaultValueFlag(c *column.Col, hasDefaultValue bool) { - if hasDefaultValue { - return - } - - if !mysql.HasNotNullFlag(c.Flag) { - return - } - - // Check if it is an `AUTO_INCREMENT` field or `TIMESTAMP` field. - if !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) { - c.Flag |= mysql.NoDefaultValueFlag - } -} - -func checkDefaultValue(c *column.Col, hasDefaultValue bool) error { - if !hasDefaultValue { - return nil - } - - if c.DefaultValue != nil { - return nil - } - - // Set not null but default null is invalid. - if mysql.HasNotNullFlag(c.Flag) { - return errors.Errorf("invalid default value for %s", c.Name) - } - - return nil -} - -// ColumnDefToCol converts ColumnDef to Col and TableConstraints. -func ColumnDefToCol(offset int, colDef *ColumnDef) (*column.Col, []*TableConstraint, error) { - constraints := []*TableConstraint{} - col := &column.Col{ - ColumnInfo: model.ColumnInfo{ - Offset: offset, - Name: model.NewCIStr(colDef.Name), - FieldType: *colDef.Tp, - }, - } - - // Check and set TimestampFlag and OnUpdateNowFlag. - if col.Tp == mysql.TypeTimestamp { - col.Flag |= mysql.TimestampFlag - col.Flag |= mysql.OnUpdateNowFlag - col.Flag |= mysql.NotNullFlag - } - - // If flen is not assigned, assigned it by type. - if col.Flen == types.UnspecifiedLength { - col.Flen = mysql.GetDefaultFieldLength(col.Tp) - } - if col.Decimal == types.UnspecifiedLength { - col.Decimal = mysql.GetDefaultDecimal(col.Tp) - } - - setOnUpdateNow := false - hasDefaultValue := false - if colDef.Constraints != nil { - keys := []*IndexColName{ - { - colDef.Name, - colDef.Tp.Flen, - }, - } - for _, v := range colDef.Constraints { - switch v.Tp { - case ConstrNotNull: - col.Flag |= mysql.NotNullFlag - case ConstrNull: - col.Flag &= ^uint(mysql.NotNullFlag) - removeOnUpdateNowFlag(col) - case ConstrAutoIncrement: - col.Flag |= mysql.AutoIncrementFlag - case ConstrPrimaryKey: - constraint := &TableConstraint{Tp: ConstrPrimaryKey, Keys: keys} - constraints = append(constraints, constraint) - col.Flag |= mysql.PriKeyFlag - case ConstrUniq: - constraint := &TableConstraint{Tp: ConstrUniq, ConstrName: colDef.Name, Keys: keys} - constraints = append(constraints, constraint) - col.Flag |= mysql.UniqueKeyFlag - case ConstrIndex: - constraint := &TableConstraint{Tp: ConstrIndex, ConstrName: colDef.Name, Keys: keys} - constraints = append(constraints, constraint) - case ConstrUniqIndex: - constraint := &TableConstraint{Tp: ConstrUniqIndex, ConstrName: colDef.Name, Keys: keys} - constraints = append(constraints, constraint) - col.Flag |= mysql.UniqueKeyFlag - case ConstrKey: - constraint := &TableConstraint{Tp: ConstrKey, ConstrName: colDef.Name, Keys: keys} - constraints = append(constraints, constraint) - case ConstrUniqKey: - constraint := &TableConstraint{Tp: ConstrUniqKey, ConstrName: colDef.Name, Keys: keys} - constraints = append(constraints, constraint) - col.Flag |= mysql.UniqueKeyFlag - case ConstrDefaultValue: - value, err := getDefaultValue(v, colDef.Tp.Tp, colDef.Tp.Decimal) - if err != nil { - return nil, nil, errors.Errorf("invalid default value - %s", errors.Trace(err)) - } - col.DefaultValue = value - hasDefaultValue = true - removeOnUpdateNowFlag(col) - case ConstrOnUpdate: - if !expression.IsCurrentTimeExpr(v.Evalue) { - return nil, nil, errors.Errorf("invalid ON UPDATE for - %s", col.Name) - } - - col.Flag |= mysql.OnUpdateNowFlag - setOnUpdateNow = true - case ConstrFulltext: - // Do nothing. - case ConstrComment: - // Do nothing. - } - } - } - - setTimestampDefaultValue(col, hasDefaultValue, setOnUpdateNow) - - // Set `NoDefaultValueFlag` if this field doesn't have a default value and - // it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field. - setNoDefaultValueFlag(col, hasDefaultValue) - - err := checkDefaultValue(col, hasDefaultValue) - if err != nil { - return nil, nil, errors.Trace(err) - } - if col.Charset == charset.CharsetBin { - col.Flag |= mysql.BinaryFlag - } - return col, constraints, nil -} diff --git a/parser/coldef/opt.go b/parser/coldef/opt.go index 84034fc11b..c35cda20dd 100644 --- a/parser/coldef/opt.go +++ b/parser/coldef/opt.go @@ -17,7 +17,7 @@ import ( "fmt" "strings" - "github.com/pingcap/tidb/expression" + "github.com/pingcap/tidb/ast" "github.com/pingcap/tidb/mysql" ) @@ -50,7 +50,7 @@ func (o *CharsetOpt) String() string { type ConstraintOpt struct { Tp int Bvalue bool - Evalue expression.Expression + Evalue ast.ExprNode } // String implements fmt.Stringer interface. @@ -69,9 +69,9 @@ func (c *ConstraintOpt) String() string { case ConstrUniqKey: return "UNIQUE KEY" case ConstrDefaultValue: - return "DEFAULT " + c.Evalue.String() + return "DEFAULT " + ast.ToString(c.Evalue) case ConstrOnUpdate: - return "ON UPDATE " + c.Evalue.String() + return "ON UPDATE " + ast.ToString(c.Evalue) default: return "" } diff --git a/stmt/stmts/create_test.go b/stmt/stmts/create_test.go index 482beb4a98..940875d8ef 100644 --- a/stmt/stmts/create_test.go +++ b/stmt/stmts/create_test.go @@ -194,7 +194,7 @@ func mustCommit(c *C, tx *sql.Tx) { func mustExecuteSql(c *C, tx *sql.Tx, sql string) sql.Result { r, err := tx.Exec(sql) - c.Assert(err, IsNil) + c.Assert(err, IsNil, Commentf(sql)) return r } diff --git a/stmt/stmts/stmt_helper_test.go b/stmt/stmts/stmt_helper_test.go index a9d50737d2..a15bf6bf56 100644 --- a/stmt/stmts/stmt_helper_test.go +++ b/stmt/stmts/stmt_helper_test.go @@ -34,8 +34,9 @@ func (f *mockFormatter) Format(format string, args ...interface{}) (n int, errno } func (s *testStmtSuite) TestGetColDefaultValue(c *C) { - testSQL := `drop table if exists helper_test; - create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int not null, c2 timestamp, c3 int default 1);` + testSQL := `drop table if exists helper_test;` + mustExec(c, s.testDB, testSQL) + testSQL = `create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int not null, c2 timestamp, c3 int default 1);` mustExec(c, s.testDB, testSQL) testSQL = " insert helper_test (c1) values (1);" @@ -47,15 +48,17 @@ func (s *testStmtSuite) TestGetColDefaultValue(c *C) { c.Assert(err, NotNil) tx.Rollback() - testSQL = `drop table if exists helper_test; - create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 datetime, c3 int default 1);` + testSQL = `drop table if exists helper_test;` + mustExec(c, s.testDB, testSQL) + testSQL = `create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 datetime, c3 int default 1);` mustExec(c, s.testDB, testSQL) testSQL = " insert helper_test (c1) values (1);" mustExec(c, s.testDB, testSQL) - testSQL = `drop table if exists helper_test; - create table helper_test (c1 enum("a"), c2 enum("b", "e") not null, c3 enum("c") default "c", c4 enum("d") default "d" not null);` + testSQL = `drop table if exists helper_test;` + mustExec(c, s.testDB, testSQL) + testSQL = `create table helper_test (c1 enum("a"), c2 enum("b", "e") not null, c3 enum("c") default "c", c4 enum("d") default "d" not null);` mustExec(c, s.testDB, testSQL) testSQL = "insert into helper_test values();" diff --git a/table/tables/tables.go b/table/tables/tables.go index 6f34496de9..199f3385fb 100644 --- a/table/tables/tables.go +++ b/table/tables/tables.go @@ -26,11 +26,11 @@ import ( "github.com/ngaut/log" "github.com/pingcap/tidb/column" "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/expression" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/meta/autoid" "github.com/pingcap/tidb/model" "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/optimizer/evaluator" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/terror" @@ -314,7 +314,7 @@ func (t *Table) setOnUpdateData(ctx context.Context, touched map[int]bool, data ucols := column.FindOnUpdateCols(t.writableCols()) for _, col := range ucols { if !touched[col.Offset] { - value, err := expression.GetTimeValue(ctx, expression.CurrentTimestamp, col.Tp, col.Decimal) + value, err := evaluator.GetTimeValue(ctx, evaluator.CurrentTimestamp, col.Tp, col.Decimal) if err != nil { return errors.Trace(err) } @@ -762,11 +762,10 @@ func GetColDefaultValue(ctx context.Context, col *model.ColumnInfo) (interface{} return nil, true, nil } - value, err := expression.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal) + value, err := evaluator.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal) if err != nil { return nil, true, errors.Errorf("Field '%s' get default value fail - %s", col.Name, errors.Trace(err)) } - return value, true, nil } else if col.Tp == mysql.TypeEnum { // For enum type, if no default value and not null is set,