diff --git a/ddl/ddl.go b/ddl/ddl.go index dd33031182..f99cd7f424 100644 --- a/ddl/ddl.go +++ b/ddl/ddl.go @@ -1206,10 +1206,6 @@ func getAnonymousIndex(t table.Table, colName model.CIStr) model.CIStr { } func (d *ddl) CreateIndex(ctx context.Context, ti ast.Ident, unique bool, indexName model.CIStr, idxColNames []*ast.IndexColName) error { - err := checkDuplicateColumnName(idxColNames) - if err != nil { - return errors.Trace(err) - } is := d.infoHandle.Get() schema, ok := is.SchemaByName(ti.Schema) if !ok { @@ -1433,20 +1429,6 @@ func findCol(cols []*model.ColumnInfo, name string) *model.ColumnInfo { return nil } -// checkDuplicateColumnName checks if index exists duplicated columns. -func checkDuplicateColumnName(indexColNames []*ast.IndexColName) error { - for i := 0; i < len(indexColNames); i++ { - name1 := indexColNames[i].Column.Name - for j := i + 1; j < len(indexColNames); j++ { - name2 := indexColNames[j].Column.Name - if name1.L == name2.L { - return infoschema.ErrColumnExists.GenByArgs(name2) - } - } - } - return nil -} - // DDL error codes. const ( codeInvalidWorker terror.ErrCode = 1 diff --git a/plan/validator.go b/plan/validator.go index 2113861809..1804225851 100644 --- a/plan/validator.go +++ b/plan/validator.go @@ -42,7 +42,7 @@ type validator struct { } func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { - switch in.(type) { + switch node := in.(type) { case *ast.AggregateFuncExpr: if v.inAggregate { // Aggregate function can not contain aggregate function. @@ -51,12 +51,17 @@ func (v *validator) Enter(in ast.Node) (out ast.Node, skipChildren bool) { } v.inAggregate = true case *ast.CreateTableStmt: - v.checkCreateTableGrammar(in.(*ast.CreateTableStmt)) + v.checkCreateTableGrammar(node) if v.err != nil { return in, true } case *ast.CreateIndexStmt: - v.checkCreateIndexGrammar(in.(*ast.CreateIndexStmt)) + v.checkCreateIndexGrammar(node) + if v.err != nil { + return in, true + } + case *ast.AlterTableStmt: + v.checkAlterTableGrammar(node) if v.err != nil { return in, true } @@ -196,6 +201,17 @@ func (v *validator) checkCreateTableGrammar(stmt *ast.CreateTableStmt) { return } } + for _, constraint := range stmt.Constraints { + switch tp := constraint.Tp; tp { + case ast.ConstraintKey: + err := checkDuplicateColumnName(constraint.Keys) + if err != nil { + v.err = err + return + } + } + } + } func isPrimary(ops []*ast.ColumnOption) int { @@ -208,14 +224,41 @@ func isPrimary(ops []*ast.ColumnOption) int { } func (v *validator) checkCreateIndexGrammar(stmt *ast.CreateIndexStmt) { - for i := 0; i < len(stmt.IndexColNames); i++ { - name1 := stmt.IndexColNames[i].Column.Name - for j := i + 1; j < len(stmt.IndexColNames); j++ { - name2 := stmt.IndexColNames[j].Column.Name - if name1.L == name2.L { - v.err = infoschema.ErrColumnExists.GenByArgs(name2) - return + v.err = checkDuplicateColumnName(stmt.IndexColNames) + return +} + +func (v *validator) checkAlterTableGrammar(stmt *ast.AlterTableStmt) { + specs := stmt.Specs + for _, spec := range specs { + switch spec.Tp { + case ast.AlterTableAddConstraint: + switch spec.Constraint.Tp { + case ast.ConstraintKey, ast.ConstraintIndex, ast.ConstraintUniq, ast.ConstraintUniqIndex, + ast.ConstraintUniqKey: + v.err = checkDuplicateColumnName(spec.Constraint.Keys) + if v.err != nil { + return + } + default: + // Nothing to do now. } + default: + // Nothing to do now. } } } + +// checkDuplicateColumnName checks if index exists duplicated columns. +func checkDuplicateColumnName(indexColNames []*ast.IndexColName) error { + for i := 0; i < len(indexColNames); i++ { + name1 := indexColNames[i].Column.Name + for j := i + 1; j < len(indexColNames); j++ { + name2 := indexColNames[j].Column.Name + if name1.L == name2.L { + return infoschema.ErrColumnExists.GenByArgs(name2) + } + } + } + return nil +} diff --git a/plan/validator_test.go b/plan/validator_test.go index 2e419c4b22..675c61841d 100644 --- a/plan/validator_test.go +++ b/plan/validator_test.go @@ -54,7 +54,9 @@ func (s *testValidatorSuite) TestValidator(c *C) { {"create table t(id int auto_increment) ENGINE=MYISAM", true, nil}, {"create table t(a int primary key, b int, c varchar(10), d char(256));", true, errors.New("Column length too big for column 'd' (max = 255); use BLOB or TEXT instead")}, {"create index ib on t(b,a,b);", true, errors.New("[schema:1060]Duplicate column name 'b'")}, + {"create table t (a int, b int, key(a, b, A))", true, errors.New("[schema:1060]Duplicate column name 'A'")}, {"create table t(c1 int not null primary key, c2 int not null primary key)", true, errors.New("Multiple primary key defined")}, + {"alter table t add index idx(a, b, A)", true, errors.New("[schema:1060]Duplicate column name 'A'")}, } store, err := tidb.NewStore(tidb.EngineGoLevelDBMemory) c.Assert(err, IsNil)