From edcbbc84ff83f811e4708fadeabb74b14fdcb0e1 Mon Sep 17 00:00:00 2001 From: Ewan Chou Date: Thu, 29 Oct 2015 15:33:54 +0800 Subject: [PATCH] optimizer: implement ddl statement convert functions. --- ast/dml.go | 4 +- ast/parser/parser.y | 4 + optimizer/convert_expr.go | 18 +-- optimizer/convert_stmt.go | 274 +++++++++++++++++++++++++++++++++++++- 4 files changed, 284 insertions(+), 16 deletions(-) diff --git a/ast/dml.go b/ast/dml.go index d46cd86394..0225af0462 100644 --- a/ast/dml.go +++ b/ast/dml.go @@ -564,7 +564,7 @@ type InsertStmt struct { Setlist []*Assignment Priority int OnDuplicate []*Assignment - Select *SelectStmt + Select ResultSetNode } // Accept implements Node Accept interface. @@ -616,7 +616,7 @@ func (nod *InsertStmt) Accept(v Visitor) (Node, bool) { if !ok { return nod, false } - nod.Select = node.(*SelectStmt) + nod.Select = node.(ResultSetNode) } return v.Leave(nod) } diff --git a/ast/parser/parser.y b/ast/parser/parser.y index af69b11b61..6f096b8687 100644 --- a/ast/parser/parser.y +++ b/ast/parser/parser.y @@ -1669,6 +1669,10 @@ InsertRest: { $$ = &ast.InsertStmt{Select: $1.(*ast.SelectStmt)} } +| UnionStmt + { + $$ = &ast.InsertStmt{Select: $1.(*ast.UnionStmt)} + } | "SET" ColumnSetValueList { $$ = &ast.InsertStmt{Setlist: $2.([]*ast.Assignment)} diff --git a/optimizer/convert_expr.go b/optimizer/convert_expr.go index 4fe74ebf34..41c78a815e 100644 --- a/optimizer/convert_expr.go +++ b/optimizer/convert_expr.go @@ -190,26 +190,28 @@ func (c *expressionConverter) compareSubquery(v *ast.CompareSubqueryExpr) { c.exprMap[v] = oldCmpSubquery } -func concatCIStr(ciStrs ...model.CIStr) string { +func joinColumnName(columnName *ast.ColumnName) string { var originStrs []string - for _, v := range ciStrs { - if v.O != "" { - originStrs = append(originStrs, v.O) - } + if columnName.Schema.O != "" { + originStrs = append(originStrs, columnName.Schema.O) } + if columnName.Table.O != "" { + originStrs = append(originStrs, columnName.Table.O) + } + originStrs = append(originStrs, columnName.Name.O) return strings.Join(originStrs, ".") } func (c *expressionConverter) columnNameExpr(v *ast.ColumnNameExpr) { ident := &expression.Ident{} - ident.CIStr = model.NewCIStr(concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name)) + ident.CIStr = model.NewCIStr(joinColumnName(v.Name)) c.exprMap[v] = ident } func (c *expressionConverter) defaultExpr(v *ast.DefaultExpr) { oldDefault := &expression.Default{} if v.Name != nil { - oldDefault.Name = concatCIStr(v.Name.Schema, v.Name.Table, v.Name.Name) + oldDefault.Name = joinColumnName(v.Name) } c.exprMap[v] = oldDefault } @@ -302,7 +304,7 @@ func (c *expressionConverter) unaryOperation(v *ast.UnaryOperationExpr) { } func (c *expressionConverter) values(v *ast.ValuesExpr) { - nameStr := concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + nameStr := joinColumnName(v.Column) c.exprMap[v] = &expression.Values{CIStr: model.NewCIStr(nameStr)} } diff --git a/optimizer/convert_stmt.go b/optimizer/convert_stmt.go index 34e81052b9..036bfebefd 100644 --- a/optimizer/convert_stmt.go +++ b/optimizer/convert_stmt.go @@ -25,23 +25,153 @@ import ( "github.com/pingcap/tidb/table" ) +func convertAssignment(converter *expressionConverter, v *ast.Assignment) (*expression.Assignment, error) { + oldAssign := &expression.Assignment{ + ColName: joinColumnName(v.Column), + } + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldAssign.Expr = oldExpr + return oldAssign, nil +} + func convertInsert(v *ast.InsertStmt) (*stmts.InsertIntoStmt, error) { oldInsert := &stmts.InsertIntoStmt{ - Text: v.Text(), + Priority: v.Priority, + Text: v.Text(), + } + tableName := v.Table.TableRefs.Left.(*ast.TableName) + oldInsert.TableIdent = table.Ident{Schema: tableName.Schema, Name: tableName.Name} + for _, val := range v.Columns { + oldInsert.ColNames = append(oldInsert.ColNames, joinColumnName(val)) + } + converter := newExpressionConverter() + for _, row := range v.Lists { + var oldRow []expression.Expression + for _, val := range row { + oldExpr, err := convertExpr(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldRow = append(oldRow, oldExpr) + } + oldInsert.Lists = append(oldInsert.Lists, oldRow) + } + for _, assign := range v.Setlist { + oldAssign, err := convertAssignment(converter, assign) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.Setlist = append(oldInsert.Setlist, oldAssign) + } + for _, onDup := range v.OnDuplicate { + oldOnDup, err := convertAssignment(converter, onDup) + if err != nil { + return nil, errors.Trace(err) + } + oldInsert.OnDuplicate = append(oldInsert.OnDuplicate, *oldOnDup) + } + if v.Select != nil { + var err error + switch x := v.Select.(type) { + case *ast.SelectStmt: + oldInsert.Sel, err = convertSelect(x) + case *ast.UnionStmt: + oldInsert.Sel, err = convertUnion(x) + } + if err != nil { + return nil, errors.Trace(err) + } } return oldInsert, nil } func convertDelete(v *ast.DeleteStmt) (*stmts.DeleteStmt, error) { oldDelete := &stmts.DeleteStmt{ - Text: v.Text(), + BeforeFrom: v.BeforeFrom, + Ignore: v.Ignore, + LowPriority: v.LowPriority, + MultiTable: v.MultiTable, + Quick: v.Quick, + Text: v.Text(), + } + converter := newExpressionConverter() + oldRefs, err := convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + oldDelete.Refs = oldRefs + for _, val := range v.Tables { + tableIdent := table.Ident{Schema: val.Schema, Name: val.Name} + oldDelete.TableIdents = append(oldDelete.TableIdents, tableIdent) + } + if v.Where != nil { + oldDelete.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + if v.Order != nil { + orderRset := &rsets.OrderByRset{} + for _, val := range v.Order { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} + orderRset.By = append(orderRset.By, orderItem) + } + oldDelete.Order = orderRset + } + if v.Limit != nil { + oldDelete.Limit = &rsets.LimitRset{Count: v.Limit.Count} } return oldDelete, nil } func convertUpdate(v *ast.UpdateStmt) (*stmts.UpdateStmt, error) { oldUpdate := &stmts.UpdateStmt{ - Text: v.Text(), + Ignore: v.Ignore, + MultipleTable: v.MultipleTable, + LowPriority: v.LowPriority, + Text: v.Text(), + } + converter := newExpressionConverter() + var err error + oldUpdate.TableRefs, err = convertJoin(converter, v.TableRefs.TableRefs) + if err != nil { + return nil, errors.Trace(err) + } + if v.Where != nil { + oldUpdate.Where, err = convertExpr(converter, v.Where) + if err != nil { + return nil, errors.Trace(err) + } + } + for _, val := range v.List { + var oldAssign *expression.Assignment + oldAssign, err = convertAssignment(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldUpdate.List = append(oldUpdate.List, *oldAssign) + } + if v.Order != nil { + orderRset := &rsets.OrderByRset{} + for _, val := range v.Order { + oldExpr, err := convertExpr(converter, val.Expr) + if err != nil { + return nil, errors.Trace(err) + } + orderItem := rsets.OrderByItem{Expr: oldExpr, Asc: !val.Desc} + orderRset.By = append(orderRset.By, orderItem) + } + oldUpdate.Order = orderRset + } + if v.Limit != nil { + oldUpdate.Limit = &rsets.LimitRset{Count: v.Limit.Count} } return oldUpdate, nil } @@ -282,9 +412,141 @@ func convertDropDatabase(v *ast.DropDatabaseStmt) (*stmts.DropDatabaseStmt, erro }, nil } +func convertColumnOption(converter *expressionConverter, v *ast.ColumnOption) (*coldef.ConstraintOpt, error) { + oldColumnOpt := &coldef.ConstraintOpt{} + switch v.Tp { + case ast.ColumnOptionAutoIncrement: + oldColumnOpt.Tp = coldef.ConstrAutoIncrement + case ast.ColumnOptionComment: + oldColumnOpt.Tp = coldef.ConstrComment + case ast.ColumnOptionDefaultValue: + oldColumnOpt.Tp = coldef.ConstrDefaultValue + case ast.ColumnOptionIndex: + oldColumnOpt.Tp = coldef.ConstrIndex + case ast.ColumnOptionKey: + oldColumnOpt.Tp = coldef.ConstrKey + case ast.ColumnOptionFulltext: + oldColumnOpt.Tp = coldef.ConstrFulltext + case ast.ColumnOptionNotNull: + oldColumnOpt.Tp = coldef.ConstrNotNull + case ast.ColumnOptionNoOption: + oldColumnOpt.Tp = coldef.ConstrNoConstr + case ast.ColumnOptionOnUpdate: + oldColumnOpt.Tp = coldef.ConstrOnUpdate + case ast.ColumnOptionPrimaryKey: + oldColumnOpt.Tp = coldef.ConstrPrimaryKey + case ast.ColumnOptionNull: + oldColumnOpt.Tp = coldef.ConstrNull + case ast.ColumnOptionUniq: + oldColumnOpt.Tp = coldef.ConstrUniq + case ast.ColumnOptionUniqIndex: + oldColumnOpt.Tp = coldef.ConstrUniqIndex + case ast.ColumnOptionUniqKey: + oldColumnOpt.Tp = coldef.ConstrUniqKey + } + if v.Expr != nil { + oldExpr, err := convertExpr(converter, v.Expr) + if err != nil { + return nil, errors.Trace(err) + } + oldColumnOpt.Evalue = oldExpr + } + return oldColumnOpt, nil +} + +func convertColumnDef(converter *expressionConverter, v *ast.ColumnDef) (*coldef.ColumnDef, error) { + oldColDef := &coldef.ColumnDef{ + Name: v.Name.Name.O, + Tp: v.Tp, + } + for _, val := range v.Options { + oldOpt, err := convertColumnOption(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldColDef.Constraints = append(oldColDef.Constraints, oldOpt) + } + return oldColDef, nil +} + +func convertIndexColNames(v []*ast.IndexColName) (out []*coldef.IndexColName) { + for _, val := range v { + oldIndexColKey := &coldef.IndexColName{ + ColumnName: val.Column.Name.O, + Length: val.Length, + } + out = append(out, oldIndexColKey) + } + return +} + +func convertConstraint(converter *expressionConverter, v *ast.Constraint) (*coldef.TableConstraint, error) { + oldConstraint := &coldef.TableConstraint{ConstrName: v.Name} + switch v.Tp { + case ast.ConstraintNoConstraint: + oldConstraint.Tp = coldef.ConstrNoConstr + case ast.ConstraintPrimaryKey: + oldConstraint.Tp = coldef.ConstrPrimaryKey + case ast.ConstraintKey: + oldConstraint.Tp = coldef.ConstrKey + case ast.ConstraintIndex: + oldConstraint.Tp = coldef.ConstrIndex + case ast.ConstraintUniq: + oldConstraint.Tp = coldef.ConstrUniq + case ast.ConstraintUniqKey: + oldConstraint.Tp = coldef.ConstrUniqKey + case ast.ConstraintUniqIndex: + oldConstraint.Tp = coldef.ConstrUniqIndex + case ast.ConstraintForeignKey: + oldConstraint.Tp = coldef.ConstrForeignKey + case ast.ConstraintFulltext: + oldConstraint.Tp = coldef.ConstrFulltext + } + oldConstraint.Keys = convertIndexColNames(v.Keys) + if v.Refer != nil { + oldConstraint.Refer = &coldef.ReferenceDef{ + TableIdent: table.Ident{Schema: v.Refer.Table.Schema, Name: v.Refer.Table.Name}, + IndexColNames: convertIndexColNames(v.Refer.IndexColNames), + } + } + return oldConstraint, nil +} + func convertCreateTable(v *ast.CreateTableStmt) (*stmts.CreateTableStmt, error) { oldCreateTable := &stmts.CreateTableStmt{ - Text: v.Text(), + Ident: table.Ident{Schema: v.Table.Schema, Name: v.Table.Name}, + Text: v.Text(), + } + converter := newExpressionConverter() + for _, val := range v.Cols { + oldColDef, err := convertColumnDef(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Cols = append(oldCreateTable.Cols, oldColDef) + } + for _, val := range v.Constraints { + oldConstr, err := convertConstraint(converter, val) + if err != nil { + return nil, errors.Trace(err) + } + oldCreateTable.Constraints = append(oldCreateTable.Constraints, oldConstr) + } + if len(v.Options) != 0 { + oldTableOpt := &coldef.TableOption{} + for _, val := range v.Options { + switch val.Tp { + case ast.TableOptionEngine: + oldTableOpt.Engine = val.StrValue + case ast.TableOptionCharset: + oldTableOpt.Charset = val.StrValue + case ast.TableOptionCollate: + oldTableOpt.Collate = val.StrValue + case ast.TableOptionAutoIncrement: + oldTableOpt.AutoIncrement = val.UintValue + } + } + oldCreateTable.Opt = oldTableOpt } return oldCreateTable, nil } @@ -317,7 +579,7 @@ func convertCreateIndex(v *ast.CreateIndexStmt) (*stmts.CreateIndexStmt, error) oldCreateIndex.IndexColNames = make([]*coldef.IndexColName, len(v.IndexColNames)) for i, val := range v.IndexColNames { oldIndexColName := &coldef.IndexColName{ - ColumnName: concatCIStr(val.Column.Schema, val.Column.Table, val.Column.Name), + ColumnName: joinColumnName(val.Column), Length: val.Length, } oldCreateIndex.IndexColNames[i] = oldIndexColName @@ -430,7 +692,7 @@ func convertShow(v *ast.ShowStmt) (*stmts.ShowStmt, error) { } } if v.Column != nil { - oldShow.ColumnName = concatCIStr(v.Column.Schema, v.Column.Table, v.Column.Name) + oldShow.ColumnName = joinColumnName(v.Column) } if v.Where != nil { converter := newExpressionConverter()