ddl: check constraint name.

Fixes #66 and also check duplication.
This commit is contained in:
Ewan Chou
2015-09-08 18:25:00 +08:00
parent 8920a55dab
commit 7fc744e435
2 changed files with 68 additions and 0 deletions

View File

@ -19,6 +19,7 @@ package ddl
import (
"encoding/json"
"fmt"
"strings"
"sync"
@ -261,6 +262,37 @@ func checkDuplicateColumn(colDefs []*coldef.ColumnDef) error {
return nil
}
func checkConstraintNames(constraints []*coldef.TableConstraint) error {
m := map[string]bool{}
// check not empty constraint name do not have duplication.
for _, constr := range constraints {
if constr.ConstrName != "" {
nameLower := strings.ToLower(constr.ConstrName)
if m[nameLower] {
return errors.Errorf("CREATE TABLE: duplicate key %s", constr.ConstrName)
}
m[nameLower] = true
}
}
// set empty constraint names.
for _, constr := range constraints {
if constr.ConstrName == "" && len(constr.Keys) > 0 {
colName := constr.Keys[0].ColumnName
constrName := colName
i := 2
for m[strings.ToLower(constrName)] {
constrName = fmt.Sprintf("%s_%d", colName, i)
i++
}
constr.ConstrName = constrName
m[constrName] = true
}
}
return nil
}
func (d *ddl) buildTableInfo(tableName model.CIStr, cols []*column.Col, constraints []*coldef.TableConstraint) (tbInfo *model.TableInfo, err error) {
tbInfo = &model.TableInfo{
Name: tableName,
@ -321,6 +353,11 @@ func (d *ddl) CreateTable(ctx context.Context, ident table.Ident, colDefs []*col
return errors.Trace(err)
}
err = checkConstraintNames(newConstraints)
if err != nil {
return errors.Trace(err)
}
tbInfo, err := d.buildTableInfo(ident.Name, cols, newConstraints)
if err != nil {
return errors.Trace(err)

View File

@ -134,6 +134,37 @@ func (ts *testSuite) TestT(c *C) {
c.Assert(err, IsNil)
}
func (ts *testSuite) TestConstraintNames(c *C) {
handle := infoschema.NewHandle(ts.store)
handle.Set(nil)
dd := ddl.NewDDL(ts.store, handle)
se, _ := tidb.CreateSession(ts.store)
ctx := se.(context.Context)
schemaName := model.NewCIStr("test")
tblName := model.NewCIStr("t")
tbIdent := table.Ident{
Schema: schemaName,
Name: tblName,
}
err := dd.CreateSchema(ctx, tbIdent.Schema)
c.Assert(err, IsNil)
tbStmt := statement("create table t (a int, b int, index a (a, b), index a (a))").(*stmts.CreateTableStmt)
err = dd.CreateTable(ctx, tbIdent, tbStmt.Cols, tbStmt.Constraints)
c.Assert(err, NotNil)
tbStmt = statement("create table t (a int, b int, index A (a, b), index (a))").(*stmts.CreateTableStmt)
err = dd.CreateTable(ctx, tbIdent, tbStmt.Cols, tbStmt.Constraints)
c.Assert(err, IsNil)
tbl, err := handle.Get().TableByName(schemaName, tblName)
indices := tbl.Indices()
c.Assert(len(indices), Equals, 2)
c.Assert(indices[0].Name.O, Equals, "A")
c.Assert(indices[1].Name.O, Equals, "a_2")
err = dd.DropSchema(ctx, tbIdent.Schema)
c.Assert(err, IsNil)
}
func statement(sql string) stmt.Statement {
lexer := parser.NewLexer(sql)
parser.YYParse(lexer)