Merge pull request #72 from pingcap/coocood/constraint-name

ddl: check constraint name.
This commit is contained in:
Ewan Chou
2015-09-09 08:23:14 +08:00
2 changed files with 72 additions and 3 deletions

View File

@ -19,6 +19,7 @@ package ddl
import (
"encoding/json"
"fmt"
"strings"
"sync"
@ -250,13 +251,45 @@ func (d *ddl) buildColumnAndConstraint(offset int, colDef *coldef.ColumnDef) (*c
}
func checkDuplicateColumn(colDefs []*coldef.ColumnDef) error {
m := map[string]bool{}
colNames := map[string]bool{}
for _, colDef := range colDefs {
nameLower := strings.ToLower(colDef.Name)
if m[nameLower] {
if colNames[nameLower] {
return errors.Errorf("CREATE TABLE: duplicate column %s", colDef.Name)
}
m[nameLower] = true
colNames[nameLower] = true
}
return nil
}
func checkConstraintNames(constraints []*coldef.TableConstraint) error {
constrNames := map[string]bool{}
// Check not empty constraint name do not have duplication.
for _, constr := range constraints {
if constr.ConstrName != "" {
nameLower := strings.ToLower(constr.ConstrName)
if constrNames[nameLower] {
return errors.Errorf("CREATE TABLE: duplicate key %s", constr.ConstrName)
}
constrNames[nameLower] = true
}
}
// Set empty constraint names.
for _, constr := range constraints {
if constr.ConstrName == "" && len(constr.Keys) > 0 {
colName := constr.Keys[0].ColumnName
constrName := colName
i := 2
for constrNames[strings.ToLower(constrName)] {
// We loop forever until we find constrName that haven't been used.
constrName = fmt.Sprintf("%s_%d", colName, i)
i++
}
constr.ConstrName = constrName
constrNames[constrName] = true
}
}
return nil
}
@ -321,6 +354,11 @@ func (d *ddl) CreateTable(ctx context.Context, ident table.Ident, colDefs []*col
return errors.Trace(err)
}
err = checkConstraintNames(newConstraints)
if err != nil {
return errors.Trace(err)
}
tbInfo, err := d.buildTableInfo(ident.Name, cols, newConstraints)
if err != nil {
return errors.Trace(err)

View File

@ -134,6 +134,37 @@ func (ts *testSuite) TestT(c *C) {
c.Assert(err, IsNil)
}
func (ts *testSuite) TestConstraintNames(c *C) {
handle := infoschema.NewHandle(ts.store)
handle.Set(nil)
dd := ddl.NewDDL(ts.store, handle)
se, _ := tidb.CreateSession(ts.store)
ctx := se.(context.Context)
schemaName := model.NewCIStr("test")
tblName := model.NewCIStr("t")
tbIdent := table.Ident{
Schema: schemaName,
Name: tblName,
}
err := dd.CreateSchema(ctx, tbIdent.Schema)
c.Assert(err, IsNil)
tbStmt := statement("create table t (a int, b int, index a (a, b), index a (a))").(*stmts.CreateTableStmt)
err = dd.CreateTable(ctx, tbIdent, tbStmt.Cols, tbStmt.Constraints)
c.Assert(err, NotNil)
tbStmt = statement("create table t (a int, b int, index A (a, b), index (a))").(*stmts.CreateTableStmt)
err = dd.CreateTable(ctx, tbIdent, tbStmt.Cols, tbStmt.Constraints)
c.Assert(err, IsNil)
tbl, err := handle.Get().TableByName(schemaName, tblName)
indices := tbl.Indices()
c.Assert(len(indices), Equals, 2)
c.Assert(indices[0].Name.O, Equals, "A")
c.Assert(indices[1].Name.O, Equals, "a_2")
err = dd.DropSchema(ctx, tbIdent.Schema)
c.Assert(err, IsNil)
}
func statement(sql string) stmt.Statement {
lexer := parser.NewLexer(sql)
parser.YYParse(lexer)