ddl: add primary key default value check when create table (#7169)

This commit is contained in:
crazycs
2018-07-31 15:44:11 +08:00
committed by GitHub
parent c38f567645
commit 58248563bc
4 changed files with 66 additions and 7 deletions

View File

@ -937,7 +937,7 @@ func (s *testColumnSuite) colDefStrToFieldType(c *C, str string) *types.FieldTyp
stmt, err := parser.New().ParseOneStmt(sqlA, "", "")
c.Assert(err, IsNil)
colDef := stmt.(*ast.AlterTableStmt).Specs[0].NewColumns[0]
col, _, err := buildColumnAndConstraint(nil, 0, colDef)
col, _, err := buildColumnAndConstraint(nil, 0, colDef, nil)
c.Assert(err, IsNil)
return &col.FieldType
}

View File

@ -187,6 +187,17 @@ func (s *testDBSuite) TestMySQLErrorCode(c *C) {
s.testErrorCode(c, sql, tmysql.ErrWrongNameForIndex)
sql = "create table t2(c1.c2 blob default null);"
s.testErrorCode(c, sql, tmysql.ErrWrongTableName)
sql = "create table t2 (id int default null primary key , age int);"
s.testErrorCode(c, sql, tmysql.ErrInvalidDefault)
sql = "create table t2 (id int null primary key , age int);"
s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull)
sql = "create table t2 (id int default null, age int, primary key(id));"
s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull)
sql = "create table t2 (id int null, age int, primary key(id));"
s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull)
sql = "create table t2 (id int primary key , age int);"
s.tk.MustExec(sql)
// add column
sql = "alter table test_error_code_succ add column c1 int"

View File

