parser/coldef: remove dependency on expression.
This commit is contained in:
202
ddl/ddl.go
202
ddl/ddl.go
@ -32,10 +32,12 @@ import (
|
||||
"github.com/pingcap/tidb/meta"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/optimizer/evaluator"
|
||||
"github.com/pingcap/tidb/parser/coldef"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
"github.com/twinj/uuid"
|
||||
)
|
||||
|
||||
@ -345,11 +347,12 @@ func setColumnFlagWithConstraint(colMap map[string]*column.Col, v *coldef.TableC
|
||||
}
|
||||
}
|
||||
|
||||
func (d *ddl) buildColumnsAndConstraints(colDefs []*coldef.ColumnDef, constraints []*coldef.TableConstraint) ([]*column.Col, []*coldef.TableConstraint, error) {
|
||||
func (d *ddl) buildColumnsAndConstraints(ctx context.Context, colDefs []*coldef.ColumnDef,
|
||||
constraints []*coldef.TableConstraint) ([]*column.Col, []*coldef.TableConstraint, error) {
|
||||
var cols []*column.Col
|
||||
colMap := map[string]*column.Col{}
|
||||
for i, colDef := range colDefs {
|
||||
col, cts, err := d.buildColumnAndConstraint(i, colDef)
|
||||
col, cts, err := d.buildColumnAndConstraint(ctx, i, colDef)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
@ -365,7 +368,8 @@ func (d *ddl) buildColumnsAndConstraints(colDefs []*coldef.ColumnDef, constraint
|
||||
return cols, constraints, nil
|
||||
}
|
||||
|
||||
func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) {
|
||||
func (d *ddl) buildColumnAndConstraint(ctx context.Context, offset int,
|
||||
colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) {
|
||||
// Set charset.
|
||||
if len(colDef.Tp.Charset) == 0 {
|
||||
switch colDef.Tp.Tp {
|
||||
@ -377,7 +381,7 @@ func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*c
|
||||
}
|
||||
}
|
||||
|
||||
col, cts, err := coldef.ColumnDefToCol(offset, colDef)
|
||||
col, cts, err := columnDefToCol(ctx, offset, colDef)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
@ -390,6 +394,192 @@ func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*c
|
||||
return col, cts, nil
|
||||
}
|
||||
|
||||
// columnDefToCol converts ColumnDef to Col and TableConstraints.
|
||||
func columnDefToCol(ctx context.Context, offset int, colDef *coldef.ColumnDef) (*column.Col, []*coldef.TableConstraint, error) {
|
||||
constraints := []*coldef.TableConstraint{}
|
||||
col := &column.Col{
|
||||
ColumnInfo: model.ColumnInfo{
|
||||
Offset: offset,
|
||||
Name: model.NewCIStr(colDef.Name),
|
||||
FieldType: *colDef.Tp,
|
||||
},
|
||||
}
|
||||
|
||||
// Check and set TimestampFlag and OnUpdateNowFlag.
|
||||
if col.Tp == mysql.TypeTimestamp {
|
||||
col.Flag |= mysql.TimestampFlag
|
||||
col.Flag |= mysql.OnUpdateNowFlag
|
||||
col.Flag |= mysql.NotNullFlag
|
||||
}
|
||||
|
||||
// If flen is not assigned, assigned it by type.
|
||||
if col.Flen == types.UnspecifiedLength {
|
||||
col.Flen = mysql.GetDefaultFieldLength(col.Tp)
|
||||
}
|
||||
if col.Decimal == types.UnspecifiedLength {
|
||||
col.Decimal = mysql.GetDefaultDecimal(col.Tp)
|
||||
}
|
||||
|
||||
setOnUpdateNow := false
|
||||
hasDefaultValue := false
|
||||
if colDef.Constraints != nil {
|
||||
keys := []*coldef.IndexColName{
|
||||
{
|
||||
colDef.Name,
|
||||
colDef.Tp.Flen,
|
||||
},
|
||||
}
|
||||
for _, v := range colDef.Constraints {
|
||||
switch v.Tp {
|
||||
case coldef.ConstrNotNull:
|
||||
col.Flag |= mysql.NotNullFlag
|
||||
case coldef.ConstrNull:
|
||||
col.Flag &= ^uint(mysql.NotNullFlag)
|
||||
removeOnUpdateNowFlag(col)
|
||||
case coldef.ConstrAutoIncrement:
|
||||
col.Flag |= mysql.AutoIncrementFlag
|
||||
case coldef.ConstrPrimaryKey:
|
||||
constraint := &coldef.TableConstraint{Tp: coldef.ConstrPrimaryKey, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.PriKeyFlag
|
||||
case coldef.ConstrUniq:
|
||||
constraint := &coldef.TableConstraint{Tp: coldef.ConstrUniq, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.UniqueKeyFlag
|
||||
case coldef.ConstrIndex:
|
||||
constraint := &coldef.TableConstraint{Tp: coldef.ConstrIndex, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
case coldef.ConstrUniqIndex:
|
||||
constraint := &coldef.TableConstraint{Tp: coldef.ConstrUniqIndex, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.UniqueKeyFlag
|
||||
case coldef.ConstrKey:
|
||||
constraint := &coldef.TableConstraint{Tp: coldef.ConstrKey, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
case coldef.ConstrUniqKey:
|
||||
constraint := &coldef.TableConstraint{Tp: coldef.ConstrUniqKey, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.UniqueKeyFlag
|
||||
case coldef.ConstrDefaultValue:
|
||||
value, err := getDefaultValue(ctx, v, colDef.Tp.Tp, colDef.Tp.Decimal)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("invalid default value - %s", errors.Trace(err))
|
||||
}
|
||||
col.DefaultValue = value
|
||||
hasDefaultValue = true
|
||||
removeOnUpdateNowFlag(col)
|
||||
case coldef.ConstrOnUpdate:
|
||||
if !evaluator.IsCurrentTimeExpr(v.Evalue) {
|
||||
return nil, nil, errors.Errorf("invalid ON UPDATE for - %s", col.Name)
|
||||
}
|
||||
|
||||
col.Flag |= mysql.OnUpdateNowFlag
|
||||
setOnUpdateNow = true
|
||||
case coldef.ConstrFulltext:
|
||||
// Do nothing.
|
||||
case coldef.ConstrComment:
|
||||
// Do nothing.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setTimestampDefaultValue(col, hasDefaultValue, setOnUpdateNow)
|
||||
|
||||
// Set `NoDefaultValueFlag` if this field doesn't have a default value and
|
||||
// it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field.
|
||||
setNoDefaultValueFlag(col, hasDefaultValue)
|
||||
|
||||
err := checkDefaultValue(col, hasDefaultValue)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
if col.Charset == charset.CharsetBin {
|
||||
col.Flag |= mysql.BinaryFlag
|
||||
}
|
||||
return col, constraints, nil
|
||||
}
|
||||
|
||||
func getDefaultValue(ctx context.Context, c *coldef.ConstraintOpt, tp byte, fsp int) (interface{}, error) {
|
||||
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
|
||||
value, err := evaluator.GetTimeValue(ctx, c.Evalue, tp, fsp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// Value is nil means `default null`.
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If value is mysql.Time, convert it to string.
|
||||
if vv, ok := value.(mysql.Time); ok {
|
||||
return vv.String(), nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
v, err := evaluator.Eval(ctx, c.Evalue)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
return types.RawData(v), nil
|
||||
}
|
||||
|
||||
func removeOnUpdateNowFlag(c *column.Col) {
|
||||
// For timestamp Col, if it is set null or default value,
|
||||
// OnUpdateNowFlag should be removed.
|
||||
if mysql.HasTimestampFlag(c.Flag) {
|
||||
c.Flag &= ^uint(mysql.OnUpdateNowFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func setTimestampDefaultValue(c *column.Col, hasDefaultValue bool, setOnUpdateNow bool) {
|
||||
if hasDefaultValue {
|
||||
return
|
||||
}
|
||||
|
||||
// For timestamp Col, if is not set default value or not set null, use current timestamp.
|
||||
if mysql.HasTimestampFlag(c.Flag) && mysql.HasNotNullFlag(c.Flag) {
|
||||
if setOnUpdateNow {
|
||||
c.DefaultValue = evaluator.ZeroTimestamp
|
||||
} else {
|
||||
c.DefaultValue = evaluator.CurrentTimestamp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setNoDefaultValueFlag(c *column.Col, hasDefaultValue bool) {
|
||||
if hasDefaultValue {
|
||||
return
|
||||
}
|
||||
|
||||
if !mysql.HasNotNullFlag(c.Flag) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it is an `AUTO_INCREMENT` field or `TIMESTAMP` field.
|
||||
if !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) {
|
||||
c.Flag |= mysql.NoDefaultValueFlag
|
||||
}
|
||||
}
|
||||
|
||||
func checkDefaultValue(c *column.Col, hasDefaultValue bool) error {
|
||||
if !hasDefaultValue {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.DefaultValue != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set not null but default null is invalid.
|
||||
if mysql.HasNotNullFlag(c.Flag) {
|
||||
return errors.Errorf("invalid default value for %s", c.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDuplicateColumn(colDefs []*coldef.ColumnDef) error {
|
||||
colNames := map[string]bool{}
|
||||
for _, colDef := range colDefs {
|
||||
@ -511,7 +701,7 @@ func (d *ddl) CreateTable(ctx context.Context, ident table.Ident, colDefs []*col
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
cols, newConstraints, err := d.buildColumnsAndConstraints(colDefs, constraints)
|
||||
cols, newConstraints, err := d.buildColumnsAndConstraints(ctx, colDefs, constraints)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -614,7 +804,7 @@ func (d *ddl) AddColumn(ctx context.Context, ti table.Ident, spec *AlterSpecific
|
||||
// ingore table constraints now, maybe return error later
|
||||
// we use length(t.Cols()) as the default offset first, later we will change the
|
||||
// column's offset later.
|
||||
col, _, err = d.buildColumnAndConstraint(len(t.Cols()), spec.Column)
|
||||
col, _, err = d.buildColumnAndConstraint(ctx, len(t.Cols()), spec.Column)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
@ -486,13 +486,7 @@ func convertColumnOption(converter *expressionConverter, v *ast.ColumnOption) (*
|
||||
case ast.ColumnOptionUniqKey:
|
||||
oldColumnOpt.Tp = coldef.ConstrUniqKey
|
||||
}
|
||||
if v.Expr != nil {
|
||||
oldExpr, err := convertExpr(converter, v.Expr)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
oldColumnOpt.Evalue = oldExpr
|
||||
}
|
||||
oldColumnOpt.Evalue = v.Expr
|
||||
return oldColumnOpt, nil
|
||||
}
|
||||
|
||||
|
||||
@ -2,13 +2,10 @@ package expression
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/pingcap/check"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
)
|
||||
|
||||
var _ = Suite(&testHelperSuite{})
|
||||
@ -164,96 +161,6 @@ func (s *testHelperSuite) TestBase(c *C) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testHelperSuite) TestGetTimeValue(c *C) {
|
||||
v, err := GetTimeValue(nil, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok := v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
ctx := newMockCtx()
|
||||
variable.BindSessionVars(ctx)
|
||||
sessionVars := variable.GetSessionVars(ctx)
|
||||
|
||||
sessionVars.Systems["timestamp"] = ""
|
||||
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok = v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
sessionVars.Systems["timestamp"] = "0"
|
||||
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok = v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
delete(sessionVars.Systems, "timestamp")
|
||||
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok = v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
sessionVars.Systems["timestamp"] = "1234"
|
||||
|
||||
tbl := []struct {
|
||||
Expr interface{}
|
||||
Ret interface{}
|
||||
}{
|
||||
{"2012-12-12 00:00:00", "2012-12-12 00:00:00"},
|
||||
{CurrentTimestamp, time.Unix(1234, 0).Format(mysql.TimeFormat)},
|
||||
{ZeroTimestamp, "0000-00-00 00:00:00"},
|
||||
{Value{"2012-12-12 00:00:00"}, "2012-12-12 00:00:00"},
|
||||
{Value{int64(0)}, "0000-00-00 00:00:00"},
|
||||
{Value{}, nil},
|
||||
{CurrentTimeExpr, CurrentTimestamp},
|
||||
{NewUnaryOperation(opcode.Minus, Value{int64(0)}), "0000-00-00 00:00:00"},
|
||||
{mockExpr{}, nil},
|
||||
}
|
||||
|
||||
for _, t := range tbl {
|
||||
v, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
switch x := v.(type) {
|
||||
case mysql.Time:
|
||||
c.Assert(x.String(), DeepEquals, t.Ret)
|
||||
default:
|
||||
c.Assert(x, DeepEquals, t.Ret)
|
||||
}
|
||||
}
|
||||
|
||||
errTbl := []struct {
|
||||
Expr interface{}
|
||||
}{
|
||||
{"2012-13-12 00:00:00"},
|
||||
{Value{"2012-13-12 00:00:00"}},
|
||||
{Value{0}},
|
||||
{Value{int64(1)}},
|
||||
{&Call{F: "xxx"}},
|
||||
{NewUnaryOperation(opcode.Minus, Value{int64(1)})},
|
||||
}
|
||||
|
||||
for _, t := range errTbl {
|
||||
_, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testHelperSuite) TestIsCurrentTimeExpr(c *C) {
|
||||
v := IsCurrentTimeExpr(mockExpr{})
|
||||
c.Assert(v, IsFalse)
|
||||
|
||||
v = IsCurrentTimeExpr(CurrentTimeExpr)
|
||||
c.Assert(v, IsTrue)
|
||||
}
|
||||
|
||||
func convert(v interface{}) interface{} {
|
||||
switch x := v.(type) {
|
||||
case int:
|
||||
|
||||
@ -24,6 +24,7 @@ import (
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/parser"
|
||||
"github.com/pingcap/tidb/parser/opcode"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/mock"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
@ -1091,3 +1092,93 @@ func (s *testEvaluatorSuite) TestAggFuncAvg(c *C) {
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(v.Equals(expect), IsTrue)
|
||||
}
|
||||
|
||||
func (s *testEvaluatorSuite) TestGetTimeValue(c *C) {
|
||||
v, err := GetTimeValue(nil, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok := v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
ctx := mock.NewContext()
|
||||
variable.BindSessionVars(ctx)
|
||||
sessionVars := variable.GetSessionVars(ctx)
|
||||
|
||||
sessionVars.Systems["timestamp"] = ""
|
||||
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok = v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
sessionVars.Systems["timestamp"] = "0"
|
||||
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok = v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
delete(sessionVars.Systems, "timestamp")
|
||||
v, err = GetTimeValue(ctx, "2012-12-12 00:00:00", mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
timeValue, ok = v.(mysql.Time)
|
||||
c.Assert(ok, IsTrue)
|
||||
c.Assert(timeValue.String(), Equals, "2012-12-12 00:00:00")
|
||||
|
||||
sessionVars.Systems["timestamp"] = "1234"
|
||||
|
||||
tbl := []struct {
|
||||
Expr interface{}
|
||||
Ret interface{}
|
||||
}{
|
||||
{"2012-12-12 00:00:00", "2012-12-12 00:00:00"},
|
||||
{CurrentTimestamp, time.Unix(1234, 0).Format(mysql.TimeFormat)},
|
||||
{ZeroTimestamp, "0000-00-00 00:00:00"},
|
||||
{ast.NewValueExpr("2012-12-12 00:00:00"), "2012-12-12 00:00:00"},
|
||||
{ast.NewValueExpr(int64(0)), "0000-00-00 00:00:00"},
|
||||
{ast.NewValueExpr(nil), nil},
|
||||
{&ast.FuncCallExpr{FnName: model.NewCIStr(CurrentTimestamp)}, CurrentTimestamp},
|
||||
{&ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr(int64(0))}, "0000-00-00 00:00:00"},
|
||||
}
|
||||
|
||||
for i, t := range tbl {
|
||||
comment := Commentf("expr: %d", i)
|
||||
v, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
switch x := v.(type) {
|
||||
case mysql.Time:
|
||||
c.Assert(x.String(), DeepEquals, t.Ret, comment)
|
||||
default:
|
||||
c.Assert(x, DeepEquals, t.Ret, comment)
|
||||
}
|
||||
}
|
||||
|
||||
errTbl := []struct {
|
||||
Expr interface{}
|
||||
}{
|
||||
{"2012-13-12 00:00:00"},
|
||||
{ast.NewValueExpr("2012-13-12 00:00:00")},
|
||||
{ast.NewValueExpr(0)},
|
||||
{ast.NewValueExpr(int64(1))},
|
||||
{&ast.FuncCallExpr{FnName: model.NewCIStr("xxx")}},
|
||||
{&ast.UnaryOperationExpr{Op: opcode.Minus, V: ast.NewValueExpr(int64(1))}},
|
||||
}
|
||||
|
||||
for _, t := range errTbl {
|
||||
_, err := GetTimeValue(ctx, t.Expr, mysql.TypeTimestamp, mysql.MinFsp)
|
||||
c.Assert(err, NotNil)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *testEvaluatorSuite) TestIsCurrentTimeExpr(c *C) {
|
||||
v := IsCurrentTimeExpr(ast.NewValueExpr("abc"))
|
||||
c.Assert(v, IsFalse)
|
||||
|
||||
v = IsCurrentTimeExpr(&ast.FuncCallExpr{FnName: model.NewCIStr("CURRENT_TIMESTAMP")})
|
||||
c.Assert(v, IsTrue)
|
||||
}
|
||||
|
||||
136
optimizer/evaluator/helper.go
Normal file
136
optimizer/evaluator/helper.go
Normal file
@ -0,0 +1,136 @@
|
||||
package evaluator
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
var (
|
||||
// CurrentTimestamp is the keyword getting default value for datetime and timestamp type.
|
||||
CurrentTimestamp = "CURRENT_TIMESTAMP"
|
||||
currentTimestampL = "current_timestamp"
|
||||
// ZeroTimestamp shows the zero datetime and timestamp.
|
||||
ZeroTimestamp = "0000-00-00 00:00:00"
|
||||
)
|
||||
|
||||
var (
|
||||
errDefaultValue = errors.New("invalid default value")
|
||||
)
|
||||
|
||||
// GetTimeValue gets the time value with type tp.
|
||||
func GetTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) {
|
||||
return getTimeValue(ctx, v, tp, fsp)
|
||||
}
|
||||
|
||||
func getTimeValue(ctx context.Context, v interface{}, tp byte, fsp int) (interface{}, error) {
|
||||
value := mysql.Time{
|
||||
Type: tp,
|
||||
Fsp: fsp,
|
||||
}
|
||||
|
||||
defaultTime, err := getSystemTimestamp(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
switch x := v.(type) {
|
||||
case string:
|
||||
upperX := strings.ToUpper(x)
|
||||
if upperX == CurrentTimestamp {
|
||||
value.Time = defaultTime
|
||||
} else if upperX == ZeroTimestamp {
|
||||
value, _ = mysql.ParseTimeFromNum(0, tp, fsp)
|
||||
} else {
|
||||
value, err = mysql.ParseTime(x, tp, fsp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
}
|
||||
case *ast.ValueExpr:
|
||||
switch xval := x.Data.(type) {
|
||||
case string:
|
||||
value, err = mysql.ParseTime(xval, tp, fsp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
case int64:
|
||||
value, err = mysql.ParseTimeFromNum(int64(xval), tp, fsp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
case nil:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errors.Trace(errDefaultValue)
|
||||
}
|
||||
case *ast.FuncCallExpr:
|
||||
if x.FnName.L == currentTimestampL {
|
||||
return CurrentTimestamp, nil
|
||||
}
|
||||
return nil, errors.Trace(errDefaultValue)
|
||||
case *ast.UnaryOperationExpr:
|
||||
// support some expression, like `-1`
|
||||
v, err := Eval(ctx, x)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
ft := types.NewFieldType(mysql.TypeLonglong)
|
||||
xval, err := types.Convert(v, ft)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
value, err = mysql.ParseTimeFromNum(xval.(int64), tp, fsp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
default:
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// IsCurrentTimeExpr returns whether e is CurrentTimeExpr.
|
||||
func IsCurrentTimeExpr(e ast.ExprNode) bool {
|
||||
x, ok := e.(*ast.FuncCallExpr)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return x.FnName.L == currentTimestampL
|
||||
}
|
||||
|
||||
func getSystemTimestamp(ctx context.Context) (time.Time, error) {
|
||||
value := time.Now()
|
||||
|
||||
if ctx == nil {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
// check whether use timestamp varibale
|
||||
sessionVars := variable.GetSessionVars(ctx)
|
||||
if v, ok := sessionVars.Systems["timestamp"]; ok {
|
||||
if v != "" {
|
||||
timestamp, err := strconv.ParseInt(v, 10, 64)
|
||||
if err != nil {
|
||||
return time.Time{}, errors.Trace(err)
|
||||
}
|
||||
|
||||
if timestamp <= 0 {
|
||||
return value, nil
|
||||
}
|
||||
|
||||
return time.Unix(timestamp, 0), nil
|
||||
}
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
@ -17,13 +17,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/juju/errors"
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/util/charset"
|
||||
"github.com/pingcap/tidb/util/types"
|
||||
)
|
||||
|
||||
@ -83,186 +77,3 @@ func (c *ColumnDef) String() string {
|
||||
}
|
||||
return strings.Join(ans, " ")
|
||||
}
|
||||
|
||||
func getDefaultValue(c *ConstraintOpt, tp byte, fsp int) (interface{}, error) {
|
||||
if tp == mysql.TypeTimestamp || tp == mysql.TypeDatetime {
|
||||
value, err := expression.GetTimeValue(nil, c.Evalue, tp, fsp)
|
||||
if err != nil {
|
||||
return nil, errors.Trace(err)
|
||||
}
|
||||
|
||||
// Value is nil means `default null`.
|
||||
if value == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// If value is mysql.Time, convert it to string.
|
||||
if vv, ok := value.(mysql.Time); ok {
|
||||
return vv.String(), nil
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
v := expression.FastEval(c.Evalue)
|
||||
return types.RawData(v), nil
|
||||
}
|
||||
|
||||
func removeOnUpdateNowFlag(c *column.Col) {
|
||||
// For timestamp Col, if it is set null or default value,
|
||||
// OnUpdateNowFlag should be removed.
|
||||
if mysql.HasTimestampFlag(c.Flag) {
|
||||
c.Flag &= ^uint(mysql.OnUpdateNowFlag)
|
||||
}
|
||||
}
|
||||
|
||||
func setTimestampDefaultValue(c *column.Col, hasDefaultValue bool, setOnUpdateNow bool) {
|
||||
if hasDefaultValue {
|
||||
return
|
||||
}
|
||||
|
||||
// For timestamp Col, if is not set default value or not set null, use current timestamp.
|
||||
if mysql.HasTimestampFlag(c.Flag) && mysql.HasNotNullFlag(c.Flag) {
|
||||
if setOnUpdateNow {
|
||||
c.DefaultValue = expression.ZeroTimestamp
|
||||
} else {
|
||||
c.DefaultValue = expression.CurrentTimestamp
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setNoDefaultValueFlag(c *column.Col, hasDefaultValue bool) {
|
||||
if hasDefaultValue {
|
||||
return
|
||||
}
|
||||
|
||||
if !mysql.HasNotNullFlag(c.Flag) {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if it is an `AUTO_INCREMENT` field or `TIMESTAMP` field.
|
||||
if !mysql.HasAutoIncrementFlag(c.Flag) && !mysql.HasTimestampFlag(c.Flag) {
|
||||
c.Flag |= mysql.NoDefaultValueFlag
|
||||
}
|
||||
}
|
||||
|
||||
func checkDefaultValue(c *column.Col, hasDefaultValue bool) error {
|
||||
if !hasDefaultValue {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c.DefaultValue != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Set not null but default null is invalid.
|
||||
if mysql.HasNotNullFlag(c.Flag) {
|
||||
return errors.Errorf("invalid default value for %s", c.Name)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ColumnDefToCol converts ColumnDef to Col and TableConstraints.
|
||||
func ColumnDefToCol(offset int, colDef *ColumnDef) (*column.Col, []*TableConstraint, error) {
|
||||
constraints := []*TableConstraint{}
|
||||
col := &column.Col{
|
||||
ColumnInfo: model.ColumnInfo{
|
||||
Offset: offset,
|
||||
Name: model.NewCIStr(colDef.Name),
|
||||
FieldType: *colDef.Tp,
|
||||
},
|
||||
}
|
||||
|
||||
// Check and set TimestampFlag and OnUpdateNowFlag.
|
||||
if col.Tp == mysql.TypeTimestamp {
|
||||
col.Flag |= mysql.TimestampFlag
|
||||
col.Flag |= mysql.OnUpdateNowFlag
|
||||
col.Flag |= mysql.NotNullFlag
|
||||
}
|
||||
|
||||
// If flen is not assigned, assigned it by type.
|
||||
if col.Flen == types.UnspecifiedLength {
|
||||
col.Flen = mysql.GetDefaultFieldLength(col.Tp)
|
||||
}
|
||||
if col.Decimal == types.UnspecifiedLength {
|
||||
col.Decimal = mysql.GetDefaultDecimal(col.Tp)
|
||||
}
|
||||
|
||||
setOnUpdateNow := false
|
||||
hasDefaultValue := false
|
||||
if colDef.Constraints != nil {
|
||||
keys := []*IndexColName{
|
||||
{
|
||||
colDef.Name,
|
||||
colDef.Tp.Flen,
|
||||
},
|
||||
}
|
||||
for _, v := range colDef.Constraints {
|
||||
switch v.Tp {
|
||||
case ConstrNotNull:
|
||||
col.Flag |= mysql.NotNullFlag
|
||||
case ConstrNull:
|
||||
col.Flag &= ^uint(mysql.NotNullFlag)
|
||||
removeOnUpdateNowFlag(col)
|
||||
case ConstrAutoIncrement:
|
||||
col.Flag |= mysql.AutoIncrementFlag
|
||||
case ConstrPrimaryKey:
|
||||
constraint := &TableConstraint{Tp: ConstrPrimaryKey, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.PriKeyFlag
|
||||
case ConstrUniq:
|
||||
constraint := &TableConstraint{Tp: ConstrUniq, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.UniqueKeyFlag
|
||||
case ConstrIndex:
|
||||
constraint := &TableConstraint{Tp: ConstrIndex, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
case ConstrUniqIndex:
|
||||
constraint := &TableConstraint{Tp: ConstrUniqIndex, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.UniqueKeyFlag
|
||||
case ConstrKey:
|
||||
constraint := &TableConstraint{Tp: ConstrKey, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
case ConstrUniqKey:
|
||||
constraint := &TableConstraint{Tp: ConstrUniqKey, ConstrName: colDef.Name, Keys: keys}
|
||||
constraints = append(constraints, constraint)
|
||||
col.Flag |= mysql.UniqueKeyFlag
|
||||
case ConstrDefaultValue:
|
||||
value, err := getDefaultValue(v, colDef.Tp.Tp, colDef.Tp.Decimal)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Errorf("invalid default value - %s", errors.Trace(err))
|
||||
}
|
||||
col.DefaultValue = value
|
||||
hasDefaultValue = true
|
||||
removeOnUpdateNowFlag(col)
|
||||
case ConstrOnUpdate:
|
||||
if !expression.IsCurrentTimeExpr(v.Evalue) {
|
||||
return nil, nil, errors.Errorf("invalid ON UPDATE for - %s", col.Name)
|
||||
}
|
||||
|
||||
col.Flag |= mysql.OnUpdateNowFlag
|
||||
setOnUpdateNow = true
|
||||
case ConstrFulltext:
|
||||
// Do nothing.
|
||||
case ConstrComment:
|
||||
// Do nothing.
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
setTimestampDefaultValue(col, hasDefaultValue, setOnUpdateNow)
|
||||
|
||||
// Set `NoDefaultValueFlag` if this field doesn't have a default value and
|
||||
// it is `not null` and not an `AUTO_INCREMENT` field or `TIMESTAMP` field.
|
||||
setNoDefaultValueFlag(col, hasDefaultValue)
|
||||
|
||||
err := checkDefaultValue(col, hasDefaultValue)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
if col.Charset == charset.CharsetBin {
|
||||
col.Flag |= mysql.BinaryFlag
|
||||
}
|
||||
return col, constraints, nil
|
||||
}
|
||||
|
||||
@ -17,7 +17,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/ast"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
)
|
||||
|
||||
@ -50,7 +50,7 @@ func (o *CharsetOpt) String() string {
|
||||
type ConstraintOpt struct {
|
||||
Tp int
|
||||
Bvalue bool
|
||||
Evalue expression.Expression
|
||||
Evalue ast.ExprNode
|
||||
}
|
||||
|
||||
// String implements fmt.Stringer interface.
|
||||
@ -69,9 +69,9 @@ func (c *ConstraintOpt) String() string {
|
||||
case ConstrUniqKey:
|
||||
return "UNIQUE KEY"
|
||||
case ConstrDefaultValue:
|
||||
return "DEFAULT " + c.Evalue.String()
|
||||
return "DEFAULT " + ast.ToString(c.Evalue)
|
||||
case ConstrOnUpdate:
|
||||
return "ON UPDATE " + c.Evalue.String()
|
||||
return "ON UPDATE " + ast.ToString(c.Evalue)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
|
||||
@ -194,7 +194,7 @@ func mustCommit(c *C, tx *sql.Tx) {
|
||||
|
||||
func mustExecuteSql(c *C, tx *sql.Tx, sql string) sql.Result {
|
||||
r, err := tx.Exec(sql)
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(err, IsNil, Commentf(sql))
|
||||
return r
|
||||
}
|
||||
|
||||
|
||||
@ -34,8 +34,9 @@ func (f *mockFormatter) Format(format string, args ...interface{}) (n int, errno
|
||||
}
|
||||
|
||||
func (s *testStmtSuite) TestGetColDefaultValue(c *C) {
|
||||
testSQL := `drop table if exists helper_test;
|
||||
create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int not null, c2 timestamp, c3 int default 1);`
|
||||
testSQL := `drop table if exists helper_test;`
|
||||
mustExec(c, s.testDB, testSQL)
|
||||
testSQL = `create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int not null, c2 timestamp, c3 int default 1);`
|
||||
mustExec(c, s.testDB, testSQL)
|
||||
|
||||
testSQL = " insert helper_test (c1) values (1);"
|
||||
@ -47,15 +48,17 @@ func (s *testStmtSuite) TestGetColDefaultValue(c *C) {
|
||||
c.Assert(err, NotNil)
|
||||
tx.Rollback()
|
||||
|
||||
testSQL = `drop table if exists helper_test;
|
||||
create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 datetime, c3 int default 1);`
|
||||
testSQL = `drop table if exists helper_test;`
|
||||
mustExec(c, s.testDB, testSQL)
|
||||
testSQL = `create table helper_test (id int PRIMARY KEY AUTO_INCREMENT, c1 int, c2 datetime, c3 int default 1);`
|
||||
mustExec(c, s.testDB, testSQL)
|
||||
|
||||
testSQL = " insert helper_test (c1) values (1);"
|
||||
mustExec(c, s.testDB, testSQL)
|
||||
|
||||
testSQL = `drop table if exists helper_test;
|
||||
create table helper_test (c1 enum("a"), c2 enum("b", "e") not null, c3 enum("c") default "c", c4 enum("d") default "d" not null);`
|
||||
testSQL = `drop table if exists helper_test;`
|
||||
mustExec(c, s.testDB, testSQL)
|
||||
testSQL = `create table helper_test (c1 enum("a"), c2 enum("b", "e") not null, c3 enum("c") default "c", c4 enum("d") default "d" not null);`
|
||||
mustExec(c, s.testDB, testSQL)
|
||||
|
||||
testSQL = "insert into helper_test values();"
|
||||
|
||||
@ -26,11 +26,11 @@ import (
|
||||
"github.com/ngaut/log"
|
||||
"github.com/pingcap/tidb/column"
|
||||
"github.com/pingcap/tidb/context"
|
||||
"github.com/pingcap/tidb/expression"
|
||||
"github.com/pingcap/tidb/kv"
|
||||
"github.com/pingcap/tidb/meta/autoid"
|
||||
"github.com/pingcap/tidb/model"
|
||||
"github.com/pingcap/tidb/mysql"
|
||||
"github.com/pingcap/tidb/optimizer/evaluator"
|
||||
"github.com/pingcap/tidb/sessionctx/variable"
|
||||
"github.com/pingcap/tidb/table"
|
||||
"github.com/pingcap/tidb/terror"
|
||||
@ -314,7 +314,7 @@ func (t *Table) setOnUpdateData(ctx context.Context, touched map[int]bool, data
|
||||
ucols := column.FindOnUpdateCols(t.writableCols())
|
||||
for _, col := range ucols {
|
||||
if !touched[col.Offset] {
|
||||
value, err := expression.GetTimeValue(ctx, expression.CurrentTimestamp, col.Tp, col.Decimal)
|
||||
value, err := evaluator.GetTimeValue(ctx, evaluator.CurrentTimestamp, col.Tp, col.Decimal)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
@ -762,11 +762,10 @@ func GetColDefaultValue(ctx context.Context, col *model.ColumnInfo) (interface{}
|
||||
return nil, true, nil
|
||||
}
|
||||
|
||||
value, err := expression.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal)
|
||||
value, err := evaluator.GetTimeValue(ctx, col.DefaultValue, col.Tp, col.Decimal)
|
||||
if err != nil {
|
||||
return nil, true, errors.Errorf("Field '%s' get default value fail - %s", col.Name, errors.Trace(err))
|
||||
}
|
||||
|
||||
return value, true, nil
|
||||
} else if col.Tp == mysql.TypeEnum {
|
||||
// For enum type, if no default value and not null is set,
|
||||
|
||||
Reference in New Issue
Block a user