diff --git a/ddl/ddl.go b/ddl/ddl.go index ed10f45d0d..f99ffba5d9 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -19,6 +19,7 @@ package ddl import ( "encoding/json" + "fmt" "strings" "sync" @@ -261,6 +262,37 @@ func checkDuplicateColumn(colDefs []*coldef.ColumnDef) error { return nil } +func checkConstraintNames(constraints []*coldef.TableConstraint) error { + m := map[string]bool{} + + // check not empty constraint name do not have duplication. + for _, constr := range constraints { + if constr.ConstrName != "" { + nameLower := strings.ToLower(constr.ConstrName) + if m[nameLower] { + return errors.Errorf("CREATE TABLE: duplicate key %s", constr.ConstrName) + } + m[nameLower] = true + } + } + + // set empty constraint names. + for _, constr := range constraints { + if constr.ConstrName == "" && len(constr.Keys) > 0 { + colName := constr.Keys[0].ColumnName + constrName := colName + i := 2 + for m[strings.ToLower(constrName)] { + constrName = fmt.Sprintf("%s_%d", colName, i) + i++ + } + constr.ConstrName = constrName + m[constrName] = true + } + } + return nil +} + func (d *ddl) buildTableInfo(tableName model.CIStr, cols []*column.Col, constraints []*coldef.TableConstraint) (tbInfo *model.TableInfo, err error) { tbInfo = &model.TableInfo{ Name: tableName, @@ -321,6 +353,11 @@ func (d *ddl) CreateTable(ctx context.Context, ident table.Ident, colDefs []*col return errors.Trace(err) } + err = checkConstraintNames(newConstraints) + if err != nil { + return errors.Trace(err) + } + tbInfo, err := d.buildTableInfo(ident.Name, cols, newConstraints) if err != nil { return errors.Trace(err) diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index dfbbe58a45..c5b61e4b4b 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -134,6 +134,37 @@ func (ts *testSuite) TestT(c *C) { c.Assert(err, IsNil) } +func (ts *testSuite) TestConstraintNames(c *C) { + handle := infoschema.NewHandle(ts.store) + handle.Set(nil) + dd := ddl.NewDDL(ts.store, handle) + se, _ := tidb.CreateSession(ts.store) + ctx := se.(context.Context) + schemaName := model.NewCIStr("test") + tblName := model.NewCIStr("t") + tbIdent := table.Ident{ + Schema: schemaName, + Name: tblName, + } + err := dd.CreateSchema(ctx, tbIdent.Schema) + c.Assert(err, IsNil) + tbStmt := statement("create table t (a int, b int, index a (a, b), index a (a))").(*stmts.CreateTableStmt) + err = dd.CreateTable(ctx, tbIdent, tbStmt.Cols, tbStmt.Constraints) + c.Assert(err, NotNil) + + tbStmt = statement("create table t (a int, b int, index A (a, b), index (a))").(*stmts.CreateTableStmt) + err = dd.CreateTable(ctx, tbIdent, tbStmt.Cols, tbStmt.Constraints) + c.Assert(err, IsNil) + tbl, err := handle.Get().TableByName(schemaName, tblName) + indices := tbl.Indices() + c.Assert(len(indices), Equals, 2) + c.Assert(indices[0].Name.O, Equals, "A") + c.Assert(indices[1].Name.O, Equals, "a_2") + + err = dd.DropSchema(ctx, tbIdent.Schema) + c.Assert(err, IsNil) +} + func statement(sql string) stmt.Statement { lexer := parser.NewLexer(sql) parser.YYParse(lexer)