From f164aa5f1e06d0ccbed643ba3c8187c9813e521b Mon Sep 17 00:00:00 2001 From: shenli Date: Sun, 7 Feb 2016 18:47:52 +0800 Subject: [PATCH] *: Support delete in new plan Move DeleteStmt from old plan to new plan. --- ast/dml.go | 53 ++++++++++---- executor/builder.go | 12 ++++ executor/converter/convert_stmt.go | 10 +-- executor/executor_write.go | 110 ++++++++++++++++++++++++++++- optimizer/optimizer.go | 2 +- optimizer/plan/planbuilder.go | 28 ++++++++ optimizer/plan/plans.go | 24 +++++++ optimizer/resolver.go | 26 ++++++- parser/parser.go | 26 +++---- parser/parser.y | 8 +-- stmt/stmts/delete_test.go | 14 ---- 11 files changed, 260 insertions(+), 53 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index 1402cf9699..c9fdf1f316 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -120,6 +120,31 @@ func (n *TableName) Accept(v Visitor) (Node, bool) { return v.Leave(n) } +// DeleteTableList is the tablelist used in delete statement multi-table mode. +type DeleteTableList struct { + node + Tables []*TableName +} + +// Accept implements Node Accept interface. +func (n *DeleteTableList) Accept(v Visitor) (Node, bool) { + newNode, skipChildren := v.Enter(n) + if skipChildren { + return v.Leave(newNode) + } + n = newNode.(*DeleteTableList) + if n != nil { + for i, t := range n.Tables { + node, ok := t.Accept(v) + if !ok { + return n, false + } + n.Tables[i] = node.(*TableName) + } + } + return v.Leave(n) +} + // OnCondition represetns JOIN on condition. type OnCondition struct { node @@ -652,15 +677,15 @@ type DeleteStmt struct { // Used in both single table and multiple table delete statement. TableRefs *TableRefsClause // Only used in multiple table delete statement. - Tables []*TableName - Where ExprNode - Order *OrderByClause - Limit *Limit - LowPriority bool - Ignore bool - Quick bool - MultiTable bool - BeforeFrom bool + Tables *DeleteTableList + Where ExprNode + Order *OrderByClause + Limit *Limit + LowPriority bool + Ignore bool + Quick bool + IsMultiTable bool + BeforeFrom bool } // Accept implements Node Accept interface. @@ -677,13 +702,11 @@ func (n *DeleteStmt) Accept(v Visitor) (Node, bool) { } n.TableRefs = node.(*TableRefsClause) - for i, val := range n.Tables { - node, ok = val.Accept(v) - if !ok { - return n, false - } - n.Tables[i] = node.(*TableName) + node, ok = n.Tables.Accept(v) + if !ok { + return n, false } + n.Tables = node.(*DeleteTableList) if n.Where != nil { node, ok = n.Where.Accept(v) diff --git a/executor/builder.go b/executor/builder.go index bc1b01bff4..74980390c2 100644 --- a/executor/builder.go +++ b/executor/builder.go @@ -51,6 +51,8 @@ func (b *executorBuilder) build(p plan.Plan) Executor { return b.buildCheckTable(v) case *plan.Deallocate: return b.buildDeallocate(v) + case *plan.Delete: + return b.buildDelete(v) case *plan.Execute: return b.buildExecute(v) case *plan.Filter: @@ -295,3 +297,13 @@ func (b *executorBuilder) buildUpdate(v *plan.Update) Executor { selExec := b.build(v.SelectPlan) return &UpdateExec{ctx: b.ctx, SelectExec: selExec, OrderedList: v.OrderedList} } + +func (b *executorBuilder) buildDelete(v *plan.Delete) Executor { + selExec := b.build(v.SelectPlan) + return &DeleteExec{ + ctx: b.ctx, + SelectExec: selExec, + Tables: v.Tables, + IsMultiTable: v.IsMultiTable, + } +} diff --git a/executor/converter/convert_stmt.go b/executor/converter/convert_stmt.go index 7d55ef2e67..96e6c35d15 100644 --- a/executor/converter/convert_stmt.go +++ b/executor/converter/convert_stmt.go @@ -103,7 +103,7 @@ func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.De BeforeFrom: v.BeforeFrom, Ignore: v.Ignore, LowPriority: v.LowPriority, - MultiTable: v.MultiTable, + MultiTable: v.IsMultiTable, Quick: v.Quick, Text: v.Text(), } @@ -112,9 +112,11 @@ func convertDelete(converter *expressionConverter, v *ast.DeleteStmt) (*stmts.De return nil, errors.Trace(err) } oldDelete.Refs = oldRefs - for _, val := range v.Tables { - tableIdent := table.Ident{Schema: val.Schema, Name: val.Name} - oldDelete.TableIdents = append(oldDelete.TableIdents, tableIdent) + if v.Tables != nil { + for _, val := range v.Tables.Tables { + tableIdent := table.Ident{Schema: val.Schema, Name: val.Name} + oldDelete.TableIdents = append(oldDelete.TableIdents, tableIdent) + } } if v.Where != nil { oldDelete.Where, err = convertExpr(converter, v.Where) diff --git a/executor/executor_write.go b/executor/executor_write.go index 53b79ea892..b9eec569b3 100644 --- a/executor/executor_write.go +++ b/executor/executor_write.go @@ -21,6 +21,7 @@ import ( "github.com/pingcap/tidb/kv" "github.com/pingcap/tidb/mysql" "github.com/pingcap/tidb/optimizer/evaluator" + "github.com/pingcap/tidb/sessionctx" "github.com/pingcap/tidb/sessionctx/variable" "github.com/pingcap/tidb/table" "github.com/pingcap/tidb/table/tables" @@ -29,6 +30,7 @@ import ( var ( _ Executor = &UpdateExec{} + _ Executor = &DeleteExec{} ) // UpdateExec represents an update executor. @@ -45,7 +47,7 @@ type UpdateExec struct { cursor int } -// Next implements Execution Next interface. +// Next implements Executor Next interface. func (e *UpdateExec) Next() (*Row, error) { if !e.fetched { err := e.fetchRows() @@ -234,3 +236,109 @@ func (e *UpdateExec) Fields() []*ast.ResultField { func (e *UpdateExec) Close() error { return e.SelectExec.Close() } + +// DeleteExec represents a delete executor. +// See: https://dev.mysql.com/doc/refman/5.7/en/delete.html +type DeleteExec struct { + SelectExec Executor + + ctx context.Context + Tables []*ast.TableName + IsMultiTable bool + + finished bool +} + +// Next implements Executor Next interface. +func (e *DeleteExec) Next() (*Row, error) { + if e.finished { + return nil, nil + } + defer func() { + e.finished = true + }() + if e.IsMultiTable && len(e.Tables) == 0 { + return &Row{}, nil + } + tblIDMap := make(map[int64]bool, len(e.Tables)) + // Get table alias map. + tblNames := make(map[string]string) + rowKeyMap := make(map[string]table.Table) + if e.IsMultiTable { + // Delete from multiple tables should consider table ident list. + fs := e.SelectExec.Fields() + for _, f := range fs { + if len(f.TableAsName.L) > 0 { + tblNames[f.TableAsName.L] = f.TableName.Name.L + } else { + tblNames[f.TableName.Name.L] = f.TableName.Name.L + } + } + for _, t := range e.Tables { + // Consider DBName. + _, ok := tblNames[t.Name.L] + if !ok { + return nil, errors.Errorf("Unknown table '%s' in MULTI DELETE", t.Name.O) + } + tblIDMap[t.TableInfo.ID] = true + } + } + for { + row, err := e.SelectExec.Next() + if err != nil { + return nil, errors.Trace(err) + } + if row == nil { + break + } + + for _, entry := range row.RowKeys { + if e.IsMultiTable { + tid := entry.Tbl.TableID() + if _, ok := tblIDMap[tid]; !ok { + continue + } + } + rowKeyMap[entry.Key] = entry.Tbl + } + } + for k, t := range rowKeyMap { + handle, err := tables.DecodeRecordKeyHandle(kv.Key(k)) + if err != nil { + return nil, errors.Trace(err) + } + data, err := t.Row(e.ctx, handle) + if err != nil { + return nil, errors.Trace(err) + } + err = e.removeRow(e.ctx, t, handle, data) + if err != nil { + return nil, errors.Trace(err) + } + } + return nil, nil +} + +func (e *DeleteExec) getTable(ctx context.Context, tableName *ast.TableName) (table.Table, error) { + return sessionctx.GetDomain(ctx).InfoSchema().TableByName(tableName.Schema, tableName.Name) +} + +func (e *DeleteExec) removeRow(ctx context.Context, t table.Table, h int64, data []interface{}) error { + err := t.RemoveRecord(ctx, h, data) + if err != nil { + return errors.Trace(err) + } + variable.GetSessionVars(ctx).AddAffectedRows(1) + return nil +} + +// Fields implements Executor Fields interface. +// Returns nil to indicate there is no output. +func (e *DeleteExec) Fields() []*ast.ResultField { + return nil +} + +// Close implements Executor Close interface. +func (e *DeleteExec) Close() error { + return e.SelectExec.Close() +} diff --git a/optimizer/optimizer.go b/optimizer/optimizer.go index 4ec6c9ae41..21d559c96d 100644 --- a/optimizer/optimizer.go +++ b/optimizer/optimizer.go @@ -91,7 +91,7 @@ func (c *supportChecker) Leave(in ast.Node) (ast.Node, bool) { func IsSupported(node ast.Node) bool { switch node.(type) { case *ast.SelectStmt, *ast.PrepareStmt, *ast.ExecuteStmt, *ast.DeallocateStmt, - *ast.AdminStmt, *ast.UpdateStmt: + *ast.AdminStmt, *ast.UpdateStmt, *ast.DeleteStmt: default: return false } diff --git a/optimizer/plan/planbuilder.go b/optimizer/plan/planbuilder.go index 7e528875aa..d7c8a4f1d0 100644 --- a/optimizer/plan/planbuilder.go +++ b/optimizer/plan/planbuilder.go @@ -59,6 +59,8 @@ func (b *planBuilder) build(node ast.Node) Plan { return b.buildAdmin(x) case *ast.DeallocateStmt: return &Deallocate{Name: x.Name} + case *ast.DeleteStmt: + return b.buildDelete(x) case *ast.ExecuteStmt: return &Execute{Name: x.Name, UsingVars: x.UsingVars} case *ast.PrepareStmt: @@ -647,6 +649,32 @@ func (b *planBuilder) buildUpdateLists(list []*ast.Assignment, fields []*ast.Res return newList } +func (b *planBuilder) buildDelete(del *ast.DeleteStmt) Plan { + sel := &ast.SelectStmt{From: del.TableRefs, Where: del.Where, OrderBy: del.Order, Limit: del.Limit} + p := b.buildFrom(sel) + if sel.OrderBy != nil && !matchOrder(p, sel.OrderBy.Items) { + p = b.buildSort(p, sel.OrderBy.Items) + if b.err != nil { + return nil + } + } + if sel.Limit != nil { + p = b.buildLimit(p, sel.Limit) + if b.err != nil { + return nil + } + } + var tables []*ast.TableName + if del.Tables != nil { + tables = del.Tables.Tables + } + return &Delete{ + Tables: tables, + IsMultiTable: del.IsMultiTable, + SelectPlan: p, + } +} + func columnOffsetInFields(cn *ast.ColumnName, fields []*ast.ResultField) (int, error) { offset := -1 tableNameL := cn.Table.L diff --git a/optimizer/plan/plans.go b/optimizer/plan/plans.go index a07fd9249e..d2db7fbf81 100644 --- a/optimizer/plan/plans.go +++ b/optimizer/plan/plans.go @@ -461,6 +461,30 @@ func (p *Update) Accept(v Visitor) (Plan, bool) { return v.Leave(p) } +// Delete represents a delete plan. +type Delete struct { + basePlan + + SelectPlan Plan + Tables []*ast.TableName + IsMultiTable bool +} + +// Accept implements Plan Accept interface. +func (p *Delete) Accept(v Visitor) (Plan, bool) { + np, skip := v.Enter(p) + if skip { + v.Leave(np) + } + p = np.(*Delete) + var ok bool + p.SelectPlan, ok = p.SelectPlan.Accept(v) + if !ok { + return p, false + } + return v.Leave(p) +} + // Filter represents a plan that filter srcplan result. type Filter struct { planWithSrc diff --git a/optimizer/resolver.go b/optimizer/resolver.go index 513783c698..105b740146 100644 --- a/optimizer/resolver.go +++ b/optimizer/resolver.go @@ -89,6 +89,8 @@ type resolverContext struct { inByItemExpression bool // If subquery use outer context. useOuterContext bool + // When visiting multi-table delete stmt table list. + inDeleteTableList bool } // currentContext gets the current resolverContext. @@ -128,6 +130,8 @@ func (nr *nameResolver) popJoin() { // Enter implements ast.Visitor interface. func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren bool) { switch v := inNode.(type) { + case *ast.AdminStmt: + nr.pushContext() case *ast.AggregateFuncExpr: ctx := nr.currentContext() if ctx.inHaving { @@ -148,6 +152,8 @@ func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren b } case *ast.DeleteStmt: nr.pushContext() + case *ast.DeleteTableList: + nr.currentContext().inDeleteTableList = true case *ast.FieldList: nr.currentContext().inFieldList = true case *ast.GroupByClause: @@ -175,6 +181,8 @@ func (nr *nameResolver) Enter(inNode ast.Node) (outNode ast.Node, skipChildren b // Leave implements ast.Visitor interface. func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) { switch v := inNode.(type) { + case *ast.AdminStmt: + nr.popContext() case *ast.AggregateFuncExpr: ctx := nr.currentContext() if ctx.inHaving { @@ -184,6 +192,8 @@ func (nr *nameResolver) Leave(inNode ast.Node) (node ast.Node, ok bool) { nr.handleTableName(v) case *ast.ColumnNameExpr: nr.handleColumnName(v) + case *ast.DeleteTableList: + nr.currentContext().inDeleteTableList = false case *ast.TableSource: nr.handleTableSource(v) case *ast.OnCondition: @@ -243,9 +253,23 @@ func (nr *nameResolver) handleTableName(tn *ast.TableName) { if tn.Schema.L == "" { tn.Schema = nr.DefaultSchema } + ctx := nr.currentContext() + if ctx.inDeleteTableList { + idx, ok := ctx.tableMap[nr.tableUniqueName(tn.Schema, tn.Name)] + if !ok { + nr.Err = errors.Errorf("Unknown table %s", tn.Name.O) + return + } + ts := ctx.tables[idx] + tableName := ts.Source.(*ast.TableName) + tn.DBInfo = tableName.DBInfo + tn.TableInfo = tableName.TableInfo + tn.SetResultFields(tableName.GetResultFields()) + return + } table, err := nr.Info.TableByName(tn.Schema, tn.Name) if err != nil { - nr.Err = err + nr.Err = errors.Trace(err) return } tn.TableInfo = table.Meta() diff --git a/parser/parser.go b/parser/parser.go index d13f287add..04528d5387 100644 --- a/parser/parser.go +++ b/parser/parser.go @@ -6317,13 +6317,13 @@ yynewstate: { // Multiple Table x := &ast.DeleteStmt{ - LowPriority: yyS[yypt-6].item.(bool), - Quick: yyS[yypt-5].item.(bool), - Ignore: yyS[yypt-4].item.(bool), - MultiTable: true, - BeforeFrom: true, - Tables: yyS[yypt-3].item.([]*ast.TableName), - TableRefs: &ast.TableRefsClause{TableRefs: yyS[yypt-1].item.(*ast.Join)}, + LowPriority: yyS[yypt-6].item.(bool), + Quick: yyS[yypt-5].item.(bool), + Ignore: yyS[yypt-4].item.(bool), + IsMultiTable: true, + BeforeFrom: true, + Tables: &ast.DeleteTableList{Tables: yyS[yypt-3].item.([]*ast.TableName)}, + TableRefs: &ast.TableRefsClause{TableRefs: yyS[yypt-1].item.(*ast.Join)}, } if yyS[yypt-0].item != nil { x.Where = yyS[yypt-0].item.(ast.ExprNode) @@ -6337,12 +6337,12 @@ yynewstate: { // Multiple Table x := &ast.DeleteStmt{ - LowPriority: yyS[yypt-7].item.(bool), - Quick: yyS[yypt-6].item.(bool), - Ignore: yyS[yypt-5].item.(bool), - MultiTable: true, - Tables: yyS[yypt-3].item.([]*ast.TableName), - TableRefs: &ast.TableRefsClause{TableRefs: yyS[yypt-1].item.(*ast.Join)}, + LowPriority: yyS[yypt-7].item.(bool), + Quick: yyS[yypt-6].item.(bool), + Ignore: yyS[yypt-5].item.(bool), + IsMultiTable: true, + Tables: &ast.DeleteTableList{Tables: yyS[yypt-3].item.([]*ast.TableName)}, + TableRefs: &ast.TableRefsClause{TableRefs: yyS[yypt-1].item.(*ast.Join)}, } if yyS[yypt-0].item != nil { x.Where = yyS[yypt-0].item.(ast.ExprNode) diff --git a/parser/parser.y b/parser/parser.y index b0ceb7baa8..e8cc7e5d13 100644 --- a/parser/parser.y +++ b/parser/parser.y @@ -1224,9 +1224,9 @@ DeleteFromStmt: LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - MultiTable: true, + IsMultiTable: true, BeforeFrom: true, - Tables: $5.([]*ast.TableName), + Tables: &ast.DeleteTableList{Tables: $5.([]*ast.TableName)}, TableRefs: &ast.TableRefsClause{TableRefs: $7.(*ast.Join)}, } if $8 != nil { @@ -1244,8 +1244,8 @@ DeleteFromStmt: LowPriority: $2.(bool), Quick: $3.(bool), Ignore: $4.(bool), - MultiTable: true, - Tables: $6.([]*ast.TableName), + IsMultiTable: true, + Tables: &ast.DeleteTableList{Tables: $6.([]*ast.TableName)}, TableRefs: &ast.TableRefsClause{TableRefs: $8.(*ast.Join)}, } if $9 != nil { diff --git a/stmt/stmts/delete_test.go b/stmt/stmts/delete_test.go index 7c659579f2..09844bfa99 100644 --- a/stmt/stmts/delete_test.go +++ b/stmt/stmts/delete_test.go @@ -18,7 +18,6 @@ import ( "strings" . "github.com/pingcap/check" - "github.com/pingcap/tidb" ) func (s *testStmtSuite) fillData(currDB *sql.DB, c *C) { @@ -70,13 +69,6 @@ func (s *testStmtSuite) queryStrings(currDB *sql.DB, sql string, c *C) []string func (s *testStmtSuite) TestDelete(c *C) { s.fillData(s.testDB, c) - // Test compile - stmtList, err := tidb.Compile(s.ctx, "DELETE from test where id = 2;") - c.Assert(err, IsNil) - - str := stmtList[0].OriginText() - c.Assert(0, Less, len(str)) - r := mustExec(c, s.testDB, `UPDATE test SET name = "abc" where id = 2;`) checkResult(c, r, 1, 0) @@ -140,12 +132,6 @@ func (s *testStmtSuite) TestDelete(c *C) { func (s *testStmtSuite) TestMultiTableDelete(c *C) { s.fillDataMultiTable(s.testDB, c) - // Test compile - stmtList, err := tidb.Compile(s.ctx, "DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;") - c.Assert(err, IsNil) - str := stmtList[0].OriginText() - c.Assert(0, Less, len(str)) - r := mustExec(c, s.testDB, `DELETE t1, t2 FROM t1 INNER JOIN t2 INNER JOIN t3 WHERE t1.id=t2.id AND t2.id=t3.id;`) checkResult(c, r, 2, 0)