diff --git a/ddl/integration_test.go b/ddl/integration_test.go index 17e4ab14cb..f26620aec9 100644 --- a/ddl/integration_test.go +++ b/ddl/integration_test.go @@ -18,6 +18,7 @@ import ( "github.com/juju/errors" . "github.com/pingcap/check" + "github.com/pingcap/tidb/ddl" "github.com/pingcap/tidb/domain" "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/session" @@ -76,6 +77,25 @@ func (s *testIntegrationSuite) TestInvalidDefault(c *C) { c.Assert(terror.ErrorEqual(err, types.ErrInvalidDefault), IsTrue) } +// for issue #3848 +func (s *testIntegrationSuite) TestInvalidNameWhenCreateTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + + tk.MustExec("USE test;") + + _, err := tk.Exec("create table t(xxx.t.a bigint)") + c.Assert(err, NotNil) + c.Assert(terror.ErrorEqual(err, ddl.ErrWrongDBName), IsTrue) + + _, err = tk.Exec("create table t(test.tttt.a bigint)") + c.Assert(err, NotNil) + c.Assert(terror.ErrorEqual(err, ddl.ErrWrongTableName), IsTrue) + + _, err = tk.Exec("create table t(t.tttt.a bigint)") + c.Assert(err, NotNil) + c.Assert(terror.ErrorEqual(err, ddl.ErrWrongDBName), IsTrue) +} + func newStoreWithBootstrap() (kv.Storage, *domain.Domain, error) { store, err := mockstore.NewMockTikvStore() if err != nil { diff --git a/plan/preprocess.go b/plan/preprocess.go index 2d2f50b5f3..204f69342a 100644 --- a/plan/preprocess.go +++ b/plan/preprocess.go @@ -82,6 +82,7 @@ func (p *preprocessor) Leave(in ast.Node) (out ast.Node, ok bool) { case *ast.CreateTableStmt: p.inCreateOrDropTable = false p.checkAutoIncrement(x) + p.checkContainDotColumn(x) case *ast.DropTableStmt, *ast.AlterTableStmt, *ast.RenameTableStmt: p.inCreateOrDropTable = false case *ast.ParamMarkerExpr: @@ -238,7 +239,6 @@ func (p *preprocessor) checkCreateTableGrammar(stmt *ast.CreateTableStmt) { p.err = ddl.ErrWrongTableName.GenByArgs(tName) return } - p.checkContainDotColumn(stmt) countPrimaryKey := 0 for _, colDef := range stmt.Cols { if err := checkColumn(colDef); err != nil { @@ -496,9 +496,14 @@ func isIncorrectName(name string) bool { // for example :create table t (c1.c2 int default null). func (p *preprocessor) checkContainDotColumn(stmt *ast.CreateTableStmt) { tName := stmt.Table.Name.String() + sName := stmt.Table.Schema.String() for _, colDef := range stmt.Cols { - // check table name. + // check schema and table names. + if colDef.Name.Schema.O != sName && len(colDef.Name.Schema.O) != 0 { + p.err = ddl.ErrWrongDBName.GenByArgs(colDef.Name.Schema.O) + return + } if colDef.Name.Table.O != tName && len(colDef.Name.Table.O) != 0 { p.err = ddl.ErrWrongTableName.GenByArgs(colDef.Name.Table.O) return