@ -158,6 +158,8 @@ var (
ErrWrongNameForIndex = terror.ClassDDL.New(codeWrongNameForIndex, mysql.MySQLErrName[mysql.ErrWrongNameForIndex])
// ErrUnknownCharacterSet returns unknown character set.
ErrUnknownCharacterSet = terror.ClassDDL.New(codeUnknownCharacterSet, "Unknown character set: '%s'")
// ErrPrimaryCantHaveNull returns All parts of a PRIMARY KEY must be NOT NULL; if you need NULL in a key, use UNIQUE instead
ErrPrimaryCantHaveNull = terror.ClassDDL.New(codePrimaryCantHaveNull, mysql.MySQLErrName[mysql.ErrPrimaryCantHaveNull])
// ErrNotAllowedTypeInPartition returns not allowed type error when creating table partiton with unsupport expression type.
ErrNotAllowedTypeInPartition = terror.ClassDDL.New(codeErrFieldTypeNotAllowedAsPartitionField, mysql.MySQLErrName[mysql.ErrFieldTypeNotAllowedAsPartitionField])
@ -592,6 +594,7 @@ const (
codePartitionFunctionIsNotAllowed = terror.ErrCode(mysql.ErrPartitionFunctionIsNotAllowed)
codeErrPartitionFuncNotAllowed = terror.ErrCode(mysql.ErrPartitionFuncNotAllowed)
codeErrFieldTypeNotAllowedAsPartitionField = terror.ErrCode(mysql.ErrFieldTypeNotAllowedAsPartitionField)
codePrimaryCantHaveNull = terror.ErrCode(mysql.ErrPrimaryCantHaveNull)
)
func init() {
@ -636,6 +639,7 @@ func init() {
codePartitionFunctionIsNotAllowed: mysql.ErrPartitionFunctionIsNotAllowed,
codeErrPartitionFuncNotAllowed: mysql.ErrPartitionFuncNotAllowed,
codeErrFieldTypeNotAllowedAsPartitionField: mysql.ErrFieldTypeNotAllowedAsPartitionField,
codePrimaryCantHaveNull: mysql.ErrPrimaryCantHaveNull,
}
terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes
}

View File

@ -173,8 +173,16 @@ func buildColumnsAndConstraints(ctx sessionctx.Context, colDefs []*ast.ColumnDef
constraints []*ast.Constraint) ([]*table.Column, []*ast.Constraint, error) {
var cols []*table.Column
colMap := map[string]*table.Column{}
// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id));
var outPriKeyConstraint *ast.Constraint
for _, v := range constraints {
if v.Tp == ast.ConstraintPrimaryKey {
outPriKeyConstraint = v
break
}
}
for i, colDef := range colDefs {
col, cts, err := buildColumnAndConstraint(ctx, i, colDef)
col, cts, err := buildColumnAndConstraint(ctx, i, colDef, outPriKeyConstraint)
if err != nil {
return nil, nil, errors.Trace(err)
}
@ -229,13 +237,14 @@ func setCharsetCollationFlenDecimal(tp *types.FieldType) error {
return nil
}
// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id));
func buildColumnAndConstraint(ctx sessionctx.Context, offset int,
colDef *ast.ColumnDef) (*table.Column, []*ast.Constraint, error) {
colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) {
err := setCharsetCollationFlenDecimal(colDef.Tp)
if err != nil {
return nil, nil, errors.Trace(err)
}
col, cts, err := columnDefToCol(ctx, offset, colDef)
col, cts, err := columnDefToCol(ctx, offset, colDef, outPriKeyConstraint)
if err != nil {
return nil, nil, errors.Trace(err)
}
@ -263,7 +272,8 @@ func isExplicitTimeStamp() bool {
}
// columnDefToCol converts ColumnDef to Col and TableConstraints.
func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (*table.Column, []*ast.Constraint, error) {
// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id));
func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) {
var constraints = make([]*ast.Constraint, 0)
col := table.ToColumn(&model.ColumnInfo{
Offset: offset,
@ -282,6 +292,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (
setOnUpdateNow := false
hasDefaultValue := false
hasNullFlag := false
if colDef.Options != nil {
length := types.UnspecifiedLength
@ -299,6 +310,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (
case ast.ColumnOptionNull:
col.Flag &= ^mysql.NotNullFlag
removeOnUpdateNowFlag(col)
hasNullFlag = true
case ast.ColumnOptionAutoIncrement:
col.Flag |= mysql.AutoIncrementFlag
case ast.ColumnOptionPrimaryKey:
@ -367,7 +379,11 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (
col.Flag &= ^mysql.BinaryFlag
col.Flag |= mysql.ZerofillFlag
}
err := checkDefaultValue(ctx, col, hasDefaultValue)
err := checkPriKeyConstraint(col, hasDefaultValue, hasNullFlag, outPriKeyConstraint)
if err != nil {
return nil, nil, errors.Trace(err)
}
err = checkDefaultValue(ctx, col, hasDefaultValue)
if err != nil {
return nil, nil, errors.Trace(err)
}
@ -475,6 +491,10 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
}
return nil
}
// Primary key default null is invalid.
if mysql.HasPriKeyFlag(c.Flag) {
return ErrPrimaryCantHaveNull
}
// Set not null but default null is invalid.
if mysql.HasNotNullFlag(c.Flag) {
@ -484,6 +504,30 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
return nil
}
// checkPriKeyConstraint check all parts of a PRIMARY KEY must be NOT NULL
func checkPriKeyConstraint(col *table.Column, hasDefaultValue, hasNullFlag bool, outPriKeyConstraint *ast.Constraint) error {
// Primary key should not be null.
if mysql.HasPriKeyFlag(col.Flag) && hasDefaultValue && col.DefaultValue == nil {
return types.ErrInvalidDefault.GenByArgs(col.Name)
}
// Set primary key flag for outer primary key constraint.
// Such as: create table t1 (id int , age int, primary key(id))
if !mysql.HasPriKeyFlag(col.Flag) && outPriKeyConstraint != nil {
for _, key := range outPriKeyConstraint.Keys {
if key.Column.Name.L != col.Name.L {
continue
}
col.Flag |= mysql.PriKeyFlag
break
}
}
// Primary key should not be null.
if mysql.HasPriKeyFlag(col.Flag) && hasNullFlag {
return ErrPrimaryCantHaveNull
}
return nil
}
func checkDuplicateColumn(colDefs []*ast.ColumnDef) error {
colNames := map[string]bool{}
for _, colDef := range colDefs {
@ -1177,7 +1221,7 @@ func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTab
// Ingore table constraints now, maybe return error later.
// We use length(t.Cols()) as the default offset firstly, we will change the
// column's offset later.
col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn)
col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn, nil)
if err != nil {
return errors.Trace(err)
}