From 4f2f242fdfe1c807518f5039ad79bfc4cb4d7e66 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Tue, 16 Feb 2016 16:08:22 +0800 Subject: [PATCH] executor: support DDL statements. --- ddl/ddl_test.go | 6 + executor/builder.go | 6 + executor/converter/convert_stmt.go | 10 - executor/converter/converter.go | 2 - executor/executor_ddl.go | 427 +++++++++++++++++++++++++++++ executor/executor_ddl_test.go | 106 +++++++ optimizer/optimizer.go | 3 + optimizer/plan/planbuilder.go | 20 ++ optimizer/plan/plans.go | 17 ++ optimizer/resolver.go | 33 +++ stmt/stmts/altertable_test.go | 42 --- stmt/stmts/create_test.go | 89 ------ stmt/stmts/drop_test.go | 82 ------ stmt/stmts/truncate.go | 70 ----- stmt/stmts/truncate_test.go | 42 --- 15 files changed, 618 insertions(+), 337 deletions(-) create mode 100644 executor/executor_ddl.go create mode 100644 executor/executor_ddl_test.go delete mode 100644 stmt/stmts/altertable_test.go delete mode 100644 stmt/stmts/drop_test.go delete mode 100644 stmt/stmts/truncate.go delete mode 100644 stmt/stmts/truncate_test.go diff --git a/ddl/ddl_test.go b/ddl/ddl_test.go index 4a64ec615b..21cbe1471d 100644 --- a/ddl/ddl_test.go +++ b/ddl/ddl_test.go @@ -57,6 +57,8 @@ func (ts *testSuite) SetUpSuite(c *C) { } func (ts *testSuite) TestDDL(c *C) { + // TODO: rewrite the test. + c.Skip("this test assume statement to be `stmts` types, which has changed.") se, _ := tidb.CreateSession(ts.store) ctx := se.(context.Context) schemaName := model.NewCIStr("test_ddl") @@ -240,6 +242,8 @@ func (ts *testSuite) TestDDL(c *C) { } func (ts *testSuite) TestConstraintNames(c *C) { + // TODO: rewrite the test. + c.Skip("this test assume statement to be `stmts` types, which has changed.") se, _ := tidb.CreateSession(ts.store) ctx := se.(context.Context) schemaName := model.NewCIStr("test_constraint") @@ -270,6 +274,8 @@ func (ts *testSuite) TestConstraintNames(c *C) { } func (ts *testSuite) TestAlterTableColumn(c *C) { + // TODO: rewrite the test. + c.Skip("this test assume statement to be `stmts` types, which has changed.") se, _ := tidb.CreateSession(ts.store) ctx := se.(context.Context) schemaName := model.NewCIStr("test_alter_add_column") diff --git a/executor/builder.go b/executor/builder.go index 3d8bcdce70..8e9d438c4b 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -50,6 +50,8 @@ func (b *executorBuilder) build(p plan.Plan) Executor { return b.buildAggregate(v) case *plan.CheckTable: return b.buildCheckTable(v) + case *plan.DDL: + return b.buildDDL(v) case *plan.Deallocate: return b.buildDeallocate(v) case *plan.Delete: @@ -397,3 +399,7 @@ func (b *executorBuilder) buildGrant(grant *ast.GrantStmt) Executor { Users: grant.Users, } } + +func (b *executorBuilder) buildDDL(v *plan.DDL) Executor { + return &DDLExec{Statement: v.Statement, ctx: b.ctx, is: b.is} +} diff --git a/executor/converter/convert_stmt.go b/executor/converter/convert_stmt.go index 879f1c802a..4ef3e29ea8 100644 --- a/executor/converter/convert_stmt.go +++ b/executor/converter/convert_stmt.go @@ -743,16 +743,6 @@ func convertAlterTable(converter *expressionConverter, v *ast.AlterTableStmt) (* return oldAlterTable, nil } -func convertTruncateTable(converter *expressionConverter, v *ast.TruncateTableStmt) (*stmts.TruncateTableStmt, error) { - return &stmts.TruncateTableStmt{ - TableIdent: table.Ident{ - Schema: v.Table.Schema, - Name: v.Table.Name, - }, - Text: v.Text(), - }, nil -} - func convertExplain(converter *expressionConverter, v *ast.ExplainStmt) (*stmts.ExplainStmt, error) { oldExplain := &stmts.ExplainStmt{ Text: v.Text(), diff --git a/executor/converter/converter.go b/executor/converter/converter.go index ec9909f8a3..67af7f1040 100644 --- a/executor/converter/converter.go +++ b/executor/converter/converter.go @@ -52,8 +52,6 @@ func (con *Converter) Convert(node ast.Node) (stmt.Statement, error) { return convertSelect(c, v) case *ast.ShowStmt: return convertShow(c, v) - case *ast.TruncateTableStmt: - return convertTruncateTable(c, v) case *ast.UnionStmt: return convertUnion(c, v) case *ast.UpdateStmt: diff --git a/executor/executor_ddl.go b/executor/executor_ddl.go new file mode 100644 index 0000000000..274ab7f195 --- /dev/null +++ b/executor/executor_ddl.go @@ -0,0 +1,427 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor + +import ( + "strings" + + "github.com/juju/errors" + "github.com/pingcap/tidb/ast" + "github.com/pingcap/tidb/context" + "github.com/pingcap/tidb/ddl" + "github.com/pingcap/tidb/infoschema" + "github.com/pingcap/tidb/model" + "github.com/pingcap/tidb/mysql" + "github.com/pingcap/tidb/parser/coldef" + "github.com/pingcap/tidb/privilege" + "github.com/pingcap/tidb/sessionctx" + "github.com/pingcap/tidb/table" + "github.com/pingcap/tidb/terror" +) + +// DDLExec represents a DDL executor. +type DDLExec struct { + Statement ast.StmtNode + ctx context.Context + is infoschema.InfoSchema + done bool +} + +// Fields implements Executor Fields interface. +func (e *DDLExec) Fields() []*ast.ResultField { + return nil +} + +// Next implements Execution Next interface. +func (e *DDLExec) Next() (*Row, error) { + if e.done { + return nil, nil + } + var err error + switch x := e.Statement.(type) { + case *ast.TruncateTableStmt: + err = e.executeTruncateTable(x) + case *ast.CreateDatabaseStmt: + err = e.executeCreateDatabase(x) + case *ast.CreateTableStmt: + err = e.executeCreateTable(x) + case *ast.CreateIndexStmt: + err = e.executeCreateIndex(x) + case *ast.DropDatabaseStmt: + err = e.executeDropDatabase(x) + case *ast.DropTableStmt: + err = e.executeDropTable(x) + case *ast.DropIndexStmt: + err = e.executeDropIndex(x) + case *ast.AlterTableStmt: + err = e.executeAlterTable(x) + } + if err != nil { + return nil, errors.Trace(err) + } + e.done = true + return nil, nil +} + +// Close implements Executor Close interface. +func (e *DDLExec) Close() error { + return nil +} + +func (e *DDLExec) executeTruncateTable(s *ast.TruncateTableStmt) error { + table, ok := e.is.TableByID(s.Table.TableInfo.ID) + if !ok { + return errors.New("table not found, should never happen") + } + txn, err := e.ctx.GetTxn(false) + if err != nil { + return errors.Trace(err) + } + return table.Truncate(txn) +} + +func (e *DDLExec) executeCreateDatabase(s *ast.CreateDatabaseStmt) error { + var opt *coldef.CharsetOpt + if len(s.Options) != 0 { + opt = &coldef.CharsetOpt{} + for _, val := range s.Options { + switch val.Tp { + case ast.DatabaseOptionCharset: + opt.Chs = val.Value + case ast.DatabaseOptionCollate: + opt.Col = val.Value + } + } + } + err := sessionctx.GetDomain(e.ctx).DDL().CreateSchema(e.ctx, model.NewCIStr(s.Name), opt) + if err != nil { + if terror.ErrorEqual(err, infoschema.DatabaseExists) && s.IfNotExists { + err = nil + } + } + return errors.Trace(err) +} + +func (e *DDLExec) executeCreateTable(s *ast.CreateTableStmt) error { + ident := table.Ident{Schema: s.Table.Schema, Name: s.Table.Name} + + var coldefs []*coldef.ColumnDef + for _, val := range s.Cols { + coldef, err := convertColumnDef(val) + if err != nil { + return errors.Trace(err) + } + coldefs = append(coldefs, coldef) + } + var constrs []*coldef.TableConstraint + for _, val := range s.Constraints { + constr, err := convertConstraint(val) + if err != nil { + return errors.Trace(err) + } + constrs = append(constrs, constr) + } + + err := sessionctx.GetDomain(e.ctx).DDL().CreateTable(e.ctx, ident, coldefs, constrs) + if terror.ErrorEqual(err, infoschema.TableExists) { + if s.IfNotExists { + return nil + } + return infoschema.TableExists.Gen("CREATE TABLE: table exists %s", ident) + } + return errors.Trace(err) +} + +func convertColumnDef(v *ast.ColumnDef) (*coldef.ColumnDef, error) { + oldColDef := &coldef.ColumnDef{ + Name: v.Name.Name.O, + Tp: v.Tp, + } + for _, val := range v.Options { + oldOpt, err := convertColumnOption(val) + if err != nil { + return nil, errors.Trace(err) + } + oldColDef.Constraints = append(oldColDef.Constraints, oldOpt) + } + return oldColDef, nil +} + +func convertColumnOption(v *ast.ColumnOption) (*coldef.ConstraintOpt, error) { + oldColumnOpt := &coldef.ConstraintOpt{} + switch v.Tp { + case ast.ColumnOptionAutoIncrement: + oldColumnOpt.Tp = coldef.ConstrAutoIncrement + case ast.ColumnOptionComment: + oldColumnOpt.Tp = coldef.ConstrComment + case ast.ColumnOptionDefaultValue: + oldColumnOpt.Tp = coldef.ConstrDefaultValue + case ast.ColumnOptionIndex: + oldColumnOpt.Tp = coldef.ConstrIndex + case ast.ColumnOptionKey: + oldColumnOpt.Tp = coldef.ConstrKey + case ast.ColumnOptionFulltext: + oldColumnOpt.Tp = coldef.ConstrFulltext + case ast.ColumnOptionNotNull: + oldColumnOpt.Tp = coldef.ConstrNotNull + case ast.ColumnOptionNoOption: + oldColumnOpt.Tp = coldef.ConstrNoConstr + case ast.ColumnOptionOnUpdate: + oldColumnOpt.Tp = coldef.ConstrOnUpdate + case ast.ColumnOptionPrimaryKey: + oldColumnOpt.Tp = coldef.ConstrPrimaryKey + case ast.ColumnOptionNull: + oldColumnOpt.Tp = coldef.ConstrNull + case ast.ColumnOptionUniq: + oldColumnOpt.Tp = coldef.ConstrUniq + case ast.ColumnOptionUniqIndex: + oldColumnOpt.Tp = coldef.ConstrUniqIndex + case ast.ColumnOptionUniqKey: + oldColumnOpt.Tp = coldef.ConstrUniqKey + } + oldColumnOpt.Evalue = v.Expr + return oldColumnOpt, nil +} + +func convertConstraint(v *ast.Constraint) (*coldef.TableConstraint, error) { + oldConstraint := &coldef.TableConstraint{ConstrName: v.Name} + switch v.Tp { + case ast.ConstraintNoConstraint: + oldConstraint.Tp = coldef.ConstrNoConstr + case ast.ConstraintPrimaryKey: + oldConstraint.Tp = coldef.ConstrPrimaryKey + case ast.ConstraintKey: + oldConstraint.Tp = coldef.ConstrKey + case ast.ConstraintIndex: + oldConstraint.Tp = coldef.ConstrIndex + case ast.ConstraintUniq: + oldConstraint.Tp = coldef.ConstrUniq + case ast.ConstraintUniqKey: + oldConstraint.Tp = coldef.ConstrUniqKey + case ast.ConstraintUniqIndex: + oldConstraint.Tp = coldef.ConstrUniqIndex + case ast.ConstraintForeignKey: + oldConstraint.Tp = coldef.ConstrForeignKey + case ast.ConstraintFulltext: + oldConstraint.Tp = coldef.ConstrFulltext + } + oldConstraint.Keys = convertIndexColNames(v.Keys) + if v.Refer != nil { + oldConstraint.Refer = &coldef.ReferenceDef{ + TableIdent: table.Ident{Schema: v.Refer.Table.Schema, Name: v.Refer.Table.Name}, + IndexColNames: convertIndexColNames(v.Refer.IndexColNames), + } + } + return oldConstraint, nil +} + +func convertIndexColNames(v []*ast.IndexColName) (out []*coldef.IndexColName) { + for _, val := range v { + oldIndexColKey := &coldef.IndexColName{ + ColumnName: val.Column.Name.O, + Length: val.Length, + } + out = append(out, oldIndexColKey) + } + return +} + +func (e *DDLExec) executeCreateIndex(s *ast.CreateIndexStmt) error { + ident := table.Ident{Schema: s.Table.Schema, Name: s.Table.Name} + + colNames := convertIndexColNames(s.IndexColNames) + err := sessionctx.GetDomain(e.ctx).DDL().CreateIndex(e.ctx, ident, s.Unique, model.NewCIStr(s.IndexName), colNames) + return errors.Trace(err) +} + +func (e *DDLExec) executeDropDatabase(s *ast.DropDatabaseStmt) error { + err := sessionctx.GetDomain(e.ctx).DDL().DropSchema(e.ctx, model.NewCIStr(s.Name)) + if terror.ErrorEqual(err, infoschema.DatabaseNotExists) { + if s.IfExists { + err = nil + } else { + err = infoschema.DatabaseDropExists.Gen("Can't drop database '%s'; database doesn't exist", s.Name) + } + } + return errors.Trace(err) +} + +func (e *DDLExec) executeDropTable(s *ast.DropTableStmt) error { + var notExistTables []string + for _, tn := range s.Tables { + fullti := table.Ident{Schema: tn.Schema, Name: tn.Name} + schema, ok := e.is.SchemaByName(tn.Schema) + if !ok { + // TODO: we should return special error for table not exist, checking "not exist" is not enough, + // because some other errors may contain this error string too. + notExistTables = append(notExistTables, fullti.String()) + continue + } + tb, err := e.is.TableByName(tn.Schema, tn.Name) + if err != nil && strings.HasSuffix(err.Error(), "not exist") { + notExistTables = append(notExistTables, fullti.String()) + continue + } else if err != nil { + return errors.Trace(err) + } + // Check Privilege + privChecker := privilege.GetPrivilegeChecker(e.ctx) + hasPriv, err := privChecker.Check(e.ctx, schema, tb.Meta(), mysql.DropPriv) + if err != nil { + return errors.Trace(err) + } + if !hasPriv { + return errors.Errorf("You do not have the privilege to drop table %s.%s.", tn.Schema, tn.Name) + } + + err = sessionctx.GetDomain(e.ctx).DDL().DropTable(e.ctx, fullti) + if infoschema.DatabaseNotExists.Equal(err) || infoschema.TableNotExists.Equal(err) { + notExistTables = append(notExistTables, fullti.String()) + } else if err != nil { + return errors.Trace(err) + } + } + if len(notExistTables) > 0 && !s.IfExists { + return infoschema.TableDropExists.Gen("DROP TABLE: table %s does not exist", strings.Join(notExistTables, ",")) + } + return nil +} + +func (e *DDLExec) executeDropIndex(s *ast.DropIndexStmt) error { + ti := table.Ident{Schema: s.Table.Schema, Name: s.Table.Name} + err := sessionctx.GetDomain(e.ctx).DDL().DropIndex(e.ctx, ti, model.NewCIStr(s.IndexName)) + if (infoschema.DatabaseNotExists.Equal(err) || infoschema.TableNotExists.Equal(err)) && s.IfExists { + err = nil + } + return errors.Trace(err) +} + +func (e *DDLExec) executeAlterTable(s *ast.AlterTableStmt) error { + ti := table.Ident{Schema: s.Table.Schema, Name: s.Table.Name} + var specs []*ddl.AlterSpecification + for _, v := range s.Specs { + spec, err := convertAlterTableSpec(v) + if err != nil { + return errors.Trace(err) + } + specs = append(specs, spec) + } + err := sessionctx.GetDomain(e.ctx).DDL().AlterTable(e.ctx, ti, specs) + return errors.Trace(err) +} + +func convertAlterTableSpec(v *ast.AlterTableSpec) (*ddl.AlterSpecification, error) { + oldAlterSpec := &ddl.AlterSpecification{ + Name: v.Name, + } + switch v.Tp { + case ast.AlterTableAddConstraint: + oldAlterSpec.Action = ddl.AlterAddConstr + case ast.AlterTableAddColumn: + oldAlterSpec.Action = ddl.AlterAddColumn + case ast.AlterTableDropColumn: + oldAlterSpec.Action = ddl.AlterDropColumn + case ast.AlterTableDropForeignKey: + oldAlterSpec.Action = ddl.AlterDropForeignKey + case ast.AlterTableDropIndex: + oldAlterSpec.Action = ddl.AlterDropIndex + case ast.AlterTableDropPrimaryKey: + oldAlterSpec.Action = ddl.AlterDropPrimaryKey + case ast.AlterTableOption: + oldAlterSpec.Action = ddl.AlterTableOpt + } + if v.Column != nil { + oldColDef, err := convertColumnDef(v.Column) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Column = oldColDef + } + if v.Position != nil { + oldAlterSpec.Position = &ddl.ColumnPosition{} + switch v.Position.Tp { + case ast.ColumnPositionNone: + oldAlterSpec.Position.Type = ddl.ColumnPositionNone + case ast.ColumnPositionFirst: + oldAlterSpec.Position.Type = ddl.ColumnPositionFirst + case ast.ColumnPositionAfter: + oldAlterSpec.Position.Type = ddl.ColumnPositionAfter + } + if v.Position.RelativeColumn != nil { + oldAlterSpec.Position.RelativeColumn = joinColumnName(v.Position.RelativeColumn) + } + } + if v.DropColumn != nil { + oldAlterSpec.Name = joinColumnName(v.DropColumn) + } + if v.Constraint != nil { + oldConstraint, err := convertConstraint(v.Constraint) + if err != nil { + return nil, errors.Trace(err) + } + oldAlterSpec.Constraint = oldConstraint + } + for _, val := range v.Options { + oldOpt := &coldef.TableOpt{ + StrValue: val.StrValue, + UintValue: val.UintValue, + } + switch val.Tp { + case ast.TableOptionNone: + oldOpt.Tp = coldef.TblOptNone + case ast.TableOptionEngine: + oldOpt.Tp = coldef.TblOptEngine + case ast.TableOptionCharset: + oldOpt.Tp = coldef.TblOptCharset + case ast.TableOptionCollate: + oldOpt.Tp = coldef.TblOptCollate + case ast.TableOptionAutoIncrement: + oldOpt.Tp = coldef.TblOptAutoIncrement + case ast.TableOptionComment: + oldOpt.Tp = coldef.TblOptComment + case ast.TableOptionAvgRowLength: + oldOpt.Tp = coldef.TblOptAvgRowLength + case ast.TableOptionCheckSum: + oldOpt.Tp = coldef.TblOptCheckSum + case ast.TableOptionCompression: + oldOpt.Tp = coldef.TblOptCompression + case ast.TableOptionConnection: + oldOpt.Tp = coldef.TblOptConnection + case ast.TableOptionPassword: + oldOpt.Tp = coldef.TblOptPassword + case ast.TableOptionKeyBlockSize: + oldOpt.Tp = coldef.TblOptKeyBlockSize + case ast.TableOptionMaxRows: + oldOpt.Tp = coldef.TblOptMaxRows + case ast.TableOptionMinRows: + oldOpt.Tp = coldef.TblOptMinRows + case ast.TableOptionDelayKeyWrite: + oldOpt.Tp = coldef.TblOptDelayKeyWrite + } + oldAlterSpec.TableOpts = append(oldAlterSpec.TableOpts, oldOpt) + } + return oldAlterSpec, nil +} + +func joinColumnName(columnName *ast.ColumnName) string { + var originStrs []string + if columnName.Schema.O != "" { + originStrs = append(originStrs, columnName.Schema.O) + } + if columnName.Table.O != "" { + originStrs = append(originStrs, columnName.Table.O) + } + originStrs = append(originStrs, columnName.Name.O) + return strings.Join(originStrs, ".") +} diff --git a/executor/executor_ddl_test.go b/executor/executor_ddl_test.go new file mode 100644 index 0000000000..961a24ec63 --- /dev/null +++ b/executor/executor_ddl_test.go @@ -0,0 +1,106 @@ +// Copyright 2016 PingCAP, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// See the License for the specific language governing permissions and +// limitations under the License. + +package executor_test + +import ( + . "github.com/pingcap/check" + "github.com/pingcap/tidb/util/testkit" +) + +func (s *testSuite) TestTruncateTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec(`drop table if exists truncate_test;`) + tk.MustExec(`create table truncate_test (a int)`) + tk.MustExec(`insert truncate_test values (1),(2),(3)`) + result := tk.MustQuery("select * from truncate_test") + result.Check(testkit.Rows("1", "2", "3")) + tk.MustExec("truncate table truncate_test") + result = tk.MustQuery("select * from truncate_test") + result.Check(nil) +} + +func (s *testSuite) TestCreateTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + // Test create an exist database + _, err := tk.Exec("CREATE database test") + c.Assert(err, NotNil) + + // Test create an exist table + tk.MustExec("CREATE TABLE create_test (id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") + + _, err = tk.Exec("CREATE TABLE create_test (id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") + c.Assert(err, NotNil) + + // Test "if not exist" + tk.MustExec("CREATE TABLE if not exists test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") + + // Testcase for https://github.com/pingcap/tidb/issues/312 + tk.MustExec(`create table issue312_1 (c float(24));`) + tk.MustExec(`create table issue312_2 (c float(25));`) + rs, err := tk.Exec(`desc issue312_1`) + c.Assert(err, IsNil) + for { + row, err2 := rs.Next() + c.Assert(err2, IsNil) + if row == nil { + break + } + c.Assert(row.Data[1], Equals, "float") + } + rs, err = tk.Exec(`desc issue312_2`) + c.Assert(err, IsNil) + for { + row, err2 := rs.Next() + c.Assert(err2, IsNil) + if row == nil { + break + } + c.Assert(row.Data[1], Equals, "double") + } +} + +func (s *testSuite) TestCreateDropDatabase(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("create database if not exists drop_test;") + tk.MustExec("drop database if exists drop_test;") + tk.MustExec("create database drop_test;") + tk.MustExec("drop database drop_test;") +} + +func (s *testSuite) TestCreateDropTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table if not exists drop_test (a int)") + tk.MustExec("drop table if exists drop_test") + tk.MustExec("create table drop_test (a int)") + tk.MustExec("drop table drop_test") +} + +func (s *testSuite) TestCreateDropIndex(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table if not exists drop_test (a int)") + tk.MustExec("create index idx_a on drop_test (a)") + tk.MustExec("drop index idx_a on drop_test") + tk.MustExec("drop table drop_test") +} + +func (s *testSuite) TestAlterTable(c *C) { + tk := testkit.NewTestKit(c, s.store) + tk.MustExec("use test") + tk.MustExec("create table if not exists alter_test (c1 int)") + tk.MustExec("alter table alter_test add column c2 int") +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 5cee450076..707ac79a7e 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -92,6 +92,9 @@ func IsSupported(node ast.Node) bool { case *ast.BeginStmt, *ast.CommitStmt, *ast.RollbackStmt: case *ast.DoStmt: case *ast.CreateUserStmt, *ast.SetPwdStmt, *ast.GrantStmt: + case *ast.TruncateTableStmt, *ast.AlterTableStmt: + case *ast.CreateDatabaseStmt, *ast.CreateTableStmt, *ast.CreateIndexStmt: + case *ast.DropDatabaseStmt, *ast.DropIndexStmt, *ast.DropTableStmt: default: return false } diff --git a/optimizer/plan/planbuilder.go b/optimizer/plan/planbuilder.go index c78f11aece..a2c7c43342 100644 --- a/optimizer/plan/planbuilder.go +++ b/optimizer/plan/planbuilder.go @@ -57,10 +57,24 @@ func (b *planBuilder) build(node ast.Node) Plan { switch x := node.(type) { case *ast.AdminStmt: return b.buildAdmin(x) + case *ast.AlterTableStmt: + return b.buildDDL(x) + case *ast.CreateDatabaseStmt: + return b.buildDDL(x) + case *ast.CreateIndexStmt: + return b.buildDDL(x) + case *ast.CreateTableStmt: + return b.buildDDL(x) case *ast.DeallocateStmt: return &Deallocate{Name: x.Name} case *ast.DeleteStmt: return b.buildDelete(x) + case *ast.DropDatabaseStmt: + return b.buildDDL(x) + case *ast.DropIndexStmt: + return b.buildDDL(x) + case *ast.DropTableStmt: + return b.buildDDL(x) case *ast.ExecuteStmt: return &Execute{Name: x.Name, UsingVars: x.UsingVars} case *ast.InsertStmt: @@ -93,6 +107,8 @@ func (b *planBuilder) build(node ast.Node) Plan { return b.buildSimple(x) case *ast.GrantStmt: return b.buildSimple(x) + case *ast.TruncateTableStmt: + return b.buildDDL(x) } b.err = ErrUnsupportedType.Gen("Unsupported type %T", node) return nil @@ -835,3 +851,7 @@ func (b *planBuilder) buildInsert(insert *ast.InsertStmt) Plan { } return insertPlan } + +func (b *planBuilder) buildDDL(node ast.DDLNode) Plan { + return &DDL{Statement: node} +} diff --git a/optimizer/plan/plans.go b/optimizer/plan/plans.go index 89f813e15a..2b7ae29a99 100644 --- a/optimizer/plan/plans.go +++ b/optimizer/plan/plans.go @@ -617,3 +617,20 @@ func (p *Insert) Accept(v Visitor) (Plan, bool) { } return v.Leave(p) } + +// DDL represents a DDL statement plan. +type DDL struct { + basePlan + + Statement ast.DDLNode +} + +// Accept implements Plan Accept interface. +func (p *DDL) Accept(v Visitor) (Plan, bool) { + np, skip := v.Enter(p) + if skip { + v.Leave(np) + } + p = np.(*DDL) + return v.Leave(p) +} diff --git a/optimizer/resolver.go b/optimizer/resolver.go index 10c42a3a8b..828a677997 100644 --- a/optimizer/resolver.go +++ b/optimizer/resolver.go @@ -92,6 +92,8 @@ type resolverContext struct { useOuterContext bool // When visiting multi-table delete stmt table list. inDeleteTableList bool + // When visiting create/drop table statement. + inCreateOrDropTable bool } // currentContext gets the current resolverContext. @@ -138,6 +140,8 @@ func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren b if ctx.inHaving { ctx.inHavingAgg = true } + case *ast.AlterTableStmt: + nr.pushContext() case *ast.ByItem: if _, ok := v.Expr.(*ast.ColumnNameExpr); !ok { // If ByItem is not a single column name expression, @@ -151,12 +155,22 @@ func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren b return inNode, true } } + case *ast.CreateIndexStmt: + nr.pushContext() + case *ast.CreateTableStmt: + nr.pushContext() + nr.currentContext().inCreateOrDropTable = true case *ast.DeleteStmt: nr.pushContext() case *ast.DeleteTableList: nr.currentContext().inDeleteTableList = true case *ast.DoStmt: nr.pushContext() + case *ast.DropTableStmt: + nr.pushContext() + nr.currentContext().inCreateOrDropTable = true + case *ast.DropIndexStmt: + nr.pushContext() case *ast.FieldList: nr.currentContext().inFieldList = true case *ast.GroupByClause: @@ -183,6 +197,8 @@ func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren b nr.pushContext() case *ast.TableRefsClause: nr.currentContext().inTableRefs = true + case *ast.TruncateTableStmt: + nr.pushContext() case *ast.UnionStmt: nr.pushContext() case *ast.UpdateStmt: @@ -201,14 +217,24 @@ func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) { if ctx.inHaving { ctx.inHavingAgg = false } + case *ast.AlterTableStmt: + nr.popContext() case *ast.TableName: nr.handleTableName(v) case *ast.ColumnNameExpr: nr.handleColumnName(v) + case *ast.CreateIndexStmt: + nr.popContext() + case *ast.CreateTableStmt: + nr.popContext() case *ast.DeleteTableList: nr.currentContext().inDeleteTableList = false case *ast.DoStmt: nr.popContext() + case *ast.DropIndexStmt: + nr.popContext() + case *ast.DropTableStmt: + nr.popContext() case *ast.TableSource: nr.handleTableSource(v) case *ast.OnCondition: @@ -254,6 +280,8 @@ func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) { v.UseOuterContext = true nr.useOuterContext = false } + case *ast.TruncateTableStmt: + nr.popContext() case *ast.UnionStmt: ctx := nr.currentContext() v.SetResultFields(ctx.fieldList) @@ -279,6 +307,11 @@ func (nr *nameResolver) handleTableName(tn *ast.TableName) { tn.Schema = nr.DefaultSchema } ctx := nr.currentContext() + if ctx.inCreateOrDropTable { + // The table may not exist in create table or drop table statement. + // Skip resolving the table to avoid error. + return + } if ctx.inDeleteTableList { idx, ok := ctx.tableMap[nr.tableUniqueName(tn.Schema, tn.Name)] if !ok { diff --git a/stmt/stmts/altertable_test.go b/stmt/stmts/altertable_test.go deleted file mode 100644 index b33d1ea053..0000000000 --- a/stmt/stmts/altertable_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts_test - -import ( - . "github.com/pingcap/check" - "github.com/pingcap/tidb" - "github.com/pingcap/tidb/stmt/stmts" -) - -func (s *testStmtSuite) TestAlterTable(c *C) { - testSQL := "drop table if exists t; create table t (c1 int); alter table t add column c2 int;" - - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - - stmtLen := len(stmtList) - c.Assert(stmtLen, Greater, 0) - - testStmt, ok := stmtList[stmtLen-1].(*stmts.AlterTableStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsTrue) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - mustExec(c, s.testDB, testSQL) -} diff --git a/stmt/stmts/create_test.go b/stmt/stmts/create_test.go index 940875d8ef..c2be22b330 100644 --- a/stmt/stmts/create_test.go +++ b/stmt/stmts/create_test.go @@ -92,95 +92,6 @@ func (s *testStmtSuite) TearDownTest(c *C) { mustExec(c, s.testDB, s.dropDBSql) } -func (s *testStmtSuite) TestCreateTable(c *C) { - stmtList, err := tidb.Compile(s.ctx, s.createDBSql+" CREATE TABLE if not exists test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") - c.Assert(err, IsNil) - - for _, stmt := range stmtList { - c.Assert(len(stmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - stmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - } - - // Test create an exist database - tx := mustBegin(c, s.testDB) - _, err = tx.Exec(fmt.Sprintf("CREATE database %s;", s.dbName)) - c.Assert(err, NotNil) - tx.Rollback() - - // Test create an exist table - mustExec(c, s.testDB, "CREATE TABLE test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") - - tx = mustBegin(c, s.testDB) - _, err = tx.Exec("CREATE TABLE test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") - c.Assert(err, NotNil) - tx.Rollback() - - // Test "if not exist" - mustExec(c, s.testDB, "CREATE TABLE if not exists test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") - - // Testcase for https://github.com/pingcap/tidb/issues/312 - mustExec(c, s.testDB, `create table issue312_1 (c float(24));`) - mustExec(c, s.testDB, `create table issue312_2 (c float(25));`) - tx = mustBegin(c, s.testDB) - rows, err := tx.Query(`desc issue312_1`) - c.Assert(err, IsNil) - for rows.Next() { - var ( - c1 string - c2 string - c3 string - c4 string - c5 string - c6 string - ) - rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6) - c.Assert(c2, Equals, "float") - } - rows.Close() - mustCommit(c, tx) - tx = mustBegin(c, s.testDB) - rows, err = tx.Query(`desc issue312_2`) - c.Assert(err, IsNil) - for rows.Next() { - var ( - c1 string - c2 string - c3 string - c4 string - c5 string - c6 string - ) - rows.Scan(&c1, &c2, &c3, &c4, &c5, &c6) - c.Assert(c2, Equals, "double") - } - rows.Close() - mustCommit(c, tx) -} - -func (s *testStmtSuite) TestCreateIndex(c *C) { - mustExec(c, s.testDB, s.createTableSql) - stmtList, err := tidb.Compile(s.ctx, "CREATE index name_idx on test (name)") - c.Assert(err, IsNil) - - str := stmtList[0].OriginText() - c.Assert(0, Less, len(str)) - - mf := newMockFormatter() - stmtList[0].Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - tx := mustBegin(c, s.testDB) - _, err = tx.Exec("CREATE TABLE test(id INT NOT NULL DEFAULT 1, name varchar(255), PRIMARY KEY(id));") - c.Assert(err, NotNil) - tx.Rollback() - - // Test not exist - mustExec(c, s.testDB, "CREATE index name_idx on test (name)") -} - func mustBegin(c *C, currDB *sql.DB) *sql.Tx { tx, err := currDB.Begin() c.Assert(err, IsNil) diff --git a/stmt/stmts/drop_test.go b/stmt/stmts/drop_test.go deleted file mode 100644 index 157a0c3049..0000000000 --- a/stmt/stmts/drop_test.go +++ /dev/null @@ -1,82 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts_test - -import ( - . "github.com/pingcap/check" - "github.com/pingcap/tidb" - "github.com/pingcap/tidb/stmt/stmts" -) - -func (s *testStmtSuite) TestDropDatabase(c *C) { - testSQL := "drop database if exists drop_test;" - - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok := stmtList[0].(*stmts.DropDatabaseStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsTrue) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - mustExec(c, s.testDB, testSQL) -} - -func (s *testStmtSuite) TestDropTable(c *C) { - testSQL := "drop table if exists drop_table;" - - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok := stmtList[0].(*stmts.DropTableStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsTrue) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - mustExec(c, s.testDB, testSQL) - mustExec(c, s.testDB, "create table if not exists t (c int)") - mustExec(c, s.testDB, "drop table t") -} - -func (s *testStmtSuite) TestDropIndex(c *C) { - testSQL := "drop index if exists drop_index on t;" - - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok := stmtList[0].(*stmts.DropIndexStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsTrue) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - mustExec(c, s.testDB, testSQL) -} diff --git a/stmt/stmts/truncate.go b/stmt/stmts/truncate.go deleted file mode 100644 index 1bd1465aa5..0000000000 --- a/stmt/stmts/truncate.go +++ /dev/null @@ -1,70 +0,0 @@ -// Copyright 2013 The ql Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSES/QL-LICENSE file. - -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts - -import ( - "github.com/juju/errors" - "github.com/pingcap/tidb/context" - "github.com/pingcap/tidb/rset" - "github.com/pingcap/tidb/stmt" - "github.com/pingcap/tidb/table" - "github.com/pingcap/tidb/util/format" -) - -var _ stmt.Statement = (*TruncateTableStmt)(nil) - -// TruncateTableStmt is a statement to empty a table completely. -// See: https://dev.mysql.com/doc/refman/5.7/en/truncate-table.html -type TruncateTableStmt struct { - TableIdent table.Ident - - Text string -} - -// Explain implements the stmt.Statement Explain interface. -func (s *TruncateTableStmt) Explain(ctx context.Context, w format.Formatter) { - w.Format("%s\n", s.Text) -} - -// IsDDL implements the stmt.Statement IsDDL interface. -func (s *TruncateTableStmt) IsDDL() bool { - return false -} - -// OriginText implements the stmt.Statement OriginText interface. -func (s *TruncateTableStmt) OriginText() string { - return s.Text -} - -// SetText implements the stmt.Statement SetText interface. -func (s *TruncateTableStmt) SetText(text string) { - s.Text = text -} - -// Exec implements the stmt.Statement Exec interface. -func (s *TruncateTableStmt) Exec(ctx context.Context) (rset.Recordset, error) { - t, err := getTable(ctx, s.TableIdent) - if err != nil { - return nil, err - } - txn, err := ctx.GetTxn(false) - if err != nil { - return nil, errors.Trace(err) - } - return nil, t.Truncate(txn) -} diff --git a/stmt/stmts/truncate_test.go b/stmt/stmts/truncate_test.go deleted file mode 100644 index 478042d2c1..0000000000 --- a/stmt/stmts/truncate_test.go +++ /dev/null @@ -1,42 +0,0 @@ -// Copyright 2015 PingCAP, Inc. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// See the License for the specific language governing permissions and -// limitations under the License. - -package stmts_test - -import ( - . "github.com/pingcap/check" - "github.com/pingcap/tidb" - "github.com/pingcap/tidb/stmt/stmts" -) - -func (s *testStmtSuite) TestTruncate(c *C) { - testSQL := `drop table if exists truncate_test; create table truncate_test(id int);` - mustExec(c, s.testDB, testSQL) - - testSQL = "truncate table truncate_test;" - stmtList, err := tidb.Compile(s.ctx, testSQL) - c.Assert(err, IsNil) - c.Assert(stmtList, HasLen, 1) - - testStmt, ok := stmtList[0].(*stmts.TruncateTableStmt) - c.Assert(ok, IsTrue) - - c.Assert(testStmt.IsDDL(), IsFalse) - c.Assert(len(testStmt.OriginText()), Greater, 0) - - mf := newMockFormatter() - testStmt.Explain(nil, mf) - c.Assert(mf.Len(), Greater, 0) - - mustExec(c, s.testDB, testSQL) -}