From 5bd601bdaea8f9ae1cd0a2f8fb477582cf93e3aa Mon Sep 17 00:00:00 2001 From: Lynn Date: Thu, 6 May 2021 21:01:52 +0800 Subject: [PATCH] ddl: disallow changing the default value of the primary key column to null (#24417) --- ddl/column_type_change_test.go | 14 +++++++++ ddl/ddl_api.go | 57 +++++++++++++++------------------- 2 files changed, 39 insertions(+), 32 deletions(-) diff --git a/ddl/column_type_change_test.go b/ddl/column_type_change_test.go index 68ee059f47..16c0c6f38e 100644 --- a/ddl/column_type_change_test.go +++ b/ddl/column_type_change_test.go @@ -1702,6 +1702,20 @@ func (s *testColumnTypeChangeSuite) TestChangingAttributeOfColumnWithFK(c *C) { tk.MustExec("drop table if exists orders, users") } +func (s *testColumnTypeChangeSuite) TestAlterPrimaryKeyToNull(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.Se.GetSessionVars().EnableChangeColumnType = true + + tk.MustExec("drop table if exists t, t1") + tk.MustExec("create table t(a int not null, b int not null, primary key(a, b));") + tk.MustGetErrCode("alter table t modify a bigint null;", mysql.ErrPrimaryCantHaveNull) + tk.MustGetErrCode("alter table t change column a a bigint null;", mysql.ErrPrimaryCantHaveNull) + tk.MustExec("create table t1(a int not null, b int not null, primary key(a));") + tk.MustGetErrCode("alter table t modify a bigint null;", mysql.ErrPrimaryCantHaveNull) + tk.MustGetErrCode("alter table t change column a a bigint null;", mysql.ErrPrimaryCantHaveNull) +} + // Close issue #23202 func (s *testColumnTypeChangeSuite) TestDDLExitWhenCancelMeetPanic(c *C) { tk := testkit.NewTestKit(c, s.store) diff --git a/ddl/ddl_api.go b/ddl/ddl_api.go index 138bb69756..0e1213e59c 100644 --- a/ddl/ddl_api.go +++ b/ddl/ddl_api.go @@ -690,24 +690,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, o } } - processDefaultValue(col, hasDefaultValue, setOnUpdateNow) - - processColumnFlags(col) - - err = checkPriKeyConstraint(col, hasDefaultValue, hasNullFlag, outPriKeyConstraint) - if err != nil { - return nil, nil, errors.Trace(err) - } - err = checkColumnValueConstraint(col, col.Collate) - if err != nil { - return nil, nil, errors.Trace(err) - } - err = checkDefaultValue(ctx, col, hasDefaultValue) - if err != nil { - return nil, nil, errors.Trace(err) - } - err = checkColumnFieldLength(col) - if err != nil { + if err = processAndCheckDefaultValueAndColumn(ctx, col, outPriKeyConstraint, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { return nil, nil, errors.Trace(err) } return col, constraints, nil @@ -3675,6 +3658,7 @@ func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []* var hasDefaultValue, setOnUpdateNow bool var err error + var hasNullFlag bool for _, opt := range options { switch opt.Tp { case ast.ColumnOptionDefaultValue: @@ -3690,6 +3674,7 @@ func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []* case ast.ColumnOptionNotNull: col.Flag |= mysql.NotNullFlag case ast.ColumnOptionNull: + hasNullFlag = true col.Flag &= ^mysql.NotNullFlag case ast.ColumnOptionAutoIncrement: col.Flag |= mysql.AutoIncrementFlag @@ -3734,17 +3719,33 @@ func processColumnOptions(ctx sessionctx.Context, col *table.Column, options []* } } - processDefaultValue(col, hasDefaultValue, setOnUpdateNow) - - processColumnFlags(col) - - if hasDefaultValue { - return errors.Trace(checkDefaultValue(ctx, col, true)) + if err = processAndCheckDefaultValueAndColumn(ctx, col, nil, hasDefaultValue, setOnUpdateNow, hasNullFlag); err != nil { + return errors.Trace(err) } return nil } +func processAndCheckDefaultValueAndColumn(ctx sessionctx.Context, col *table.Column, outPriKeyConstraint *ast.Constraint, hasDefaultValue, setOnUpdateNow, hasNullFlag bool) error { + processDefaultValue(col, hasDefaultValue, setOnUpdateNow) + processColumnFlags(col) + + err := checkPriKeyConstraint(col, hasDefaultValue, hasNullFlag, outPriKeyConstraint) + if err != nil { + return errors.Trace(err) + } + if err = checkColumnValueConstraint(col, col.Collate); err != nil { + return errors.Trace(err) + } + if err = checkDefaultValue(ctx, col, hasDefaultValue); err != nil { + return errors.Trace(err) + } + if err = checkColumnFieldLength(col); err != nil { + return errors.Trace(err) + } + return nil +} + func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, originalColName model.CIStr, spec *ast.AlterTableSpec) (*model.Job, error) { specNewColumn := spec.NewColumns[0] @@ -3854,10 +3855,6 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or return nil, errors.Trace(err) } - if err = checkColumnValueConstraint(newCol, newCol.Collate); err != nil { - return nil, errors.Trace(err) - } - if err = checkModifyTypes(ctx, &col.FieldType, &newCol.FieldType, isColumnWithIndex(col.Name.L, t.Meta().Indices)); err != nil { if strings.Contains(err.Error(), "Unsupported modifying collation") { colErrMsg := "Unsupported modifying collation of column '%s' from '%s' to '%s' when index is defined on it." @@ -3894,10 +3891,6 @@ func (d *ddl) getModifiableColumnJob(ctx sessionctx.Context, ident ast.Ident, or modifyColumnTp = mysql.TypeNull } - if err = checkColumnFieldLength(newCol); err != nil { - return nil, err - } - if err = checkColumnWithIndexConstraint(t.Meta(), col.ColumnInfo, newCol.ColumnInfo); err != nil { return nil, err }