ddl: add primary key default value check when create table (#7169)
This commit is contained in:
@ -937,7 +937,7 @@ func (s *testColumnSuite) colDefStrToFieldType(c *C, str string) *types.FieldTyp
|
||||
stmt, err := parser.New().ParseOneStmt(sqlA, "", "")
|
||||
c.Assert(err, IsNil)
|
||||
colDef := stmt.(*ast.AlterTableStmt).Specs[0].NewColumns[0]
|
||||
col, _, err := buildColumnAndConstraint(nil, 0, colDef)
|
||||
col, _, err := buildColumnAndConstraint(nil, 0, colDef, nil)
|
||||
c.Assert(err, IsNil)
|
||||
return &col.FieldType
|
||||
}
|
||||
|
||||
@ -187,6 +187,17 @@ func (s *testDBSuite) TestMySQLErrorCode(c *C) {
|
||||
s.testErrorCode(c, sql, tmysql.ErrWrongNameForIndex)
|
||||
sql = "create table t2(c1.c2 blob default null);"
|
||||
s.testErrorCode(c, sql, tmysql.ErrWrongTableName)
|
||||
sql = "create table t2 (id int default null primary key , age int);"
|
||||
s.testErrorCode(c, sql, tmysql.ErrInvalidDefault)
|
||||
sql = "create table t2 (id int null primary key , age int);"
|
||||
s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull)
|
||||
sql = "create table t2 (id int default null, age int, primary key(id));"
|
||||
s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull)
|
||||
sql = "create table t2 (id int null, age int, primary key(id));"
|
||||
s.testErrorCode(c, sql, tmysql.ErrPrimaryCantHaveNull)
|
||||
|
||||
sql = "create table t2 (id int primary key , age int);"
|
||||
s.tk.MustExec(sql)
|
||||
|
||||
// add column
|
||||
sql = "alter table test_error_code_succ add column c1 int"
|
||||
|
||||
@ -158,6 +158,8 @@ var (
|
||||
ErrWrongNameForIndex = terror.ClassDDL.New(codeWrongNameForIndex, mysql.MySQLErrName[mysql.ErrWrongNameForIndex])
|
||||
// ErrUnknownCharacterSet returns unknown character set.
|
||||
ErrUnknownCharacterSet = terror.ClassDDL.New(codeUnknownCharacterSet, "Unknown character set: '%s'")
|
||||
// ErrPrimaryCantHaveNull returns All parts of a PRIMARY KEY must be NOT NULL; if you need NULL in a key, use UNIQUE instead
|
||||
ErrPrimaryCantHaveNull = terror.ClassDDL.New(codePrimaryCantHaveNull, mysql.MySQLErrName[mysql.ErrPrimaryCantHaveNull])
|
||||
|
||||
// ErrNotAllowedTypeInPartition returns not allowed type error when creating table partiton with unsupport expression type.
|
||||
ErrNotAllowedTypeInPartition = terror.ClassDDL.New(codeErrFieldTypeNotAllowedAsPartitionField, mysql.MySQLErrName[mysql.ErrFieldTypeNotAllowedAsPartitionField])
|
||||
@ -592,6 +594,7 @@ const (
|
||||
codePartitionFunctionIsNotAllowed = terror.ErrCode(mysql.ErrPartitionFunctionIsNotAllowed)
|
||||
codeErrPartitionFuncNotAllowed = terror.ErrCode(mysql.ErrPartitionFuncNotAllowed)
|
||||
codeErrFieldTypeNotAllowedAsPartitionField = terror.ErrCode(mysql.ErrFieldTypeNotAllowedAsPartitionField)
|
||||
codePrimaryCantHaveNull = terror.ErrCode(mysql.ErrPrimaryCantHaveNull)
|
||||
)
|
||||
|
||||
func init() {
|
||||
@ -636,6 +639,7 @@ func init() {
|
||||
codePartitionFunctionIsNotAllowed: mysql.ErrPartitionFunctionIsNotAllowed,
|
||||
codeErrPartitionFuncNotAllowed: mysql.ErrPartitionFuncNotAllowed,
|
||||
codeErrFieldTypeNotAllowedAsPartitionField: mysql.ErrFieldTypeNotAllowedAsPartitionField,
|
||||
codePrimaryCantHaveNull: mysql.ErrPrimaryCantHaveNull,
|
||||
}
|
||||
terror.ErrClassToMySQLCodes[terror.ClassDDL] = ddlMySQLErrCodes
|
||||
}
|
||||
|
||||
@ -173,8 +173,16 @@ func buildColumnsAndConstraints(ctx sessionctx.Context, colDefs []*ast.ColumnDef
|
||||
constraints []*ast.Constraint) ([]*table.Column, []*ast.Constraint, error) {
|
||||
var cols []*table.Column
|
||||
colMap := map[string]*table.Column{}
|
||||
// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id));
|
||||
var outPriKeyConstraint *ast.Constraint
|
||||
for _, v := range constraints {
|
||||
if v.Tp == ast.ConstraintPrimaryKey {
|
||||
outPriKeyConstraint = v
|
||||
break
|
||||
}
|
||||
}
|
||||
for i, colDef := range colDefs {
|
||||
col, cts, err := buildColumnAndConstraint(ctx, i, colDef)
|
||||
col, cts, err := buildColumnAndConstraint(ctx, i, colDef, outPriKeyConstraint)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
@ -229,13 +237,14 @@ func setCharsetCollationFlenDecimal(tp *types.FieldType) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id));
|
||||
func buildColumnAndConstraint(ctx sessionctx.Context, offset int,
|
||||
colDef *ast.ColumnDef) (*table.Column, []*ast.Constraint, error) {
|
||||
colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) {
|
||||
err := setCharsetCollationFlenDecimal(colDef.Tp)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
col, cts, err := columnDefToCol(ctx, offset, colDef)
|
||||
col, cts, err := columnDefToCol(ctx, offset, colDef, outPriKeyConstraint)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
@ -263,7 +272,8 @@ func isExplicitTimeStamp() bool {
|
||||
}
|
||||
|
||||
// columnDefToCol converts ColumnDef to Col and TableConstraints.
|
||||
func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (*table.Column, []*ast.Constraint, error) {
|
||||
// outPriKeyConstraint is the primary key constraint out of column definition. such as: create table t1 (id int , age int, primary key(id));
|
||||
func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef, outPriKeyConstraint *ast.Constraint) (*table.Column, []*ast.Constraint, error) {
|
||||
var constraints = make([]*ast.Constraint, 0)
|
||||
col := table.ToColumn(&model.ColumnInfo{
|
||||
Offset: offset,
|
||||
@ -282,6 +292,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (
|
||||
|
||||
setOnUpdateNow := false
|
||||
hasDefaultValue := false
|
||||
hasNullFlag := false
|
||||
if colDef.Options != nil {
|
||||
length := types.UnspecifiedLength
|
||||
|
||||
@ -299,6 +310,7 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (
|
||||
case ast.ColumnOptionNull:
|
||||
col.Flag &= ^mysql.NotNullFlag
|
||||
removeOnUpdateNowFlag(col)
|
||||
hasNullFlag = true
|
||||
case ast.ColumnOptionAutoIncrement:
|
||||
col.Flag |= mysql.AutoIncrementFlag
|
||||
case ast.ColumnOptionPrimaryKey:
|
||||
@ -367,7 +379,11 @@ func columnDefToCol(ctx sessionctx.Context, offset int, colDef *ast.ColumnDef) (
|
||||
col.Flag &= ^mysql.BinaryFlag
|
||||
col.Flag |= mysql.ZerofillFlag
|
||||
}
|
||||
err := checkDefaultValue(ctx, col, hasDefaultValue)
|
||||
err := checkPriKeyConstraint(col, hasDefaultValue, hasNullFlag, outPriKeyConstraint)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
err = checkDefaultValue(ctx, col, hasDefaultValue)
|
||||
if err != nil {
|
||||
return nil, nil, errors.Trace(err)
|
||||
}
|
||||
@ -475,6 +491,10 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Primary key default null is invalid.
|
||||
if mysql.HasPriKeyFlag(c.Flag) {
|
||||
return ErrPrimaryCantHaveNull
|
||||
}
|
||||
|
||||
// Set not null but default null is invalid.
|
||||
if mysql.HasNotNullFlag(c.Flag) {
|
||||
@ -484,6 +504,30 @@ func checkDefaultValue(ctx sessionctx.Context, c *table.Column, hasDefaultValue
|
||||
return nil
|
||||
}
|
||||
|
||||
// checkPriKeyConstraint check all parts of a PRIMARY KEY must be NOT NULL
|
||||
func checkPriKeyConstraint(col *table.Column, hasDefaultValue, hasNullFlag bool, outPriKeyConstraint *ast.Constraint) error {
|
||||
// Primary key should not be null.
|
||||
if mysql.HasPriKeyFlag(col.Flag) && hasDefaultValue && col.DefaultValue == nil {
|
||||
return types.ErrInvalidDefault.GenByArgs(col.Name)
|
||||
}
|
||||
// Set primary key flag for outer primary key constraint.
|
||||
// Such as: create table t1 (id int , age int, primary key(id))
|
||||
if !mysql.HasPriKeyFlag(col.Flag) && outPriKeyConstraint != nil {
|
||||
for _, key := range outPriKeyConstraint.Keys {
|
||||
if key.Column.Name.L != col.Name.L {
|
||||
continue
|
||||
}
|
||||
col.Flag |= mysql.PriKeyFlag
|
||||
break
|
||||
}
|
||||
}
|
||||
// Primary key should not be null.
|
||||
if mysql.HasPriKeyFlag(col.Flag) && hasNullFlag {
|
||||
return ErrPrimaryCantHaveNull
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func checkDuplicateColumn(colDefs []*ast.ColumnDef) error {
|
||||
colNames := map[string]bool{}
|
||||
for _, colDef := range colDefs {
|
||||
@ -1177,7 +1221,7 @@ func (d *ddl) AddColumn(ctx sessionctx.Context, ti ast.Ident, spec *ast.AlterTab
|
||||
// Ingore table constraints now, maybe return error later.
|
||||
// We use length(t.Cols()) as the default offset firstly, we will change the
|
||||
// column's offset later.
|
||||
col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn)
|
||||
col, _, err = buildColumnAndConstraint(ctx, len(t.Cols()), specNewColumn, nil)
|
||||
if err != nil {
|
||||
return errors.Trace(err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user