diff --git a/ddl/column_test.go b/ddl/column_test.go index 67f7e5ce1d..00c7cba126 100644 --- a/ddl/column_test.go +++ b/ddl/column_test.go @@ -937,7 +937,7 @@ func (s *testColumnSuite) colDefStrToFieldType(c *C, str string) *types.FieldTyp stmt, err := parser.New().ParseOneStmt(sqlA, "", "") c.Assert(err, IsNil) colDef := stmt.(*ast.AlterTableStmt).Specs[0].NewColumns[0] - col, _, err := buildColumnAndConstraint(nil, 0, colDef) + col, _, err := buildColumnAndConstraint(nil, 0, colDef, nil) c.Assert(err, IsNil) return &col.FieldType } diff --git a/ddl/db_test.go b/ddl/db_test.go index 79e5695f74..3a662005c8 100644 --- a/ddl/db_test.go +++ b/ddl/db_test.go @@ -187,6 +187,17 @@ func (s *testDBSuite) TestMySQLErrorCode(c *C) { s.testErrorCode(c, sql, tmysql.ErrWrongNameForIndex) sql = "create table t2(c1.c2 blob default null);" s.testErrorCode(c, sql, tmysql.ErrWrongTableName) + sql = "create table t2 (id int default null primary key , age int);" + s.testErrorCode(c, sql, tmysql.ErrInvalidDefault) + sql = "create table t2 (id int null primary key , age int);" + s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull) + sql = "create table t2 (id int default null, age int, primary key(id));" + s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull) + sql = "create table t2 (id int null, age int, primary key(id));" + s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull) + + sql = "create table t2 (id int primary key , age int);" + s.tk.MustExec(sql) // add column sql = "alter table test_error_code_succ add column c1 int" diff --git a/ddl/ddl.go b/ddl/ddl.go index 465ffa004f..2943df7437 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -158,6 +158,8 @@ var ( ErrWrongNameForIndex = terror.ClassDDL.New(codeWrongNameForIndex, mysql.MySQLErrName[mysql.ErrWrongNameForIndex]) // ErrUnknownCharacterSet returns unknown character set. ErrUnknownCharacterSet = terror.ClassDDL.New(codeUnknownCharacterSet, "Unknown character set: '%s'") + // ErrPrimaryCantHaveNull returns All parts of a PRIMARY KEY must be NOT NULL; if you need NULL in a key, use UNIQUE instead + ErrPrimaryCantHaveNull = terror.ClassDDL.New(codePrimaryCantHaveNull, mysql.MySQLErrName[mysql.ErrPrimaryCantHaveNull]) // ErrNotAllowedTypeInPartition returns not allowed type error when creating table partiton with unsupport expression type. ErrNotAllowedTypeInPartition = terror.ClassDDL.New(codeErrFieldTypeNotAllowedAsPartitionField, mysql.MySQLErrName[mysql.ErrFieldTypeNotAllowedAsPartitionField]) @@ -592,6 +594,7 @@ const ( codePartitionFunctionIsNotAllowed = terror.ErrCode(mysql.ErrPartitionFunctionIsNotAllowed) codeErrPartitionFuncNotAllowed = terror.ErrCode(mysql.ErrPartitionFuncNotAllowed) codeErrFieldTypeNotAllowedAsPartitionField = terror.ErrCode(mysql.ErrFieldTypeNotAllowedAsPartitionField) + codePrimaryCantHaveNull = terror.ErrCode(mysql.ErrPrimaryCantHaveNull) ) func init() { @@ -636,6 +639,7 @@ func init() { codePartitionFunctionIsNotAllowed: mysql.ErrPartitionFunctionIsNotAllowed, codeErrPartitionFuncNotAllowed: mysql.ErrPartitionFuncNotAllowed, codeErrFieldTypeNotAllowedAsPartitionField: mysql.ErrFieldTypeNotAllowedAsPartitionField, + codePrimaryCantHaveNull: mysql.ErrPrimaryCantHaveNull, } terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes } diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 034d8278f8..1f2533187c 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -173,8 +173,16 @@ func buildColumnsAndConstraints(ctx sessionctx.Context, colDefs []*ast.ColumnDef constraints []*ast.Constraint) ([]*table.Column, []*ast.Constraint, error) { var cols []*table.Column colMap := map[string]*table.Column{} + // outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); + var outPriKeyConstraint *ast.Constraint + for _, v := range constraints { + if v.Tp == ast.ConstraintPrimaryKey { + outPriKeyConstraint = v + break + } + } for i, colDef := range colDefs { - col, cts, err := buildColumnAndConstraint(ctx, i, colDef) + col, cts, err := buildColumnAndConstraint(ctx, i, colDef, outPriKeyConstraint) if err != nil { return nil, nil, errors.Trace(err) } @@ -229,13 +237,14 @@ func setCharsetCollationFlenDecimal(tp *types.FieldType) error { return nil } +// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); func buildColumnAndConstraint(ctx sessionctx.Context, offset int, - colDef *ast.ColumnDef) (*table.Column, []*ast.Constraint, error) { + colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) { err := setCharsetCollationFlenDecimal(colDef.Tp) if err != nil { return nil, nil, errors.Trace(err) } - col, cts, err := columnDefToCol(ctx, offset, colDef) + col, cts, err := columnDefToCol(ctx, offset, colDef, outPriKeyConstraint) if err != nil { return nil, nil, errors.Trace(err) } @@ -263,7 +272,8 @@ func isExplicitTimeStamp() bool { } // columnDefToCol converts ColumnDef to Col and TableConstraints. -func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (*table.Column, []*ast.Constraint, error) { +// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id)); +func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) { var constraints = make([]*ast.Constraint, 0) col := table.ToColumn(&model.ColumnInfo{ Offset: offset, @@ -282,6 +292,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) ( setOnUpdateNow := false hasDefaultValue := false + hasNullFlag := false if colDef.Options != nil { length := types.UnspecifiedLength @@ -299,6 +310,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) ( case ast.ColumnOptionNull: col.Flag &= ^mysql.NotNullFlag removeOnUpdateNowFlag(col) + hasNullFlag = true case ast.ColumnOptionAutoIncrement: col.Flag |= mysql.AutoIncrementFlag case ast.ColumnOptionPrimaryKey: @@ -367,7 +379,11 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) ( col.Flag &= ^mysql.BinaryFlag col.Flag |= mysql.ZerofillFlag } - err := checkDefaultValue(ctx, col, hasDefaultValue) + err := checkPriKeyConstraint(col, hasDefaultValue, hasNullFlag, outPriKeyConstraint) + if err != nil { + return nil, nil, errors.Trace(err) + } + err = checkDefaultValue(ctx, col, hasDefaultValue) if err != nil { return nil, nil, errors.Trace(err) } @@ -475,6 +491,10 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue } return nil } + // Primary key default null is invalid. + if mysql.HasPriKeyFlag(c.Flag) { + return ErrPrimaryCantHaveNull + } // Set not null but default null is invalid. if mysql.HasNotNullFlag(c.Flag) { @@ -484,6 +504,30 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue return nil } +// checkPriKeyConstraint check all parts of a PRIMARY KEY must be NOT NULL +func checkPriKeyConstraint(col *table.Column, hasDefaultValue, hasNullFlag bool, outPriKeyConstraint *ast.Constraint) error { + // Primary key should not be null. + if mysql.HasPriKeyFlag(col.Flag) && hasDefaultValue && col.DefaultValue == nil { + return types.ErrInvalidDefault.GenByArgs(col.Name) + } + // Set primary key flag for outer primary key constraint. + // Such as: create table t1 (id int , age int, primary key(id)) + if !mysql.HasPriKeyFlag(col.Flag) && outPriKeyConstraint != nil { + for _, key := range outPriKeyConstraint.Keys { + if key.Column.Name.L != col.Name.L { + continue + } + col.Flag |= mysql.PriKeyFlag + break + } + } + // Primary key should not be null. + if mysql.HasPriKeyFlag(col.Flag) && hasNullFlag { + return ErrPrimaryCantHaveNull + } + return nil +} + func checkDuplicateColumn(colDefs []*ast.ColumnDef) error { colNames := map[string]bool{} for _, colDef := range colDefs { @@ -1177,7 +1221,7 @@ func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTab // Ingore table constraints now, maybe return error later. // We use length(t.Cols()) as the default offset firstly, we will change the // column's offset later. - col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn) + col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn, nil) if err != nil { return errors.Trace(err) }